Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refrain from using globals in Validation(). #495

Merged
merged 1 commit into from
Jan 13, 2016
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 85 additions & 107 deletions tools/torch/main.lua
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,8 @@ if opt.subtractMean ~= 'none' then
end

local classes
local confusion
local validation_confusion
local trainConfusion
local valConfusion

if opt.labels ~= '' then
logmessage.display(0,'Loading label definitions from '.. opt.labels ..' file')
Expand All @@ -195,12 +195,12 @@ if opt.labels ~= '' then
end

-- This matrix records the current confusion across classes
confusion = optim.ConfusionMatrix(classes)
trainConfusion = optim.ConfusionMatrix(classes)

-- separate validation matrix for validation data
validation_confusion = nil
valConfusion = nil
if opt.validation ~= '' then
validation_confusion = optim.ConfusionMatrix(classes)
valConfusion = optim.ConfusionMatrix(classes)
end

logmessage.display(0,'found ' .. #classes .. ' categories')
Expand All @@ -213,33 +213,35 @@ logmessage.display(0,'creating data readers')
-- (e.g. cropping, mean subtraction, mirroring) are
-- performed from separate threads
local trainDataLoader, trainSize, inputTensorShape
local validationDataLoader, valSize
local valDataLoader, valSize

if opt.train ~= '' then
-- create data loader for training dataset
trainDataLoader = DataLoader:new(4, -- num threads
package.path,
opt.dbbackend, opt.train, opt.train_labels,
opt.mirror, meanTensor,
true, -- train
opt.shuffle,
classes ~= nil -- whether this is a classification task
)
trainDataLoader = DataLoader:new(
4, -- num threads
package.path,
opt.dbbackend, opt.train, opt.train_labels,
opt.mirror, meanTensor,
true, -- train
opt.shuffle,
classes ~= nil -- whether this is a classification task
)
-- retrieve info from train DB (number of records and shape of input tensors)
trainSize, inputTensorShape = trainDataLoader:getInfo()
logmessage.display(0,'found ' .. trainSize .. ' images in train db' .. opt.train)
if opt.validation ~= '' then
local shape
validationDataLoader = DataLoader:new(4, -- num threads
package.path,
opt.dbbackend, opt.validation, opt.validation_labels,
false, -- no need to do random mirrorring
meanTensor,
false, -- train
false, -- shuffle
classes ~= nil -- whether this is a classification task
)
valSize, shape = validationDataLoader:getInfo()
valDataLoader = DataLoader:new(
4, -- num threads
package.path,
opt.dbbackend, opt.validation, opt.validation_labels,
false, -- no need to do random mirrorring
meanTensor,
false, -- train
false, -- shuffle
classes ~= nil -- whether this is a classification task
)
valSize, shape = valDataLoader:getInfo()
logmessage.display(0,'found ' .. valSize .. ' images in train db' .. opt.validation)
end
else
Expand Down Expand Up @@ -294,16 +296,16 @@ end
-- can be separate batch sizes for the training and validation
-- sets)
local trainBatchSize
local validationBatchSize
local valBatchSize
if opt.batchSize==0 then
local defaultBatchSize = 16
trainBatchSize = network.trainBatchSize or defaultBatchSize
validationBatchSize = network.validationBatchSize or defaultBatchSize
valBatchSize = network.validBatchSize or defaultBatchSize
else
trainBatchSize = opt.batchSize
validationBatchSize = opt.batchSize
valBatchSize = opt.batchSize
end
logmessage.display(0,'Train batch size is '.. trainBatchSize .. ' and validation batch size is ' .. validationBatchSize)
logmessage.display(0,'Train batch size is '.. trainBatchSize .. ' and validation batch size is ' .. valBatchSize)

