-
Notifications
You must be signed in to change notification settings - Fork 347
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
测试没有图像输出 #175
Comments
你好,请问他的测试数据集放哪 |
用的他的命令行操作的。我最后跑通的是TensorFlow版本的那个 |
他这个跑不通吗,我测试就失败了 用他的命令行报错没有那个文件,我也找不到他那个文件在哪定义了, |
在dataset文件夹下psenet文件夹里面,我测试的是IC15,然后打开这个文件,在里面修改就可以了。 |
haode |
大神你好,请问你用他的训练命令了吗,训练有么有报错呢,我的出现一个C就终止了,这是什么问题呢 |
需要你自己把模型的预测结果,也就是bbox处理,放到图像上 import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import torch
from mmcv import Config
from .models import build_model
from .models.utils import fuse_module
from .dataset.psenet import psenet_ctw
class OcrTextDetector(object):
def __init__(self, ckpt_path, config_path, device='cpu'):
self.ckpt_path = ckpt_path
self.cfg_path = config_path
self.device = device
self.model = None
self.cfg = None
def build_model(self):
cfg = Config.fromfile(self.cfg_path)
for d in [cfg, cfg.data.test]:
d.update(dict(
report_speed=False
))
self.cfg = cfg
model = build_model(self.cfg.model)
model.to(self.device)
checkpoint = torch.load(self.ckpt_path, map_location=self.device)
d = dict()
for key,value in checkpoint['state_dict'].items():
tmp = key[7:]
d[tmp] = value
model.load_state_dict(d)
model = fuse_module(model)
model.eval()
self.model = model
return self
def preprocess_img(self, img_path):
img = psenet_ctw.get_img(img_path=img_path, read_type='pil')
img_meta = dict(
org_img_size=[np.array(img.shape[:2])]
)
img = psenet_ctw.scale_aligned_short(img)
img_meta.update(dict(
img_size=[np.array(img.shape[:2])]
))
img = Image.fromarray(img)
img = img.convert('RGB')
img = transforms.ToTensor()(img)
img = transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])(img)
return img, img_meta
def predict(self, img_path):
img, img_meta = self.preprocess_img(img_path)
img = torch.unsqueeze(img, 0)
outputs = self.model(img, img_metas=img_meta, cfg=self.cfg)
return outputs
if __name__ == "__main__":
ckpt_path = './checkpoints/psenet_r50_ctw_finetune/checkpoint.pth'
cfg = './config/psenet/psenet_r50_ctw_finetune.py'
img_path = '../../dataset/ctw1500/train_images/0002.jpg'
ocr = OcrTextDetector(ckpt_path, cfg).build_model().predict(img_path) def draw_bbox(bboxs,img):
bboxs_res = []
for bbox in bboxs:
bbox = np.reshape(bbox,(4,2))
cv2.drawContours(img, [bbox],-1, (0, 255, 0), 2)
bboxs_res.append(bbox)
return bboxs_res, img
detector = OcrTextDetector(ckpt_path, cfg).build_model().predict(img_path)
box = detector.predict(img_path)
img = cv2.imread(img_path)
bboxs_res, box_img = draw_bbox(box['bboxes'], img)
plt.imshow(box_img ) |
您好,您的来信我已经收到,感谢您的来信!谢谢
|
请问可以提供预训练模型吗 |
测试后只生成了位置文本,没有生成结果图像,这个应该怎么修改呢?
The text was updated successfully, but these errors were encountered: