Skip to content

Commit

Permalink
checkpoint and nnp for imagenet
Browse files Browse the repository at this point in the history
  • Loading branch information
TE-ShreenidhiRamachnadran authored and TakuyaYashima committed Dec 3, 2019
1 parent 62ff854 commit 1083cab
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 4 deletions.
24 changes: 24 additions & 0 deletions imagenet-classification/_checkpoint_nnp_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright (c) 2017 Sony Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import

import os
import sys

# Import save/load_checkpoint from utils
sys.path.append(os.path.join(
os.path.dirname(__file__), '../utils/'))
from checkpoint_util import save_checkpoint, load_checkpoint
from save_nnp import save_nnp
2 changes: 2 additions & 0 deletions imagenet-classification/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ def parse_tuple(x):
help="Random area of the RandomResizedCrop augmentation.")
parser.add_argument("--num-threads", "-N", type=int, default=num_threads,
help="DALI's the number of threads.")
parser.add_argument("--checkpoint", type=str, default=None,
help='path to checkpoint file')

args = parser.parse_args()
if not os.path.isdir(args.model_save_path):
Expand Down
27 changes: 23 additions & 4 deletions imagenet-classification/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
import nnabla.parametric_functions as PF
import nnabla.solvers as S

from _checkpoint_nnp_util import load_checkpoint, save_checkpoint, save_nnp
from args import get_args
from tiny_imagenet_data import data_iterator_tiny_imagenet
from imagenet_data import data_iterator_imagenet
import model_resnet

import model_resnet
import nnabla.utils.save as save
import os
from collections import namedtuple

Expand Down Expand Up @@ -135,10 +137,21 @@ def train():

v_e = F.mean(F.top_n_error(v_pred2, v_model.label))

# Save_nnp_Epoch0
contents = save_nnp({'x': v_model.image}, {
'y': v_model.pred}, args.batch_size)
save.save(os.path.join(args.model_save_path,
'Imagenet_result_epoch0.nnp'), contents)

# Create Solver.
solver = S.Momentum(args.learning_rate, 0.9)
solver.set_parameters(nn.get_parameters())

start_point = 0
if args.checkpoint is not None:
# load weights and solver state info from specified checkpoint file.
start_point = load_checkpoint(args.checkpoint, solver)

# Create monitor.
import nnabla.monitor as M
monitor = M.Monitor(args.monitor_path)
Expand All @@ -151,11 +164,11 @@ def train():
"Validation time", monitor, interval=10)

# Training loop.
for i in range(args.max_iter):
for i in range(start_point, args.max_iter):
# Save parameters
if i % args.model_save_interval == 0:
nn.save_parameters(os.path.join(
args.model_save_path, 'param_%06d.h5' % i))
# save checkpoint file
save_checkpoint(args.model_save_path, i, solver)

# Validation
if i % args.val_interval == 0 and i != 0:
Expand Down Expand Up @@ -217,6 +230,12 @@ def accumulate_error(l, e, t_model, t_e):
nn.save_parameters(os.path.join(args.model_save_path,
'param_%06d.h5' % args.max_iter))

# Save_nnp
contents = save_nnp({'x': v_model.image}, {
'y': v_model.pred}, args.batch_size)
save.save(os.path.join(args.model_save_path,
'Imagenet_result.nnp'), contents)


if __name__ == '__main__':
train()

0 comments on commit 1083cab

Please sign in to comment.