import torch from torch import nn class inceptionBlock(nn.Module): def __init__(self): super().__init__() self.conv1=nn.Conv2d( in_channels=1, #模块的第一个卷积的输入通道数 out_channels=16, kernel_size=1, stride=2 #padding=1 #特征图大小不改变 ) self.relu1=nn.ReLU() self.conv2=nn.Conv2d( in_channels=16, out_channels=128, kernel_size=1, #stride=1 #padding=1 ) self.relu2=nn.ReLU() self.conv3=nn.Conv2d( in_channels=128, out_channels=256, kernel_size=1, stride=2 ) self.relu3=nn.ReLU() # self.conv4=nn.Conv2d( # in_channels=256, # out_channels=256, # kernel_size=1, # stride=7 # ) self.b4=nn.Sequential( nn.MaxPool2d( #池化操作不改变特征图大小 kernel_size=3, padding=1, stride=7, ), nn.Conv2d( in_channels=256, #池化不改变特征图通道数 out_channels=256, kernel_size=1, ) ) self.relu4=nn.ReLU() self.flatten=nn.Flatten() self.fc1=nn.Linear( in_features=256, out_features=10, ) def forward(self,x): x=self.conv1(x) x=self.relu1(x) x=self.conv2(x) x=self.relu2(x) x=self.conv3(x) x=self.relu3(x) # x=self.conv4(x) x=self.b4(x) x=self.relu4(x) x=self.flatten(x) x=self.fc1(x) return x if __name__ == '__main__': x = torch.rand(size=(1, 1, 28, 28)) block = inceptionBlock() out = block(x) print(out.shape) |