Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CPU support for pre-trained models. Working fine without issue. #19

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
975 changes: 975 additions & 0 deletions Demo.ipynb

Large diffs are not rendered by default.

Binary file added __pycache__/config.cpython-37.pyc
Binary file not shown.
Binary file added __pycache__/dataloader.cpython-37.pyc
Binary file not shown.
Binary file added __pycache__/inference_demo_helper.cpython-37.pyc
Binary file not shown.
1 change: 1 addition & 0 deletions arch.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
7
55 changes: 15 additions & 40 deletions config.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,22 @@
import argparse

def getConfig():
parser = argparse.ArgumentParser()
parser.add_argument('action', type=str, default='train', help='Model Training or Testing options')
parser.add_argument('--exp_num', default=0, type=str, help='experiment_number')
parser.add_argument('--dataset', type=str, default='DUTS', help='DUTS')
parser.add_argument('--data_path', type=str, default='data/')

# Model parameter settings
parser.add_argument('--arch', type=str, default='0', help='Backbone Architecture')
parser.add_argument('--channels', type=list, default=[24, 40, 112, 320])
parser.add_argument('--RFB_aggregated_channel', type=int, nargs='*', default=[32, 64, 128])
parser.add_argument('--frequency_radius', type=int, default=16, help='Frequency radius r in FFT')
parser.add_argument('--denoise', type=float, default=0.93, help='Denoising background ratio')
parser.add_argument('--gamma', type=float, default=0.1, help='Confidence ratio')

# Training parameter settings
parser.add_argument('--img_size', type=int, default=320)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--lr', type=float, default=5e-5)
parser.add_argument('--optimizer', type=str, default='Adam')
parser.add_argument('--weight_decay', type=float, default=1e-4)
parser.add_argument('--criterion', type=str, default='API', help='API or bce')
parser.add_argument('--scheduler', type=str, default='Reduce', help='Reduce or Step')
parser.add_argument('--aug_ver', type=int, default=2, help='1=Normal, 2=Hard')
parser.add_argument('--lr_factor', type=float, default=0.1)
parser.add_argument('--clipping', type=float, default=2, help='Gradient clipping')
parser.add_argument('--patience', type=int, default=5, help="Scheduler ReduceLROnPlateau's parameter & Early Stopping(+5)")
parser.add_argument('--model_path', type=str, default='results/')
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--save_map', type=bool, default=None, help='Save prediction map')
class DummyArgs():
def __init__(self, arch = 7):
d = {0:320, 1:320, 2:352, 3:384, 4:448, 5:512, 6:576, 7:640}
self.arch = str(arch)
self.channels = [24, 40, 112, 320]
self.RFB_aggregated_channel = [32, 64, 128]
self.frequency_radius = 16
self.denoise = 0.93
self.gamma = 0.1
self.multi_gpu = False
self.img_size = d[int(arch)] # image_size is based on architecture


# Hardware settings
parser.add_argument('--multi_gpu', type=bool, default=True)
parser.add_argument('--num_workers', type=int, default=4)
cfg = parser.parse_args()

return cfg
def getConfig():
with open ('./arch.txt') as f: arch = int(f.read())
return DummyArgs(arch)


if __name__ == '__main__':
cfg = getConfig()
cfg = vars(cfg)
print(cfg)
cfg = getConfig()
Binary file added image.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ def __init__(self, args, save_path):

# Network
self.model = TRACER(args).to(self.device)
if args.multi_gpu:
if args.multi_gpu or self.device.type == 'cpu': # original code does not infer with CPU conditions because it was saved with nn.DataParallel
self.model = nn.DataParallel(self.model).to(self.device)

path = load_pretrained(f'TE-{args.arch}')
path = load_pretrained(f'TE-{args.arch}', self.device)
self.model.load_state_dict(path)
print('###### pre-trained Model restored #####')