-- if we were instructed to print a visualization of the model,
-- do it now and return immediately
Expand Down Expand Up @@ -387,8 +389,8 @@ if opt.crop and inputTensorShape then
opt.croplen = math.min(opt.croplen, inputTensorShape[2], inputTensorShape[3])
-- set crop length in data readers
trainDataLoader:setCropLen(opt.croplen)
if validationDataLoader then
validationDataLoader:setCropLen(opt.croplen)
if valDataLoader then
valDataLoader:setCropLen(opt.croplen)
end
end

Expand Down Expand Up @@ -485,11 +487,13 @@ if opt.optimState ~= '' then
end

local function updateConfusion(y,yt)
if confusion ~= nil then
confusion:batchAdd(y,yt)
if trainConfusion ~= nil then
trainConfusion:batchAdd(y,yt)
end
end

local labelFunction = network.labelHook or function (input, dblabel) return dblabel end

-- Optimization configuration
logmessage.display(0,'initializing the parameters for Optimizer')
local optimizer = Optimizer{
Expand All @@ -501,7 +505,7 @@ local optimizer = Optimizer{
Parameters = {Weights, Gradients},
HookFunction = COMPUTE_TRAIN_ACCURACY and updateConfusion or nil,
lrPolicy = lrpolicy,
LabelFunction = network.labelHook or function (input,dblabel) return dblabel end,
LabelFunction = labelFunction,
}

-- During training, loss rate should be displayed at max 8 times or for every 5000 images, whichever lower.
Expand Down Expand Up @@ -582,70 +586,76 @@ end


-- Validation function
local function Validation(dataLoader, nn_model, epoch)
local function Validation(model, loss, epoch, data_loader, data_size, batch_size, confusion, label_function)

-- switch model to evaluation mode
nn_model:evaluate()
model:evaluate()

local NumBatches = 0
local batch_count = 0
local loss_sum = 0
local inputs, targets
local dataLoaderIdx = 1
local data_index = 1
local data = {}

local t = 1
while t <= valSize do
if confusion ~= nil then
confusion:zero()
end

-- create mini batch
NumBatches = NumBatches + 1
local count = 1
while count <= data_size do

while dataLoader:acceptsjob() do
local dataBatchSize = math.min(valSize-dataLoaderIdx+1,validationBatchSize)
if dataBatchSize > 0 then
dataLoader:scheduleNextBatch(dataBatchSize, dataLoaderIdx, data, true)
dataLoaderIdx = dataLoaderIdx + dataBatchSize
-- create mini batch
while data_loader:acceptsjob() do
local curr_batch_size = math.min(data_size - data_index + 1, batch_size)
if curr_batch_size > 0 then
data_loader:scheduleNextBatch(curr_batch_size, data_index, data, true)
data_index = data_index + curr_batch_size
else break end
end

-- wait for next data loader job to complete
dataLoader:waitNext()
data_loader:waitNext()

-- get data from last load job
local thisBatchSize = data.batchSize
inputs = data.inputs
targets = data.outputs
local data_batch_size = data.batchSize
local inputs = data.inputs
local targets = data.outputs

if inputs then
if opt.type =='cuda' then
inputs=inputs:cuda()
if inputs ~= nil then
if opt.type == 'cuda' then
inputs = inputs:cuda()
targets = targets:cuda()
else
inputs=inputs:float()
inputs = inputs:float()
targets = targets:float()
end

local y = model:forward(inputs)
local labels = network.labelHook and network.labelHook(inputs, targets) or targets
local err = loss:forward(y,labels)
local labels = label_function(inputs, targets)
local err = loss:forward(y, labels)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

loss is another global that should be converted to a function parameter

loss_sum = loss_sum + err
if validation_confusion then
validation_confusion:batchAdd(y,labels)
if confusion ~= nil then
confusion:batchAdd(y, labels)
end

if math.fmod(NumBatches,50)==0 then
batch_count = batch_count + 1
if math.fmod(batch_count, 50) == 0 then
collectgarbage()
end

t = t + thisBatchSize
count = count + data_batch_size
else
-- failed to read from database (possibly due to disabled thread)
dataLoaderIdx = dataLoaderIdx - data.batchSize
data_index = data_index - data_batch_size
end
end

return (loss_sum/NumBatches)

--xlua.progress(valSize, valSize)
local avg_loss = batch_count > 0 and loss_sum / batch_count or 0
if confusion ~= nil then
confusion:updateValids()
logmessage.display(0, 'Validation (epoch ' .. epoch .. '): loss = ' .. avg_loss .. ', accuracy = ' .. confusion.totalValid)
else
logmessage.display(0, 'Validation (epoch ' .. epoch .. '): loss = ' .. avg_loss)
end
end

-- Train function
Expand Down Expand Up @@ -733,17 +743,7 @@ local function Train(epoch, dataLoader)
end

if opt.validation ~= '' and current_epoch >= next_validation then
if validation_confusion ~= nil then
validation_confusion:zero()
end
local avg_loss=Validation(validationDataLoader, model, current_epoch)
-- log details at the end of validation
if validation_confusion ~= nil then
validation_confusion:updateValids()
logmessage.display(0, 'Validation (epoch ' .. current_epoch .. '): loss = ' .. avg_loss .. ', accuracy = ' .. validation_confusion.totalValid)
else
logmessage.display(0, 'Validation (epoch ' .. current_epoch .. '): loss = ' .. avg_loss )
end
Validation(model, loss, current_epoch, valDataLoader, valSize, valBatchSize, valConfusion, labelFunction)

next_validation = (utils.round(current_epoch/opt.interval) + 1) * opt.interval -- To find next nearest epoch value that exactly divisible by opt.interval
last_validation_epoch = current_epoch
Expand Down Expand Up @@ -781,48 +781,26 @@ logmessage.display(0,'started training the model')

-- run an initial validation before the first train epoch
if opt.validation ~= '' then
model:evaluate()
if validation_confusion ~= nil then
validation_confusion:zero()
end
local avg_loss=Validation(validationDataLoader, model, 0)
-- log details at the end of validation
if validation_confusion ~= nil then
validation_confusion:updateValids()
logmessage.display(0, 'Validation (epoch ' .. epoch-1 .. '): loss = ' .. avg_loss .. ', accuracy = ' .. validation_confusion.totalValid)
else
logmessage.display(0, 'Validation (epoch ' .. epoch-1 .. '): loss = ' .. avg_loss )
end
model:training() -- to reset model to training
Validation(model, loss, 0, valDataLoader, valSize, valBatchSize, valConfusion, labelFunction)
end

while epoch<=opt.epoch do
local ErrTrain = 0
if confusion ~= nil then
confusion:zero()
if trainConfusion ~= nil then
trainConfusion:zero()
end
Train(epoch, trainDataLoader)
if confusion ~= nil then
confusion:updateValids()
--print(confusion)
ErrTrain = (1-confusion.totalValid)
if trainConfusion ~= nil then
trainConfusion:updateValids()
--print(trainConfusion)
ErrTrain = (1-trainConfusion.totalValid)
end
epoch = epoch+1
end

-- if required, perform validation at the end
if opt.validation ~= '' and opt.epoch > last_validation_epoch then
if validation_confusion ~= nil then
validation_confusion:zero()
end
local avg_loss=Validation(validationDataLoader, model, opt.epoch)
-- log details at the end of validation
if validation_confusion ~= nil then
validation_confusion:updateValids()
logmessage.display(0, 'Validation (epoch ' .. opt.epoch .. '): loss = ' .. avg_loss .. ', accuracy = ' .. validation_confusion.totalValid)
else
logmessage.display(0, 'Validation (epoch ' .. opt.epoch .. '): loss = ' .. avg_loss )
end
Validation(model, loss, opt.epoch, valDataLoader, valSize, valBatchSize, valConfusion, labelFunction)
end

-- if required, save snapshot at the end
Expand All @@ -836,7 +814,7 @@ end
-- close databases
trainDataLoader:close()
if opt.validation ~= '' then
validationDataLoader:close()
valDataLoader:close()
end

-- enforce clean exit
Expand Down