-
Notifications
You must be signed in to change notification settings - Fork 76
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
大老多分类怎么写呢? #15
Comments
import torch
import argparse
# torch.utils.data.DataLoader是一个迭代器,方便我们去多线程地读取数据,并且可以实现batch以及shuffle的读取等。
from torch.utils.data import DataLoader
from torch import nn, optim
from torchvision.transforms import transforms
from unet import Unet
from dataset import LiverDataset #图片数据
import cv2
# 是否使用cuda
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# job年限轴 - 训练的job年限轴
x_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])#归一化
])
# 收入轴 - 为mask只需要转换为tensor
y_transforms = transforms.ToTensor()
#训练模型
def train():
batch_size = 1 #批处理图片数
num_epochs=1 #训练轮数
num_classes=2 #分类数
num_workers=0 #进程数
model = Unet(3, num_classes).to(device)
if num_classes>1:
# 当为CrossEntropyLoss时,outputs会自动softmax,不需要手动计算
criterion = nn.CrossEntropyLoss() #损失函数,多分类
else:
# 当为BCEWithLogitsLoss 为bceloss计算sigmoid是因为bceloss不包含sigmoid函数,需要自行在模型中添加
criterion = nn.BCEWithLogitsLoss() #损失函数,单分类
optimizer = optim.Adam(model.parameters()) #优化方法
liver_dataset = LiverDataset("data/train",transform=x_transforms,target_transform=y_transforms)#加载数据集,就可以以liver_dataset[key]形式取值
dataload = DataLoader(liver_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10)
dt_size = len(dataload.dataset)
epoch_loss = 0
step = 0
for x, y in dataload:
step += 1
inputs = x.to(device)
labels = y.to(device)
if num_classes>1: # 多分类 - 损失函数,多分类
#这个说明是不对的:https://www.cnblogs.com/dyc99/p/12665778.html
#这个说明是才对的:https://discuss.pytorch.org/t/runtimeerror-only-batches-of-spatial-targets-supported-3d-tensors-but-got-targets-of-dimension-4/82098/9
labels = torch.argmax(labels, dim=1) #输入出要少一维数组,这样转
print('x:',inputs.size()) # x: torch.Size([1, 3, 512, 512]) #网络输出为1个3类,当size是1个数时,返回一个1行3列的张量,其中每个元素又是一个512行512列的张量,最小元素的每一行服从正态分布
print('y:',labels.size()) # y: torch.Size([1, 1, 512, 512])
optimizer.zero_grad()# 清零
outputs = model(inputs) # forward 前向传播
loss=criterion(outputs, labels)
# if num_classes>1: # 多分类 - 损失函数,多分类
# img_size =512 #多分类时设置的图片大小, 这个要跟上面的inputs = x.to(device) 返回的每个元素一样
# #pred=torch.rand(batch_size,num_classes,img_size,img_size) #假设pred是模型的输出,当然这里是:outputs
# target=torch.randint(num_classes,(batch_size,img_size,img_size)) 等于 labels = torch.argmax(labels, dim=1) #输入出要少一维数组,这样转
# print('y:',target.size()) #y: torch.Size([1, 512, 512])
# return
# loss=criterion(outputs,target) #损失函数,多分类
# else:
# loss=criterion(outputs, labels) #损失函数,单分类
loss.backward() #梯度下降,计算出梯度
optimizer.step() #更新参数一次:所有的优化器Optimizer都实现了step()方法来对所有的参数进行更新
epoch_loss += loss.item()
print("%d/%d,train_loss:%0.3f" % (step, (dt_size - 1) // dataload.batch_size + 1, loss.item()))
print("epoch %d loss:%0.3f" % (epoch, epoch_loss/step))
torch.save(model.state_dict(), 'weights_%d.pth' % epoch) # 返回模型的所有内容
return model
#显示模型的输出结果
def test():
train()
ckpt='weights_0.pth'
model = Unet(3, 2)
model.load_state_dict(torch.load(ckpt,map_location='cpu'))
liver_dataset = LiverDataset("data/val", transform=x_transforms,target_transform=y_transforms)
dataloaders = DataLoader(liver_dataset, batch_size=1)
model.eval()
import matplotlib.pyplot as plt
plt.ion()
with torch.no_grad():
for x, _ in dataloaders:
y=model(x).sigmoid()
img_y=torch.squeeze(y).numpy()
plt.imshow(img_y)
plt.pause(0.01)
#plt.savefig(img_y) #linux下使用,只能生成图片
plt.show()
test()
------------------------最后报错:ypeError: Invalid dimensions for image data-----------------------
TypeError Traceback (most recent call last)
<ipython-input-16-d732d69be42b> in <module>
103 plt.show()
104
--> 105 test()
<ipython-input-16-d732d69be42b> in test()
98 y=model(x).sigmoid()
99 img_y=torch.squeeze(y).numpy()
--> 100 plt.imshow(img_y)
101 plt.pause(0.01)
102 #plt.savefig(img_y) #linux下使用,只能生成图片
d:\ProgramData\Anaconda3\lib\site-packages\matplotlib\pyplot.py in imshow(X, cmap, norm, aspect, interpolation, alpha, vmin, vmax, origin, extent, shape, filternorm, filterrad, imlim, resample, url, hold, data, **kwargs)
3203 filternorm=filternorm, filterrad=filterrad,
3204 imlim=imlim, resample=resample, url=url, data=data,
-> 3205 **kwargs)
3206 finally:
3207 ax._hold = washold
d:\ProgramData\Anaconda3\lib\site-packages\matplotlib\__init__.py in inner(ax, *args, **kwargs)
1853 "the Matplotlib list!)" % (label_namer, func.__name__),
1854 RuntimeWarning, stacklevel=2)
-> 1855 return func(ax, *args, **kwargs)
1856
1857 inner.__doc__ = _add_data_doc(inner.__doc__,
d:\ProgramData\Anaconda3\lib\site-packages\matplotlib\axes\_axes.py in imshow(self, X, cmap, norm, aspect, interpolation, alpha, vmin, vmax, origin, extent, shape, filternorm, filterrad, imlim, resample, url, **kwargs)
5485 resample=resample, **kwargs)
5486
-> 5487 im.set_data(X)
5488 im.set_alpha(alpha)
5489 if im.get_clip_path() is None:
d:\ProgramData\Anaconda3\lib\site-packages\matplotlib\image.py in set_data(self, A)
651 if not (self._A.ndim == 2
652 or self._A.ndim == 3 and self._A.shape[-1] in [3, 4]):
--> 653 raise TypeError("Invalid dimensions for image data")
654
655 if self._A.ndim == 3:
TypeError: Invalid dimensions for image data
有没有解决的办法呢?分享交流一下
|
|
Could you share the data and trained weight link of Baidu Netdisk again? The original link has expired.Thank you very much! |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
if num_classes>1: # 多分类 - 损失函数,多分类
#pred=torch.rand(batch_size,num_classes,img_size,img_size)
target=torch.randint(num_classes,(batch_size,img_size,img_size))
loss=criterion(outputs,target)
else:
loss = criterion(outputs, labels) #损失函数,单分类
这个地方不会写呢
The text was updated successfully, but these errors were encountered: