|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +# @Time : 2019/8/23 21:57 |
| 3 | +# @Author : zhoujun |
| 4 | + |
| 5 | +import torch |
| 6 | +from torch import nn |
| 7 | +import torch.nn.functional as F |
| 8 | +from models.modules import * |
| 9 | + |
| 10 | +backbone_dict = {'resnet18': {'models': resnet18, 'out': [64, 128, 256, 512]}, |
| 11 | + 'resnet34': {'models': resnet34, 'out': [64, 128, 256, 512]}, |
| 12 | + 'resnet50': {'models': resnet50, 'out': [256, 512, 1024, 2048]}, |
| 13 | + 'resnet101': {'models': resnet101, 'out': [256, 512, 1024, 2048]}, |
| 14 | + 'resnet152': {'models': resnet152, 'out': [256, 512, 1024, 2048]}, |
| 15 | + 'resnext50_32x4d': {'models': resnext50_32x4d, 'out': [256, 512, 1024, 2048]}, |
| 16 | + 'resnext101_32x8d': {'models': resnext101_32x8d, 'out': [256, 512, 1024, 2048]} |
| 17 | + } |
| 18 | + |
| 19 | + |
| 20 | +# 'MobileNetV3_Large': {'models': MobileNetV3_Large, 'out': [24, 40, 160, 160]}, |
| 21 | +# 'MobileNetV3_Small': {'models': MobileNetV3_Small, 'out': [16, 24, 48, 96]}, |
| 22 | +# 'shufflenetv2': {'models': shufflenet_v2_x1_0, 'out': [24, 116, 232, 464]}} |
| 23 | + |
| 24 | +inplace = True |
| 25 | + |
| 26 | +class PSENet(nn.Module): |
| 27 | + def __init__(self, backbone, result_num=6, scale: int = 1, pretrained=False): |
| 28 | + super(PSENet, self).__init__() |
| 29 | + assert backbone in backbone_dict, 'backbone must in: {}'.format(backbone_dict) |
| 30 | + self.name = backbone |
| 31 | + self.scale = scale |
| 32 | + conv_out = 256 |
| 33 | + backbone_model, backbone_out = backbone_dict[backbone]['models'], backbone_dict[backbone]['out'] |
| 34 | + self.backbone = backbone_model(pretrained=pretrained) |
| 35 | + |
| 36 | + # Top layer |
| 37 | + self.toplayer = nn.Conv2d(backbone_out[3], conv_out, kernel_size=1, stride=1, padding=0) # Reduce channels |
| 38 | + # Lateral layers |
| 39 | + self.latlayer1 = nn.Conv2d(backbone_out[2], conv_out, kernel_size=1, stride=1, padding=0) |
| 40 | + self.latlayer2 = nn.Conv2d(backbone_out[1], conv_out, kernel_size=1, stride=1, padding=0) |
| 41 | + self.latlayer3 = nn.Conv2d(backbone_out[0], conv_out, kernel_size=1, stride=1, padding=0) |
| 42 | + |
| 43 | + # Smooth layers |
| 44 | + self.smooth1 = nn.Conv2d(conv_out, conv_out, kernel_size=3, stride=1, padding=1) |
| 45 | + self.smooth2 = nn.Conv2d(conv_out, conv_out, kernel_size=3, stride=1, padding=1) |
| 46 | + self.smooth3 = nn.Conv2d(conv_out, conv_out, kernel_size=3, stride=1, padding=1) |
| 47 | + |
| 48 | + self.conv = nn.Sequential( |
| 49 | + nn.Conv2d(conv_out * 4, conv_out, kernel_size=3, padding=1, stride=1), |
| 50 | + nn.BatchNorm2d(conv_out), |
| 51 | + nn.ReLU(inplace=inplace) |
| 52 | + ) |
| 53 | + self.out_conv = nn.Conv2d(conv_out, result_num, kernel_size=1, stride=1) |
| 54 | + |
| 55 | + def forward(self, input: torch.Tensor): |
| 56 | + _, _, H, W = input.size() |
| 57 | + c2, c3, c4, c5 = self.backbone(input) |
| 58 | + # Top-down |
| 59 | + p5 = self.toplayer(c5) |
| 60 | + p4 = self._upsample_add(p5, self.latlayer1(c4)) |
| 61 | + p3 = self._upsample_add(p4, self.latlayer2(c3)) |
| 62 | + p2 = self._upsample_add(p3, self.latlayer3(c2)) |
| 63 | + # Smooth |
| 64 | + p4 = self.smooth1(p4) |
| 65 | + p3 = self.smooth2(p3) |
| 66 | + p2 = self.smooth3(p2) |
| 67 | + |
| 68 | + x = self._upsample_cat(p2, p3, p4, p5) |
| 69 | + x = self.conv(x) |
| 70 | + x = self.out_conv(x) |
| 71 | + |
| 72 | + x = F.interpolate(x, size=(H // self.scale, W // self.scale), mode='bilinear', align_corners=True) |
| 73 | + return x |
| 74 | + |
| 75 | + def _upsample_add(self, x, y): |
| 76 | + return F.interpolate(x, size=y.size()[2:], mode='bilinear', align_corners=False) + y |
| 77 | + |
| 78 | + def _upsample_cat(self, p2, p3, p4, p5): |
| 79 | + h, w = p2.size()[2:] |
| 80 | + p3 = F.interpolate(p3, size=(h, w), mode='bilinear', align_corners=False) |
| 81 | + p4 = F.interpolate(p4, size=(h, w), mode='bilinear', align_corners=False) |
| 82 | + p5 = F.interpolate(p5, size=(h, w), mode='bilinear', align_corners=False) |
| 83 | + return torch.cat([p2, p3, p4, p5], dim=1) |
| 84 | + |
| 85 | +if __name__ == '__main__': |
| 86 | + device = torch.device('cpu') |
| 87 | + x = torch.zeros(1, 3, 640, 640).to(device) |
| 88 | + |
| 89 | + model = PAN(backbone='resnet18', fpem_repeat=2, pretrained=True).to(device) |
| 90 | + y = model(x) |
| 91 | + print(y.shape) |
| 92 | + # torch.save(model.state_dict(), 'PAN.pth') |
0 commit comments