From 3c3710fa4d2e8f9a545e8c6a2560e2c40261f88c Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 23 Feb 2022 22:03:30 +0800 Subject: [PATCH] add instructions for CenterNet training and quantization --- centernet_training/README.md | 13 ++++- kv260_centernet/README.md | 41 ++++++++++++++-- kv260_centernet/data.py | 21 ++++++-- kv260_centernet/model.py | 83 +++++++++++++++++--------------- kv260_centernet/quantize.py | 77 ++++++++++++++++------------- kv260_centernet/requirements.txt | 3 ++ 6 files changed, 156 insertions(+), 82 deletions(-) create mode 100644 kv260_centernet/requirements.txt diff --git a/centernet_training/README.md b/centernet_training/README.md index 312c735..71c0e88 100644 --- a/centernet_training/README.md +++ b/centernet_training/README.md @@ -1,5 +1,7 @@ # CenterNet training +CenterNet with CSPDarknet YOLOv5m backbone and FPN neck + ## Environment setup Repo: [centernet-lightning](https://github.com/gau-nernst/centernet-lightning) @@ -22,4 +24,13 @@ python train.py fit --configs/centernet.yaml --configs configs/macaque.yaml ## Export weights for Vitis AI -TBD +Extract state dict from the saved checkpoint and rename the keys: + +```python +import torch + +state_dict = torch.load('checkpoint.ckpt', map_location='cpu')['state_dict'] +state_dict = {k[len('model.'):]: v for k, v in state_dict.items()} + +torch.save(state_dict, 'model_weights.pth') +``` diff --git a/kv260_centernet/README.md b/kv260_centernet/README.md index 7541368..7994258 100644 --- a/kv260_centernet/README.md +++ b/kv260_centernet/README.md @@ -1,10 +1,45 @@ # CenterNet with KV260 -`model.py` provides model definition for CenterNet. Only ResNet backbones are supported. The heads are fixed at 256-channels wide and 3-layer deep. FPN is also fixed. +Go to [centernet_training](../centernet_training/) for more details about the model and how to train it. -## Model training +Files overview: -Model is trained from this repo: https://github.com/gau-nernst/centernet-lightning +- `model.py`: model definition for CenterNet. Depends on [vision_toolbox](https://github.com/gau-nernst/vision-toolbox) for backbones and FPN neck. +- `quantize.py` +- `data.py` + +## Environment setup + +Please use [Vitis AI 2.0](https://github.com/Xilinx/Vitis-AI) since it uses PyTorch 1.7.1 by default. GPU version is recommended. + +Inside Vitis AI 2.0 Docker environment: + +```bash +conda activate vitis-ai-pytorch +git clone https://github.com/gau-nernst/macaque-detection +cd macaque-detection/kv260_centernet +pip install -r requirements.txt --user +``` + +## Quantization + +Calibration + +```bash +python quantize.py --weights macaque_centernet_darknet.pth --output_dir "centernet_darknet" --data_dir /datasets/NTU_macaque_videos/images/ --batch_size 64 calibrate +``` + +Test (validation) + +```bash +python quantize.py --weights macaque_centernet_darknet.pth --output_dir "centernet_darknet" --data_dir /datasets/NTU_macaque_videos/images/ --ann_json /datasets/NTU_macaque_videos/ntu_macaques_val.json --batch_size 64 test +``` + +Export + +```bash +python quantize.py --weights macaque_centernet_darknet.pth --output_dir "centernet_darknet" --data_dir /datasets/NTU_macaque_videos/images/ --ann_json /datasets/NTU_macaque_videos/ntu_macaques_val.json --batch_size 64 export +``` ## Inference diff --git a/kv260_centernet/data.py b/kv260_centernet/data.py index 990d162..95f2eab 100644 --- a/kv260_centernet/data.py +++ b/kv260_centernet/data.py @@ -14,15 +14,28 @@ COCOeval = None -class ImageFolder(Dataset): +class CalibrationDataset(Dataset): def __init__(self, img_dir, transforms=None): super().__init__() - self.img_dir = img_dir - self.images = [x for x in os.listdir(img_dir) if x.lower().endswith((".jpg", ".jpeg"))] + images = [] + def get_all_images(dir): + files = os.listdir(dir) + for file in files: + full_path = os.path.join(dir, file) + + if file.lower().endswith(('.jpg', '.jpeg', '.png')): + images.append(full_path) + + elif os.path.isdir(full_path): + get_all_images(full_path) + + get_all_images(img_dir) + print(f'Discovered {len(images)} images') + self.images = images self.transforms = transforms def __getitem__(self, idx): - img = Image.open(os.path.join(self.img_dir, self.images[idx])) + img = Image.open(self.images[idx]) if self.transforms is not None: img = self.transforms(img) return img diff --git a/kv260_centernet/model.py b/kv260_centernet/model.py index 8bd2171..2332c80 100644 --- a/kv260_centernet/model.py +++ b/kv260_centernet/model.py @@ -2,14 +2,7 @@ from torch import nn import torch.nn.functional as F from torchvision import models - - -class ConvBnAct(nn.Sequential): - def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1): - super().__init__() - self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=False) - self.bn = nn.BatchNorm2d(out_channels) - self.act = nn.ReLU(inplace=True) +from vision_toolbox import backbones, ConvBnAct, FPN class ResNetBackbone(nn.Module): @@ -25,6 +18,9 @@ def __init__(self, name): self.out_channels = [x.shape[1] for x in self(torch.rand(1,3,224,224))] def forward(self, x): + return self.forward_features(x)[-1] + + def forward_features(self, x): out1 = self.feat_extractor.conv1(x) out1 = self.feat_extractor.bn1(out1) out1 = self.feat_extractor.maxpool(out1) @@ -35,29 +31,8 @@ def forward(self, x): return [out2, out3, out4, out5] -class FPN(nn.Module): - def __init__(self, in_channels, out_channels=256, block=ConvBnAct): - super().__init__() - self.out_channels = out_channels - self.stride = 2**(len(in_channels)-1) - - self.lateral_convs = nn.ModuleList([nn.Conv2d(in_c, out_channels, kernel_size=1) for in_c in in_channels]) - self.output_convs = nn.ModuleList([block(out_channels, out_channels) for _ in range(len(in_channels)-1)]) - - def forward(self, x): - laterals = [l_conv(x[i]) for i, l_conv in enumerate(self.lateral_convs)] - outputs = [laterals.pop()] - - for o_conv in self.output_convs: - out = F.interpolate(outputs[-1], scale_factor=2., mode="nearest") - out = out + laterals.pop() - outputs.append(o_conv(out)) - - return outputs[-1] - - class Head(nn.Sequential): - def __init__(self, in_channels, out_channels, width=256, depth=3): + def __init__(self, in_channels, out_channels, width, depth): super().__init__() for i in range(depth): in_c = in_channels if i == 0 else width @@ -66,23 +41,50 @@ def __init__(self, in_channels, out_channels, width=256, depth=3): class CenterNet(nn.Module): - def __init__(self, num_classes, backbone): + def __init__(self, backbone, neck, heads): super().__init__() - self.backbone = ResNetBackbone(backbone) - self.neck = FPN(self.backbone.out_channels) - self.heads = nn.ModuleDict({ - "heatmap": Head(self.neck.out_channels, num_classes), - "box_2d": Head(self.neck.out_channels, 4) - }) + self.backbone = backbone + self.neck = neck + self.heads = heads + self.relu = nn.ReLU(inplace=True) def forward(self, x): - out = self.backbone(x) + out = self.backbone.forward_features(x) out = self.neck(out) heatmap = self.heads.heatmap(out) box_offsets = self.heads.box_2d(out) + box_offsets = self.relu(box_offsets) return heatmap, box_offsets +def build_model(pth_path, backbone_name): + state_dict = torch.load(pth_path, map_location='cpu') + + # backbone + backbone = backbones.__dict__[backbone_name]() + # backbone = ResNetBackbone(backbone_name) + + # neck + neck_out_channels = state_dict['neck.lateral_convs.0.weight'].shape[0] + neck = FPN(backbone.out_channels, out_channels=neck_out_channels) + + # heads + num_classes = state_dict['heads.heatmap.out_conv.weight'].shape[0] + head_width = state_dict['heads.heatmap.block_1.conv.weight'].shape[0] + head_depth = len([x for x in state_dict.keys() if x.startswith('heads.heatmap.block_') and x.endswith('.conv.weight')]) + heads = nn.ModuleDict({ + "heatmap": Head(neck_out_channels, num_classes, head_width, head_depth), + "box_2d": Head(neck_out_channels, 4, head_width, head_depth) + }) + state_dict['heads.box_2d.out_conv.weight'] *= 16 + state_dict['heads.box_2d.out_conv.bias'] *= 16 + + model = CenterNet(backbone, neck, heads).eval() + model.load_state_dict(state_dict) + model(torch.rand(1,3,512,512)) + return model + + def decode_detections(heatmap, box_offsets): batch_size, _, out_h, out_w = heatmap.shape @@ -103,8 +105,9 @@ def decode_detections(heatmap, box_offsets): cx = indices % out_w + 0.5 cy = indices // out_w + 0.5 - box_offsets = box_offsets.flatten(start_dim=-2) * 16 - box_offsets = box_offsets.clamp_min(0) + # box_offsets = box_offsets.flatten(start_dim=-2) * 16 + box_offsets = box_offsets.flatten(start_dim=-2) + # box_offsets = box_offsets.clamp_min(0) # boxes are in output feature maps coordinates x1 = cx - torch.gather(box_offsets[:,0], dim=-1, index=indices) # x1 = cx - left diff --git a/kv260_centernet/quantize.py b/kv260_centernet/quantize.py index d80149c..5b972c5 100644 --- a/kv260_centernet/quantize.py +++ b/kv260_centernet/quantize.py @@ -1,7 +1,7 @@ +# standard libraries import argparse -from pytorch_nndct.apis import torch_quantizer - +# 3rd party libraries import numpy as np import torch from torch.utils.data import DataLoader, Subset, ConcatDataset @@ -11,15 +11,19 @@ import albumentations as A from albumentations.pytorch import ToTensorV2 -from model import CenterNet, decode_detections -from data import ImageFolder, CocoDetection, coco_collate, CocoEvaluator +# local files +from model import build_model, decode_detections +from data import CalibrationDataset, CocoDetection, coco_collate, CocoEvaluator + +# Vitis AI +from pytorch_nndct.apis import torch_quantizer + DEVICE = "cuda" if torch.cuda.is_available() else "cpu" def get_quantized_model(args): - model = CenterNet(1, "resnet34").eval() - model.load_state_dict(torch.load(args.weights, map_location="cpu")) + model = build_model(args.weights, 'darknet_yolov5m') sample_inputs = torch.randn((1,3,args.img_h,args.img_w)) mode = "calib" if args.command == "calibrate" else "test" @@ -34,26 +38,27 @@ def get_quantized_model(args): def get_dataset(data_dir, img_h=512, img_w=512, ann_json=None, detection=False): if detection: transform = A.Compose([ - A.Resize(img_h, img_w), - A.Normalize(), + A.SmallestMaxSize(max(args.img_h, args.img_w)), + A.CenterCrop(args.img_h, args.img_w), + A.Normalize(mean=(0,0,0), std=(1,1,1)), ToTensorV2() ], bbox_params=dict(format="coco", label_fields=["labels"], min_area=1)) ds = CocoDetection(data_dir, ann_json, transforms=transform) else: transform = T.Compose([ - T.Resize((img_h, img_w)), + T.Resize(max(args.img_h, args.img_w)), + T.CenterCrop((args.img_h, args.img_w)), T.ToTensor(), - T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) + T.Normalize(mean=(0,0,0), std=(1,1,1)) ]) - ds = ImageFolder(data_dir, transforms=transform) + ds = CalibrationDataset(data_dir, transforms=transform) return ds - @torch.no_grad() -def validate(model: CenterNet, dataloader, num_classes): +def validate(model, dataloader, num_classes): model.eval() evaluator = CocoEvaluator(num_classes) @@ -75,29 +80,9 @@ def validate(model: CenterNet, dataloader, num_classes): evaluator.update(detections, targets) return evaluator.get_metrics() - - -def get_args_parser(): - parser = argparse.ArgumentParser() - parser.add_argument("command", type=str) - - parser.add_argument("--num_classes", type=int, default=1) - parser.add_argument("--weights") - parser.add_argument("--output_dir", default="./centernet") - parser.add_argument("--data_dir") - parser.add_argument("--ann_json") - - parser.add_argument("--img_w", type=int, default=512) - parser.add_argument("--img_h", type=int, default=512) - - parser.add_argument("--batch_size", type=int, default=4) - parser.add_argument("--num_samples", type=int, default=0) - - return parser -if __name__ == "__main__": - args = get_args_parser().parse_args() +def main(args): assert args.command in ("calibrate", "test", "export") if args.command == 'export': args.num_samples = 1 @@ -137,3 +122,27 @@ def get_args_parser(): quant_model(img) quantizer.export_xmodel(output_dir=args.output_dir) + + +def get_args_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("command", type=str) + + parser.add_argument("--num_classes", type=int, default=1) + parser.add_argument("--weights") + parser.add_argument("--output_dir", default="./centernet") + parser.add_argument("--data_dir") + parser.add_argument("--ann_json") + + parser.add_argument("--img_w", type=int, default=512) + parser.add_argument("--img_h", type=int, default=512) + + parser.add_argument("--batch_size", type=int, default=4) + parser.add_argument("--num_samples", type=int, default=0) + + return parser + + +if __name__ == "__main__": + args = get_args_parser().parse_args() + main(args) diff --git a/kv260_centernet/requirements.txt b/kv260_centernet/requirements.txt new file mode 100644 index 0000000..53a4f30 --- /dev/null +++ b/kv260_centernet/requirements.txt @@ -0,0 +1,3 @@ +albumentations +pycocotools +git+https://github.com/gau-nernst/vision-toolbox.git