目录
1.我搭建的是这一个
2.方法一,用一个sequential配置
2 方法二,用cfg进行配置
1.我搭建的是这一个
2.方法一,用一个sequential配置
import torch
from torch import nn
from torchinfo import summary
class Net(nn.Module):
def __init__(self):
super().__init__()
self.layer=nn.Sequential(
nn.Conv2d(3,64,3,1,1),
nn.Conv2d(64,64,3,1,1),
nn.MaxPool2d(2),
nn.Conv2d(64,128,3,1,1),
nn.Conv2d(128,128,3,1,1),
nn.MaxPool2d(2),
nn.Conv2d(128,256,3,1,1),
nn.Conv2d(256,256,3,1,1),
nn.Conv2d(256,256,3,1,1),
nn.MaxPool2d(2),
nn.Conv2d(256,512,3,1,1),
nn.Conv2d(512,512,3,1,1),
nn.Conv2d(512,512,3,1,1),
nn.MaxPool2d(2),
nn.Conv2d(512, 512, 3, 1, 1),
nn.Conv2d(512, 512, 3, 1, 1),
nn.Conv2d(512, 512, 3, 1, 1),
nn.MaxPool2d(2),
)
self.linear=nn.Sequential(
nn.Flatten(1),
nn.Linear(512 * 7 * 7, 4096),
nn.Linear(4096, 4096),
nn.Linear(4096, 1000),
nn.Softmax(dim=1)
)
def forward(self,x):
x=self.layer(x)
x=self.linear(x)
return x
if __name__ == '__main__':
x=torch.randn(1,3,224,224)
net=Net()
out=net(x)
print(out.shape)
summary(net,(1,3,224,224))
summary是打印网络结构信息
2 方法二,用cfg进行配置
import cv2
import torch
from torch import nn
from torchinfo import summary
cfg={
'vgg':[
64,64,'M',
128,128,'M',
256,256,256,'M',
512,512,512,'M',
512,512,512,'M'
]
}
class Net(nn.Module):
def __init__(self):
super().__init__()
layer=[]
out_ch=3
for in_ch in cfg['vgg']:
if in_ch=='M':
layer+=[nn.MaxPool2d(2)]
else:
layer+=[nn.Conv2d(out_ch,in_ch,3,1,1)]
out_ch=in_ch#容易出错
self.layer=nn.Sequential(*layer)
self.linear=nn.Sequential(
nn.Flatten(),
nn.Linear(512*7*7,4096),
nn.Linear(4096,4096),
nn.Linear(4096,1000)
)
def forward(self,x):
x=self.layer(x)
x=self.linear(x)
return x
if __name__ == '__main__':
x=torch.randn(1,3,224,224)
net=Net()
out=net(x)
print(out.shape)