diff --git a/tools/torch/main.lua b/tools/torch/main.lua index 63f4ebe60..38719874a 100644 --- a/tools/torch/main.lua +++ b/tools/torch/main.lua @@ -181,8 +181,8 @@ if opt.subtractMean ~= 'none' then end local classes -local confusion -local validation_confusion +local trainConfusion +local valConfusion if opt.labels ~= '' then logmessage.display(0,'Loading label definitions from '.. opt.labels ..' file') @@ -195,12 +195,12 @@ if opt.labels ~= '' then end -- This matrix records the current confusion across classes - confusion = optim.ConfusionMatrix(classes) + trainConfusion = optim.ConfusionMatrix(classes) -- separate validation matrix for validation data - validation_confusion = nil + valConfusion = nil if opt.validation ~= '' then - validation_confusion = optim.ConfusionMatrix(classes) + valConfusion = optim.ConfusionMatrix(classes) end logmessage.display(0,'found ' .. #classes .. ' categories') @@ -213,33 +213,35 @@ logmessage.display(0,'creating data readers') -- (e.g. cropping, mean subtraction, mirroring) are -- performed from separate threads local trainDataLoader, trainSize, inputTensorShape -local validationDataLoader, valSize +local valDataLoader, valSize if opt.train ~= '' then -- create data loader for training dataset - trainDataLoader = DataLoader:new(4, -- num threads - package.path, - opt.dbbackend, opt.train, opt.train_labels, - opt.mirror, meanTensor, - true, -- train - opt.shuffle, - classes ~= nil -- whether this is a classification task - ) + trainDataLoader = DataLoader:new( + 4, -- num threads + package.path, + opt.dbbackend, opt.train, opt.train_labels, + opt.mirror, meanTensor, + true, -- train + opt.shuffle, + classes ~= nil -- whether this is a classification task + ) -- retrieve info from train DB (number of records and shape of input tensors) trainSize, inputTensorShape = trainDataLoader:getInfo() logmessage.display(0,'found ' .. trainSize .. ' images in train db' .. opt.train) if opt.validation ~= '' then local shape - validationDataLoader = DataLoader:new(4, -- num threads - package.path, - opt.dbbackend, opt.validation, opt.validation_labels, - false, -- no need to do random mirrorring - meanTensor, - false, -- train - false, -- shuffle - classes ~= nil -- whether this is a classification task - ) - valSize, shape = validationDataLoader:getInfo() + valDataLoader = DataLoader:new( + 4, -- num threads + package.path, + opt.dbbackend, opt.validation, opt.validation_labels, + false, -- no need to do random mirrorring + meanTensor, + false, -- train + false, -- shuffle + classes ~= nil -- whether this is a classification task + ) + valSize, shape = valDataLoader:getInfo() logmessage.display(0,'found ' .. valSize .. ' images in train db' .. opt.validation) end else @@ -294,16 +296,16 @@ end -- can be separate batch sizes for the training and validation -- sets) local trainBatchSize -local validationBatchSize +local valBatchSize if opt.batchSize==0 then local defaultBatchSize = 16 trainBatchSize = network.trainBatchSize or defaultBatchSize - validationBatchSize = network.validationBatchSize or defaultBatchSize + valBatchSize = network.validBatchSize or defaultBatchSize else trainBatchSize = opt.batchSize - validationBatchSize = opt.batchSize + valBatchSize = opt.batchSize end -logmessage.display(0,'Train batch size is '.. trainBatchSize .. ' and validation batch size is ' .. validationBatchSize) +logmessage.display(0,'Train batch size is '.. trainBatchSize .. ' and validation batch size is ' .. valBatchSize) -- if we were instructed to print a visualization of the model, -- do it now and return immediately @@ -387,8 +389,8 @@ if opt.crop and inputTensorShape then opt.croplen = math.min(opt.croplen, inputTensorShape[2], inputTensorShape[3]) -- set crop length in data readers trainDataLoader:setCropLen(opt.croplen) - if validationDataLoader then - validationDataLoader:setCropLen(opt.croplen) + if valDataLoader then + valDataLoader:setCropLen(opt.croplen) end end @@ -485,11 +487,13 @@ if opt.optimState ~= '' then end local function updateConfusion(y,yt) - if confusion ~= nil then - confusion:batchAdd(y,yt) + if trainConfusion ~= nil then + trainConfusion:batchAdd(y,yt) end end +local labelFunction = network.labelHook or function (input, dblabel) return dblabel end + -- Optimization configuration logmessage.display(0,'initializing the parameters for Optimizer') local optimizer = Optimizer{ @@ -501,7 +505,7 @@ local optimizer = Optimizer{ Parameters = {Weights, Gradients}, HookFunction = COMPUTE_TRAIN_ACCURACY and updateConfusion or nil, lrPolicy = lrpolicy, - LabelFunction = network.labelHook or function (input,dblabel) return dblabel end, + LabelFunction = labelFunction, } -- During training, loss rate should be displayed at max 8 times or for every 5000 images, whichever lower. @@ -582,70 +586,76 @@ end -- Validation function -local function Validation(dataLoader, nn_model, epoch) +local function Validation(model, loss, epoch, data_loader, data_size, batch_size, confusion, label_function) -- switch model to evaluation mode - nn_model:evaluate() + model:evaluate() - local NumBatches = 0 + local batch_count = 0 local loss_sum = 0 - local inputs, targets - local dataLoaderIdx = 1 + local data_index = 1 local data = {} - local t = 1 - while t <= valSize do + if confusion ~= nil then + confusion:zero() + end - -- create mini batch - NumBatches = NumBatches + 1 + local count = 1 + while count <= data_size do - while dataLoader:acceptsjob() do - local dataBatchSize = math.min(valSize-dataLoaderIdx+1,validationBatchSize) - if dataBatchSize > 0 then - dataLoader:scheduleNextBatch(dataBatchSize, dataLoaderIdx, data, true) - dataLoaderIdx = dataLoaderIdx + dataBatchSize + -- create mini batch + while data_loader:acceptsjob() do + local curr_batch_size = math.min(data_size - data_index + 1, batch_size) + if curr_batch_size > 0 then + data_loader:scheduleNextBatch(curr_batch_size, data_index, data, true) + data_index = data_index + curr_batch_size else break end end -- wait for next data loader job to complete - dataLoader:waitNext() + data_loader:waitNext() -- get data from last load job - local thisBatchSize = data.batchSize - inputs = data.inputs - targets = data.outputs + local data_batch_size = data.batchSize + local inputs = data.inputs + local targets = data.outputs - if inputs then - if opt.type =='cuda' then - inputs=inputs:cuda() + if inputs ~= nil then + if opt.type == 'cuda' then + inputs = inputs:cuda() targets = targets:cuda() else - inputs=inputs:float() + inputs = inputs:float() targets = targets:float() end local y = model:forward(inputs) - local labels = network.labelHook and network.labelHook(inputs, targets) or targets - local err = loss:forward(y,labels) + local labels = label_function(inputs, targets) + local err = loss:forward(y, labels) loss_sum = loss_sum + err - if validation_confusion then - validation_confusion:batchAdd(y,labels) + if confusion ~= nil then + confusion:batchAdd(y, labels) end - if math.fmod(NumBatches,50)==0 then + batch_count = batch_count + 1 + if math.fmod(batch_count, 50) == 0 then collectgarbage() end - t = t + thisBatchSize + count = count + data_batch_size else -- failed to read from database (possibly due to disabled thread) - dataLoaderIdx = dataLoaderIdx - data.batchSize + data_index = data_index - data_batch_size end end - return (loss_sum/NumBatches) - - --xlua.progress(valSize, valSize) + local avg_loss = batch_count > 0 and loss_sum / batch_count or 0 + if confusion ~= nil then + confusion:updateValids() + logmessage.display(0, 'Validation (epoch ' .. epoch .. '): loss = ' .. avg_loss .. ', accuracy = ' .. confusion.totalValid) + else + logmessage.display(0, 'Validation (epoch ' .. epoch .. '): loss = ' .. avg_loss) + end end -- Train function @@ -733,17 +743,7 @@ local function Train(epoch, dataLoader) end if opt.validation ~= '' and current_epoch >= next_validation then - if validation_confusion ~= nil then - validation_confusion:zero() - end - local avg_loss=Validation(validationDataLoader, model, current_epoch) - -- log details at the end of validation - if validation_confusion ~= nil then - validation_confusion:updateValids() - logmessage.display(0, 'Validation (epoch ' .. current_epoch .. '): loss = ' .. avg_loss .. ', accuracy = ' .. validation_confusion.totalValid) - else - logmessage.display(0, 'Validation (epoch ' .. current_epoch .. '): loss = ' .. avg_loss ) - end + Validation(model, loss, current_epoch, valDataLoader, valSize, valBatchSize, valConfusion, labelFunction) next_validation = (utils.round(current_epoch/opt.interval) + 1) * opt.interval -- To find next nearest epoch value that exactly divisible by opt.interval last_validation_epoch = current_epoch @@ -781,48 +781,26 @@ logmessage.display(0,'started training the model') -- run an initial validation before the first train epoch if opt.validation ~= '' then - model:evaluate() - if validation_confusion ~= nil then - validation_confusion:zero() - end - local avg_loss=Validation(validationDataLoader, model, 0) - -- log details at the end of validation - if validation_confusion ~= nil then - validation_confusion:updateValids() - logmessage.display(0, 'Validation (epoch ' .. epoch-1 .. '): loss = ' .. avg_loss .. ', accuracy = ' .. validation_confusion.totalValid) - else - logmessage.display(0, 'Validation (epoch ' .. epoch-1 .. '): loss = ' .. avg_loss ) - end - model:training() -- to reset model to training + Validation(model, loss, 0, valDataLoader, valSize, valBatchSize, valConfusion, labelFunction) end while epoch<=opt.epoch do local ErrTrain = 0 - if confusion ~= nil then - confusion:zero() + if trainConfusion ~= nil then + trainConfusion:zero() end Train(epoch, trainDataLoader) - if confusion ~= nil then - confusion:updateValids() - --print(confusion) - ErrTrain = (1-confusion.totalValid) + if trainConfusion ~= nil then + trainConfusion:updateValids() + --print(trainConfusion) + ErrTrain = (1-trainConfusion.totalValid) end epoch = epoch+1 end -- if required, perform validation at the end if opt.validation ~= '' and opt.epoch > last_validation_epoch then - if validation_confusion ~= nil then - validation_confusion:zero() - end - local avg_loss=Validation(validationDataLoader, model, opt.epoch) - -- log details at the end of validation - if validation_confusion ~= nil then - validation_confusion:updateValids() - logmessage.display(0, 'Validation (epoch ' .. opt.epoch .. '): loss = ' .. avg_loss .. ', accuracy = ' .. validation_confusion.totalValid) - else - logmessage.display(0, 'Validation (epoch ' .. opt.epoch .. '): loss = ' .. avg_loss ) - end + Validation(model, loss, opt.epoch, valDataLoader, valSize, valBatchSize, valConfusion, labelFunction) end -- if required, save snapshot at the end @@ -836,7 +814,7 @@ end -- close databases trainDataLoader:close() if opt.validation ~= '' then - validationDataLoader:close() + valDataLoader:close() end -- enforce clean exit