Skip to content

Commit

Permalink
add instructions for CenterNet training and quantization
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Feb 23, 2022
1 parent b9df9f7 commit 3c3710f
Show file tree
Hide file tree
Showing 6 changed files with 156 additions and 82 deletions.
13 changes: 12 additions & 1 deletion centernet_training/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# CenterNet training

CenterNet with CSPDarknet YOLOv5m backbone and FPN neck

## Environment setup

Repo: [centernet-lightning](https://github.com/gau-nernst/centernet-lightning)
Expand All @@ -22,4 +24,13 @@ python train.py fit --configs/centernet.yaml --configs configs/macaque.yaml

## Export weights for Vitis AI

TBD
Extract state dict from the saved checkpoint and rename the keys:

```python
import torch

state_dict = torch.load('checkpoint.ckpt', map_location='cpu')['state_dict']
state_dict = {k[len('model.'):]: v for k, v in state_dict.items()}

torch.save(state_dict, 'model_weights.pth')
```
41 changes: 38 additions & 3 deletions kv260_centernet/README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,45 @@
# CenterNet with KV260

`model.py` provides model definition for CenterNet. Only ResNet backbones are supported. The heads are fixed at 256-channels wide and 3-layer deep. FPN is also fixed.
Go to [centernet_training](../centernet_training/) for more details about the model and how to train it.

## Model training
Files overview:

Model is trained from this repo: https://github.com/gau-nernst/centernet-lightning
- `model.py`: model definition for CenterNet. Depends on [vision_toolbox](https://github.com/gau-nernst/vision-toolbox) for backbones and FPN neck.
- `quantize.py`
- `data.py`

## Environment setup

Please use [Vitis AI 2.0](https://github.com/Xilinx/Vitis-AI) since it uses PyTorch 1.7.1 by default. GPU version is recommended.

Inside Vitis AI 2.0 Docker environment:

```bash
conda activate vitis-ai-pytorch
git clone https://github.com/gau-nernst/macaque-detection
cd macaque-detection/kv260_centernet
pip install -r requirements.txt --user
```

## Quantization

Calibration

```bash
python quantize.py --weights macaque_centernet_darknet.pth --output_dir "centernet_darknet" --data_dir /datasets/NTU_macaque_videos/images/ --batch_size 64 calibrate
```

Test (validation)

```bash
python quantize.py --weights macaque_centernet_darknet.pth --output_dir "centernet_darknet" --data_dir /datasets/NTU_macaque_videos/images/ --ann_json /datasets/NTU_macaque_videos/ntu_macaques_val.json --batch_size 64 test
```

Export

```bash
python quantize.py --weights macaque_centernet_darknet.pth --output_dir "centernet_darknet" --data_dir /datasets/NTU_macaque_videos/images/ --ann_json /datasets/NTU_macaque_videos/ntu_macaques_val.json --batch_size 64 export
```

## Inference

Expand Down
21 changes: 17 additions & 4 deletions kv260_centernet/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,28 @@
COCOeval = None


class ImageFolder(Dataset):
class CalibrationDataset(Dataset):
def __init__(self, img_dir, transforms=None):
super().__init__()
self.img_dir = img_dir
self.images = [x for x in os.listdir(img_dir) if x.lower().endswith((".jpg", ".jpeg"))]
images = []
def get_all_images(dir):
files = os.listdir(dir)
for file in files:
full_path = os.path.join(dir, file)

if file.lower().endswith(('.jpg', '.jpeg', '.png')):
images.append(full_path)

elif os.path.isdir(full_path):
get_all_images(full_path)

get_all_images(img_dir)
print(f'Discovered {len(images)} images')
self.images = images
self.transforms = transforms

def __getitem__(self, idx):
img = Image.open(os.path.join(self.img_dir, self.images[idx]))
img = Image.open(self.images[idx])
if self.transforms is not None:
img = self.transforms(img)
return img
Expand Down
83 changes: 43 additions & 40 deletions kv260_centernet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,7 @@
from torch import nn
import torch.nn.functional as F
from torchvision import models


class ConvBnAct(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=False)
self.bn = nn.BatchNorm2d(out_channels)
self.act = nn.ReLU(inplace=True)
from vision_toolbox import backbones, ConvBnAct, FPN


class ResNetBackbone(nn.Module):
Expand All @@ -25,6 +18,9 @@ def __init__(self, name):
self.out_channels = [x.shape[1] for x in self(torch.rand(1,3,224,224))]

def forward(self, x):
return self.forward_features(x)[-1]

def forward_features(self, x):
out1 = self.feat_extractor.conv1(x)
out1 = self.feat_extractor.bn1(out1)
out1 = self.feat_extractor.maxpool(out1)
Expand All @@ -35,29 +31,8 @@ def forward(self, x):
return [out2, out3, out4, out5]


class FPN(nn.Module):
def __init__(self, in_channels, out_channels=256, block=ConvBnAct):
super().__init__()
self.out_channels = out_channels
self.stride = 2**(len(in_channels)-1)

self.lateral_convs = nn.ModuleList([nn.Conv2d(in_c, out_channels, kernel_size=1) for in_c in in_channels])
self.output_convs = nn.ModuleList([block(out_channels, out_channels) for _ in range(len(in_channels)-1)])

def forward(self, x):
laterals = [l_conv(x[i]) for i, l_conv in enumerate(self.lateral_convs)]
outputs = [laterals.pop()]

for o_conv in self.output_convs:
out = F.interpolate(outputs[-1], scale_factor=2., mode="nearest")
out = out + laterals.pop()
outputs.append(o_conv(out))

return outputs[-1]


class Head(nn.Sequential):
def __init__(self, in_channels, out_channels, width=256, depth=3):
def __init__(self, in_channels, out_channels, width, depth):
super().__init__()
for i in range(depth):
in_c = in_channels if i == 0 else width
Expand All @@ -66,23 +41,50 @@ def __init__(self, in_channels, out_channels, width=256, depth=3):


class CenterNet(nn.Module):
def __init__(self, num_classes, backbone):
def __init__(self, backbone, neck, heads):
super().__init__()
self.backbone = ResNetBackbone(backbone)
self.neck = FPN(self.backbone.out_channels)
self.heads = nn.ModuleDict({
"heatmap": Head(self.neck.out_channels, num_classes),
"box_2d": Head(self.neck.out_channels, 4)
})
self.backbone = backbone
self.neck = neck
self.heads = heads
self.relu = nn.ReLU(inplace=True)

def forward(self, x):
out = self.backbone(x)
out = self.backbone.forward_features(x)
out = self.neck(out)
heatmap = self.heads.heatmap(out)
box_offsets = self.heads.box_2d(out)
box_offsets = self.relu(box_offsets)
return heatmap, box_offsets


def build_model(pth_path, backbone_name):
state_dict = torch.load(pth_path, map_location='cpu')

# backbone
backbone = backbones.__dict__[backbone_name]()
# backbone = ResNetBackbone(backbone_name)

# neck
neck_out_channels = state_dict['neck.lateral_convs.0.weight'].shape[0]
neck = FPN(backbone.out_channels, out_channels=neck_out_channels)

# heads
num_classes = state_dict['heads.heatmap.out_conv.weight'].shape[0]
head_width = state_dict['heads.heatmap.block_1.conv.weight'].shape[0]
head_depth = len([x for x in state_dict.keys() if x.startswith('heads.heatmap.block_') and x.endswith('.conv.weight')])
heads = nn.ModuleDict({
"heatmap": Head(neck_out_channels, num_classes, head_width, head_depth),
"box_2d": Head(neck_out_channels, 4, head_width, head_depth)
})
state_dict['heads.box_2d.out_conv.weight'] *= 16
state_dict['heads.box_2d.out_conv.bias'] *= 16

model = CenterNet(backbone, neck, heads).eval()
model.load_state_dict(state_dict)
model(torch.rand(1,3,512,512))
return model


def decode_detections(heatmap, box_offsets):
batch_size, _, out_h, out_w = heatmap.shape

Expand All @@ -103,8 +105,9 @@ def decode_detections(heatmap, box_offsets):
cx = indices % out_w + 0.5
cy = indices // out_w + 0.5

box_offsets = box_offsets.flatten(start_dim=-2) * 16
box_offsets = box_offsets.clamp_min(0)
# box_offsets = box_offsets.flatten(start_dim=-2) * 16
box_offsets = box_offsets.flatten(start_dim=-2)
# box_offsets = box_offsets.clamp_min(0)

# boxes are in output feature maps coordinates
x1 = cx - torch.gather(box_offsets[:,0], dim=-1, index=indices) # x1 = cx - left
Expand Down
77 changes: 43 additions & 34 deletions kv260_centernet/quantize.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# standard libraries
import argparse

from pytorch_nndct.apis import torch_quantizer

# 3rd party libraries
import numpy as np
import torch
from torch.utils.data import DataLoader, Subset, ConcatDataset
Expand All @@ -11,15 +11,19 @@
import albumentations as A
from albumentations.pytorch import ToTensorV2

from model import CenterNet, decode_detections
from data import ImageFolder, CocoDetection, coco_collate, CocoEvaluator
# local files
from model import build_model, decode_detections
from data import CalibrationDataset, CocoDetection, coco_collate, CocoEvaluator

# Vitis AI
from pytorch_nndct.apis import torch_quantizer


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


def get_quantized_model(args):
model = CenterNet(1, "resnet34").eval()
model.load_state_dict(torch.load(args.weights, map_location="cpu"))
model = build_model(args.weights, 'darknet_yolov5m')

sample_inputs = torch.randn((1,3,args.img_h,args.img_w))
mode = "calib" if args.command == "calibrate" else "test"
Expand All @@ -34,26 +38,27 @@ def get_quantized_model(args):
def get_dataset(data_dir, img_h=512, img_w=512, ann_json=None, detection=False):
if detection:
transform = A.Compose([
A.Resize(img_h, img_w),
A.Normalize(),
A.SmallestMaxSize(max(args.img_h, args.img_w)),
A.CenterCrop(args.img_h, args.img_w),
A.Normalize(mean=(0,0,0), std=(1,1,1)),
ToTensorV2()
], bbox_params=dict(format="coco", label_fields=["labels"], min_area=1))
ds = CocoDetection(data_dir, ann_json, transforms=transform)

else:
transform = T.Compose([
T.Resize((img_h, img_w)),
T.Resize(max(args.img_h, args.img_w)),
T.CenterCrop((args.img_h, args.img_w)),
T.ToTensor(),
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
T.Normalize(mean=(0,0,0), std=(1,1,1))
])
ds = ImageFolder(data_dir, transforms=transform)
ds = CalibrationDataset(data_dir, transforms=transform)

return ds



@torch.no_grad()
def validate(model: CenterNet, dataloader, num_classes):
def validate(model, dataloader, num_classes):
model.eval()
evaluator = CocoEvaluator(num_classes)

Expand All @@ -75,29 +80,9 @@ def validate(model: CenterNet, dataloader, num_classes):
evaluator.update(detections, targets)

return evaluator.get_metrics()


def get_args_parser():
parser = argparse.ArgumentParser()
parser.add_argument("command", type=str)

parser.add_argument("--num_classes", type=int, default=1)
parser.add_argument("--weights")
parser.add_argument("--output_dir", default="./centernet")
parser.add_argument("--data_dir")
parser.add_argument("--ann_json")

parser.add_argument("--img_w", type=int, default=512)
parser.add_argument("--img_h", type=int, default=512)

parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--num_samples", type=int, default=0)

return parser


if __name__ == "__main__":
args = get_args_parser().parse_args()
def main(args):
assert args.command in ("calibrate", "test", "export")
if args.command == 'export':
args.num_samples = 1
Expand Down Expand Up @@ -137,3 +122,27 @@ def get_args_parser():
quant_model(img)

quantizer.export_xmodel(output_dir=args.output_dir)


def get_args_parser():
parser = argparse.ArgumentParser()
parser.add_argument("command", type=str)

parser.add_argument("--num_classes", type=int, default=1)
parser.add_argument("--weights")
parser.add_argument("--output_dir", default="./centernet")
parser.add_argument("--data_dir")
parser.add_argument("--ann_json")

parser.add_argument("--img_w", type=int, default=512)
parser.add_argument("--img_h", type=int, default=512)

parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--num_samples", type=int, default=0)

return parser


if __name__ == "__main__":
args = get_args_parser().parse_args()
main(args)
3 changes: 3 additions & 0 deletions kv260_centernet/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
albumentations
pycocotools
git+https://github.com/gau-nernst/vision-toolbox.git

0 comments on commit 3c3710f

Please sign in to comment.