Skip to content

Commit

Permalink
Merge branch 'ote' into feature/dp/update_ov_wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
druzhkov-paul authored Apr 28, 2021
2 parents 6805a68 + b6479f2 commit 05d2a0e
Show file tree
Hide file tree
Showing 23 changed files with 200 additions and 185 deletions.
5 changes: 2 additions & 3 deletions mmdet/apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from mmdet.core import (DistEvalHook, DistEvalPlusBeforeRunHook, EvalHook,
EvalPlusBeforeRunHook)
from mmdet.integration.nncf import CompressionHook, wrap_nncf_model
from mmdet.integration.nncf import CompressionHook, CheckpointHookBeforeTraining, wrap_nncf_model
from mmdet.parallel import MMDataCPU
from mmcv.utils import build_from_cfg

Expand Down Expand Up @@ -121,8 +121,6 @@ def train_detector(model,
else:
model = MMDataCPU(model)

if nncf_enable_compression and distributed:
compression_ctrl.distributed()

# build runner
optimizer = build_optimizer(model, cfg.optimizer)
Expand Down Expand Up @@ -191,6 +189,7 @@ def train_detector(model,

if nncf_enable_compression:
runner.register_hook(CompressionHook(compression_ctrl=compression_ctrl))
runner.register_hook(CheckpointHookBeforeTraining())
# user-defined hooks
if cfg.get('custom_hooks', None):
custom_hooks = cfg.custom_hooks
Expand Down
19 changes: 10 additions & 9 deletions mmdet/core/post_processing/bbox_nms.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from torch.onnx import is_in_onnx_export

from mmdet.integration.nncf import no_nncf_trace
from mmdet.integration.nncf import no_nncf_trace, is_in_nncf_tracing

from mmdet.ops.nms import batched_nms

Expand Down Expand Up @@ -63,13 +63,14 @@ def multiclass_nms_core(multi_bboxes, multi_scores, score_thr, nms_cfg, max_num=
bboxes = multi_bboxes[:, None].expand(multi_scores.size(0), num_classes, 4)
scores = multi_scores

if is_in_onnx_export():
labels = torch.arange(num_classes, dtype=torch.long, device=scores.device) \
.unsqueeze(0) \
.expand_as(scores) \
.reshape(-1)
bboxes = bboxes.reshape(-1, 4)
scores = scores.reshape(-1)
if is_in_onnx_export() or is_in_nncf_tracing():
with no_nncf_trace():
labels = torch.arange(num_classes, dtype=torch.long, device=scores.device) \
.unsqueeze(0) \
.expand_as(scores) \
.reshape(-1)
bboxes = bboxes.reshape(-1, 4)
scores = scores.reshape(-1)

assert nms_cfg['type'] == 'nms', 'Only vanilla NMS is compatible with ONNX export'
nms_cfg['score_thr'] = score_thr
Expand All @@ -91,7 +92,7 @@ def multiclass_nms_core(multi_bboxes, multi_scores, score_thr, nms_cfg, max_num=
labels = labels[keep]
dets = torch.cat([dets, labels.to(dets.dtype).unsqueeze(-1)], dim=1)

if not is_in_onnx_export() and max_num > 0:
if not (is_in_onnx_export() or is_in_nncf_tracing()) and max_num > 0:
dets = dets[:max_num]

if return_inds:
Expand Down
14 changes: 9 additions & 5 deletions mmdet/core/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import torch
from six.moves import map, zip

from mmdet.integration.nncf.utils import no_nncf_trace, is_in_nncf_tracing

from ..mask.structures import BitmapMasks, PolygonMasks


Expand Down Expand Up @@ -50,7 +52,7 @@ def arange(start=0,
layout=torch.strided,
device=None,
requires_grad=False):
if torch.onnx.is_in_onnx_export():
if torch.onnx.is_in_onnx_export() or is_in_nncf_tracing():
if end is None:
raise ValueError('End of range must be defined.')
assert out is None
Expand Down Expand Up @@ -93,20 +95,22 @@ def topk(x, k, dim=None, **kwargs):
if dim is None:
dim = x.dim() - 1

if is_in_onnx_export():
if is_in_onnx_export() or is_in_nncf_tracing():
n = operators.shape_as_tensor(x)[dim].unsqueeze(0)
if not isinstance(k, torch.Tensor):
k = torch.tensor([k], dtype=torch.long)
# Workaround for ONNXRuntime: convert values to int to get minimum.
n = torch.min(torch.cat((k, n), dim=0).int()).long()
with no_nncf_trace():
# Workaround for ONNXRuntime: convert values to int to get minimum.
n = torch.min(torch.cat((k, n), dim=0).int()).long()
# ONNX OpSet 10 does not support non-floating point input for TopK.
original_dtype = x.dtype
require_cast = original_dtype not in {
torch.float16, torch.float32, torch.float64
}
if require_cast:
x = x.to(torch.float32)
values, keep = torch.topk(x, n, dim=dim, **kwargs)
with no_nncf_trace():
values, keep = torch.topk(x, n, dim=dim, **kwargs)
if require_cast:
values = values.to(original_dtype)
else:
Expand Down
5 changes: 3 additions & 2 deletions mmdet/integration/nncf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
from .compression import (check_nncf_is_enabled, get_nncf_config_from_meta,
get_nncf_metadata, get_uncompressed_model,
is_checkpoint_nncf, wrap_nncf_model)
from .compression_hooks import CompressionHook
from .compression_hooks import CompressionHook, CheckpointHookBeforeTraining
from .utils import get_nncf_version, is_in_nncf_tracing, no_nncf_trace

__all__ = [
'check_nncf_is_enabled',
'CompressionHook',
'CheckpointHookBeforeTraining',
'get_nncf_config_from_meta',
'get_nncf_metadata',
'get_nncf_version',
'get_uncompressed_model',
'is_checkpoint_nncf',
'is_in_nncf_tracing'
'is_in_nncf_tracing',
'no_nncf_trace',
'wrap_nncf_model',
]
67 changes: 13 additions & 54 deletions mmdet/integration/nncf/compression.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
import pathlib
import tempfile
from functools import partial

import mmcv
import torch
Expand Down Expand Up @@ -35,6 +34,7 @@ def is_checkpoint_nncf(path):
except FileNotFoundError:
return False


def get_nncf_config_from_meta(path):
"""
The function uses metadata stored in a checkpoint to restore the nncf
Expand All @@ -46,7 +46,7 @@ def get_nncf_config_from_meta(path):

nncf_enable_compression = meta.get('nncf_enable_compression', False)
assert nncf_enable_compression, \
'get_nncf_config_from_meta should be run for NNCF-compressed checkpoints only'
'get_nncf_config_from_meta should be run for NNCF-compressed checkpoints only'

config_text = meta['config']

Expand All @@ -60,14 +60,14 @@ def get_nncf_config_from_meta(path):
nncf_config = cfg.get('nncf_config')

assert isinstance(nncf_config, dict), (
f'Wrong nncf_config part of the config saved in the metainfo'
f' of the snapshot {path}:'
f' nncf_config={nncf_config}')
f'Wrong nncf_config part of the config saved in the metainfo'
f' of the snapshot {path}:'
f' nncf_config={nncf_config}')

nncf_config_part = {
'nncf_config': nncf_config,
'find_unused_parameters': True
}
'nncf_config': nncf_config,
'find_unused_parameters': True
}
if nncf_config_part['nncf_config'].get('log_dir'):
# TODO(LeonidBeynenson): improve work with log dir
log_dir = tempfile.mkdtemp(prefix='nncf_output_')
Expand All @@ -76,6 +76,7 @@ def get_nncf_config_from_meta(path):
logger.info(f'Read nncf config from meta nncf_config_part={nncf_config_part}')
return nncf_config_part


def wrap_nncf_model(model,
cfg,
data_loader_for_init=None,
Expand All @@ -91,6 +92,7 @@ def wrap_nncf_model(model,
from nncf import (NNCFConfig, create_compressed_model,
register_default_init_args)
from nncf.dynamic_graph.io_handling import nncf_model_input
from nncf.dynamic_graph.trace_tensor import TracedTensor
from nncf.initialization import InitializingDataLoader

class MMInitializeDataLoader(InitializingDataLoader):
Expand Down Expand Up @@ -177,14 +179,15 @@ def dummy_forward(model):
def wrap_inputs(args, kwargs):
# during dummy_forward
if not len(kwargs):
args[0][0] = nncf_model_input(args[0][0])
if not isinstance(args[0][0], TracedTensor):
args[0][0] = nncf_model_input(args[0][0])
return args, kwargs

# during building original graph
if not kwargs.get('return_loss') and kwargs.get('forward_export'):
return args, kwargs

# during model's forward
# during model's forward in export
assert 'img' in kwargs, 'During model forward img must be in kwargs'
img = kwargs['img']
if isinstance(img, list):
Expand All @@ -207,55 +210,11 @@ def wrap_inputs(args, kwargs):
dummy_forward_fn=dummy_forward,
wrap_inputs_fn=wrap_inputs,
resuming_state_dict=resuming_state_dict)
model = change_export_func_first_conv(model)
model.export = export_method.__get__(model)

return compression_ctrl, model


def change_export_func_first_conv(model):
""" To avoid saturation issue
At the moment works only for mobilenet
"""

def run_hacked_export_quantization(self, x):
from nncf.quantization.layers import (
ExportQuantizeToFakeQuantize, ExportQuantizeToONNXQuantDequant,
QuantizerExportMode, get_scale_zp_from_input_low_input_high)
from nncf.utils import no_jit_trace
with no_jit_trace():
input_range = abs(self.scale) + self.eps
# todo: take bias into account during input_low/input_high calculation
input_low = input_range * self.level_low / self.level_high
input_high = input_range

if self._export_mode == QuantizerExportMode.ONNX_QUANTIZE_DEQUANTIZE_PAIRS:
y_scale, y_zero_point = get_scale_zp_from_input_low_input_high(self.level_low,
self.level_high,
input_low,
input_high)

if self._export_mode == QuantizerExportMode.ONNX_QUANTIZE_DEQUANTIZE_PAIRS:
return ExportQuantizeToONNXQuantDequant.apply(x, y_scale, y_zero_point)
if self._export_mode == QuantizerExportMode.FAKE_QUANTIZE:
x = x / 2.0
return ExportQuantizeToFakeQuantize.apply(x, self.levels, input_low, input_high, input_low * 2,
input_high * 2)
raise RuntimeError

logger = get_root_logger()
orig_model = model.get_nncf_wrapped_model()
try:
# pylint: disable=protected-access
module_ = orig_model.backbone.features.init_block.conv.pre_ops._modules['0']
except (AttributeError, KeyError) as e:
logger.info(f'Cannot change an export function for the first Conv due {e}')
return model
module_.op.run_export_quantization = partial(run_hacked_export_quantization, module_.op)
logger.info('Change an export function for the first Conv to avoid saturation issue on AVX2, AVX512')
return model


def get_uncompressed_model(module):
if not is_nncf_enabled():
return module
Expand Down
33 changes: 32 additions & 1 deletion mmdet/integration/nncf/compression_hooks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from mmcv.runner.hooks.hook import Hook
from mmcv.runner.hooks.hook import HOOKS, Hook
from mmcv.runner.dist_utils import master_only


@HOOKS.register_module()
class CompressionHook(Hook):
def __init__(self, compression_ctrl=None):
self.compression_ctrl = compression_ctrl
Expand All @@ -16,6 +18,35 @@ def before_run(self, runner):
print_statistics(self.compression_ctrl.statistics(), runner.logger)


@HOOKS.register_module()
class CheckpointHookBeforeTraining(Hook):
"""Save checkpoints before training.
Args:
save_optimizer (bool): Whether to save optimizer state_dict in the
checkpoint. It is usually used for resuming experiments.
Default: True.
out_dir (str, optional): The directory to save checkpoints. If not
specified, ``runner.work_dir`` will be used by default.
"""

def __init__(self,
save_optimizer=True,
out_dir=None,
**kwargs):
self.save_optimizer = save_optimizer
self.out_dir = out_dir
self.args = kwargs

@master_only
def before_run(self, runner):
runner.logger.info(f'Saving checkpoint before training')
if not self.out_dir:
self.out_dir = runner.work_dir
runner.save_checkpoint(
self.out_dir, filename_tmpl='before_training.pth', save_optimizer=self.save_optimizer, **self.args)


def print_statistics(stats, logger):
try:
from texttable import Texttable
Expand Down
3 changes: 2 additions & 1 deletion mmdet/models/dense_heads/fcos_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from mmcv.runner import force_fp32

from mmdet.core import distance2bbox, multi_apply, multiclass_nms, reduce_mean
from mmdet.integration.nncf.utils import is_in_nncf_tracing
from ..builder import HEADS, build_loss
from .anchor_free_head import AnchorFreeHead

Expand Down Expand Up @@ -387,7 +388,7 @@ def _get_bboxes_single(self,

# Set max number of box to be feed into nms in deployment
deploy_nms_pre = cfg.get('deploy_nms_pre', -1)
if deploy_nms_pre > 0 and torch.onnx.is_in_onnx_export():
if deploy_nms_pre > 0 and (torch.onnx.is_in_onnx_export() or is_in_nncf_tracing()):
max_scores, _ = (mlvl_scores * mlvl_centerness[:, None]).max(dim=1)
_, topk_inds = max_scores.topk(deploy_nms_pre)
mlvl_scores = mlvl_scores[topk_inds, :]
Expand Down
3 changes: 2 additions & 1 deletion mmdet/models/dense_heads/rpn_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torch.onnx import is_in_onnx_export

from mmdet.ops.nms import batched_nms
from mmdet.integration.nncf.utils import is_in_nncf_tracing
from ...core.utils.misc import topk

from ..builder import HEADS
Expand Down Expand Up @@ -157,7 +158,7 @@ def _get_bboxes_single(self,
(w >= cfg.min_bbox_size)
& (h >= cfg.min_bbox_size),
as_tuple=False).squeeze()
if valid_inds.sum().item() != len(proposals) or is_in_onnx_export():
if valid_inds.sum().item() != len(proposals) or (is_in_onnx_export() or is_in_nncf_tracing()):
proposals = proposals[valid_inds, :]
scores = scores[valid_inds]
ids = ids[valid_inds]
Expand Down
5 changes: 4 additions & 1 deletion mmdet/models/dense_heads/rpn_test_mixin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import sys

from mmdet.core import merge_aug_proposals
from mmdet.integration.nncf.utils import no_nncf_trace


if sys.version_info >= (3, 7):
from mmdet.utils.contextmanagers import completed
Expand Down Expand Up @@ -33,7 +35,8 @@ def simple_test_rpn(self, x, img_metas):
list[Tensor]: Proposals of each image.
"""
rpn_outs = self(x)
proposal_list = self.get_bboxes(*rpn_outs, img_metas)
with no_nncf_trace():
proposal_list = self.get_bboxes(*rpn_outs, img_metas)
return proposal_list

def aug_test_rpn(self, feats, img_metas):
Expand Down
3 changes: 2 additions & 1 deletion mmdet/models/dense_heads/yolo_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from mmdet.core import (build_anchor_generator, build_assigner,
build_bbox_coder, build_sampler, images_to_levels,
multi_apply, multiclass_nms)
from mmdet.integration.nncf.utils import is_in_nncf_tracing
from ..builder import HEADS, build_loss
from .base_dense_head import BaseDenseHead
from .dense_test_mixins import BBoxTestMixin
Expand Down Expand Up @@ -281,7 +282,7 @@ def _get_bboxes_single(self,
# Get top-k prediction
nms_pre = cfg.get('nms_pre', -1)
if 0 < nms_pre < conf_pred.size(0) and (
not torch.onnx.is_in_onnx_export()):
not torch.onnx.is_in_onnx_export() or not is_in_nncf_tracing()):
_, topk_inds = conf_pred.topk(nms_pre)
bbox_pred = bbox_pred[topk_inds, :]
cls_pred = cls_pred[topk_inds, :]
Expand Down
Loading

0 comments on commit 05d2a0e

Please sign in to comment.