-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathpreprocess_mask_rcnn.py
64 lines (57 loc) · 2.98 KB
/
preprocess_mask_rcnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from detectron2.utils.logger import setup_logger
setup_logger()
import argparse
from pathlib import Path
import numpy as np
import cv2
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
import matplotlib.image as mpimg
def preprocess(args):
images = sorted((args.data_path / 'video_frames').glob('*.jpg'))
obj_name = args.object_name
if obj_name == '': obj_name = args.class_name
out_mask_dir = args.data_path / 'masks' / obj_name
out_mask_dir.mkdir(exist_ok=True)
cfg = get_cfg()
# add project-specific config (e.g., TensorMask) here if you're not running a model in detectron2's core library
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # set threshold for this model
# Find a model from detectron2's model zoo. You can use the https://dl.fbaipublicfiles... url as well
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
predictor = DefaultPredictor(cfg)
number_of_frames = len(images)
for i in range(0,number_of_frames):
# try:
im = np.array(mpimg.imread(images[i]))
outputs = predictor(im)
if args.class_name == 'anything':
try:
mask = outputs["instances"].pred_masks[0].cpu().numpy()
cv2.imwrite(f"{out_mask_dir}/%05d.png" % (i), mask * 255.0)
except:
cv2.imwrite(f"{out_mask_dir}/%05d.png" % (i), np.zeros((im.shape[0], im.shape[1])))
else:
found_anything = False
for j in range(len(outputs['instances'])):
if predictor.metadata.thing_classes[(outputs['instances'][j].pred_classes.cpu()).long()]==args.class_name:
# found the required class, save the mask
mask = outputs["instances"].pred_masks[j].cpu().numpy()
cv2.imwrite(f"{out_mask_dir}/%05d.png"%(i), mask * 255.0)
found_anything = True
break
else:
# found unneeded class
print("Frame %d: Did not find %s, found %s"%(i,args.class_name,predictor.metadata.thing_classes[(outputs['instances'][j].pred_classes.cpu()).long()]))
if not found_anything:
cv2.imwrite(f"{out_mask_dir}/%05d.png" % (i), np.zeros((im.shape[0],im.shape[1])))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Preprocess image sequence')
parser.add_argument(
'--data-path', type=Path, default=Path('./data/'), help='folder to process')
parser.add_argument('--class_name', type=str, default='person',
help='The foreground object class')
parser.add_argument('--object_name', type=str, default='', help='Name of the foreground object, default is the same as the class name')
args = parser.parse_args()
preprocess(args=args)