Expand Down
84 changes: 84 additions & 0 deletions inference_demo_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""
author: Min Seok Lee and Wooseok Shin
"""
from PIL import Image
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from dataloader import get_test_augmentation
from model.TRACER import TRACER
from util.utils import load_pretrained
import torch.nn as nn
import urllib
from torchvision.transforms import transforms


class Inference():
def __init__(self, args):
super(Inference, self).__init__()
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.transform = get_test_augmentation(img_size=args.img_size)
self.args = args

self.invTrans = transforms.Compose([ transforms.Normalize(mean = [ 0., 0., 0. ],
std = [ 1/0.229, 1/0.224, 1/0.225 ]),
transforms.Normalize(mean = [ -0.485, -0.456, -0.406 ],
std = [ 1., 1., 1. ]),
])

# Network
self.model = TRACER(args).to(self.device)
self.model = nn.DataParallel(self.model).to(self.device)

path = load_pretrained(f'TE-{args.arch}', self.device)
self.model.load_state_dict(path)
self.model.eval()
print('###### pre-trained Model restored #####')


def test(self, image):
if isinstance(image, Image.Image):
image = np.array(image)

elif isinstance(image, str): # if path or URL
if "http" in image or "https" in image:
req = urllib.request.urlopen(image)
arr = np.asarray(bytearray(req.read()), dtype=np.uint8)
image = cv2.imdecode(arr, -1) # 'Load it as it is'

else: # if path in directory
image = cv2.imread(image)

image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
h, w = image.shape[:2]

image = self.transform(image=image)['image']

with torch.no_grad():
image = torch.tensor(image.unsqueeze(0), device=self.device, dtype=torch.float32)

output, edge_mask, ds_map = self.model(image)
output = F.interpolate(output, size=(h, w), mode='bilinear')
output = (output.squeeze().detach().cpu().numpy() * 255.0).astype(np.uint8) # convert uint8 type

salient_object = self.post_processing(image, output, h, w)
return output, salient_object


def post_processing(self, original_image, output_image, height, width, threshold=200):

original_image = self.invTrans(original_image)

original_image = F.interpolate(original_image, size=(height, width), mode='bilinear')
original_image = (original_image.squeeze().permute(1,2,0).detach().cpu().numpy() * 255.0).astype(np.uint8)

rgba_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2BGRA)
output_rbga_image = cv2.cvtColor(output_image, cv2.COLOR_BGR2BGRA)

output_rbga_image[:, :, 3] = output_image # Extract edges
edge_y, edge_x, _ = np.where(output_rbga_image <= threshold) # Edge coordinates

rgba_image[edge_y, edge_x, 3] = 0
return rgba_image

Binary file added model/__pycache__/EfficientNet.cpython-37.pyc
Binary file not shown.
Binary file added model/__pycache__/TRACER.cpython-37.pyc
Binary file not shown.
Binary file added modules/__pycache__/att_modules.cpython-37.pyc
Binary file not shown.
Binary file added modules/__pycache__/conv_modules.cpython-37.pyc
Binary file not shown.
5 changes: 3 additions & 2 deletions modules/att_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
from config import getConfig
from modules.conv_modules import BasicConv2d, DWConv, DWSConv

cfg = getConfig()

cfg = getConfig()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class Frequency_Edge_Module(nn.Module):
def __init__(self, radius, channel):
Expand Down Expand Up @@ -64,7 +65,7 @@ def forward(self, x):
x_fft = fftshift(x_fft)

# Mask -> low, high separate
mask = self.mask_radial(img=x, r=self.radius).cuda()
mask = self.mask_radial(img=x, r=self.radius).to(device)
high_frequency = x_fft * (1 - mask)
x_fft = ifftshift(high_frequency)
x_fft = ifft2(x_fft, dim=(-2, -1))
Expand Down
4 changes: 2 additions & 2 deletions trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def validate(self):

def test(self, args, save_path):
path = os.path.join(save_path, 'best_model.pth')
self.model.load_state_dict(torch.load(path))
self.model.load_state_dict(torch.load(path, map_location = self.device))
print('###### pre-trained Model restored #####')

te_img_folder = os.path.join(args.data_path, args.dataset, 'Test/images/')
Expand Down Expand Up @@ -226,7 +226,7 @@ def __init__(self, args, save_path):
self.model = nn.DataParallel(self.model).to(self.device)

path = os.path.join(save_path, 'best_model.pth')
self.model.load_state_dict(torch.load(path))
self.model.load_state_dict(torch.load(path, map_location = self.device))
print('###### pre-trained Model restored #####')

self.criterion = Criterion(args)
Expand Down
Binary file added util/__pycache__/effi_utils.cpython-37.pyc
Binary file not shown.
Binary file added util/__pycache__/utils.cpython-37.pyc
Binary file not shown.
5 changes: 3 additions & 2 deletions util/effi_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,12 +616,13 @@ def load_pretrained_weights(model, model_name, weights_path=None, load_fc=True,
advprop (bool): Whether to load pretrained weights
trained with advprop (valid when weights_path is None).
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if isinstance(weights_path, str):
state_dict = torch.load(weights_path, strict=False)
state_dict = torch.load(weights_path, strict=False, map_location = device)
else:
# AutoAugment or Advprop (different preprocessing)
url_map_ = url_map_advprop if advprop else url_map
state_dict = model_zoo.load_url(url_map_[model_name])
state_dict = model_zoo.load_url(url_map_[model_name], map_location=device)

if load_fc:
ret = model.load_state_dict(state_dict, strict=False)
Expand Down
4 changes: 2 additions & 2 deletions util/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def update(self, val, n=1):
}


def load_pretrained(model_name):
state_dict = model_zoo.load_url(url_TRACER[model_name])
def load_pretrained(model_name, device):
state_dict = model_zoo.load_url(url_TRACER[model_name], map_location = device)

return state_dict
4 changes: 2 additions & 2 deletions w.o_edges/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def validate(self):

def test(self, args, save_path):
path = os.path.join(save_path, 'best_model.pth')
self.model.load_state_dict(torch.load(path))
self.model.load_state_dict(torch.load(path, map_location=self.device))
print('###### pre-trained Model restored #####')

te_img_folder = os.path.join(args.data_path, args.dataset, 'Test/images/')
Expand Down Expand Up @@ -223,7 +223,7 @@ def __init__(self, args, save_path):
self.model = nn.DataParallel(self.model).to(self.device)

path = os.path.join(save_path, 'best_model.pth')
self.model.load_state_dict(torch.load(path))
self.model.load_state_dict(torch.load(path, map_location=self.device))
print('###### pre-trained Model restored #####')

self.criterion = Criterion(args)
Expand Down