From 67f0c2ec54a397c0b7d8f7b55c72c39d5a11e05b Mon Sep 17 00:00:00 2001 From: WongGawa Date: Mon, 23 Dec 2024 15:18:35 +0800 Subject: [PATCH] remove gpu's description in md --- GETTING_STARTED.md | 13 +- GETTING_STARTED_CN.md | 13 +- configs/yolov10/README.md | 5 +- configs/yolov3/README.md | 5 +- configs/yolov4/README.md | 5 +- configs/yolov5/README.md | 5 +- configs/yolov7/README.md | 5 +- configs/yolov8/README.md | 5 +- configs/yolov9/README.md | 5 +- configs/yolox/README.md | 5 +- demo/__init__.py | 3 + demo/predict.py | 340 ++++++++++++++++++ docs/en/modelzoo/yolov3.md | 5 +- docs/en/modelzoo/yolov4.md | 5 +- docs/en/modelzoo/yolov5.md | 5 +- docs/en/modelzoo/yolov7.md | 5 +- docs/en/modelzoo/yolov8.md | 5 +- docs/en/modelzoo/yolox.md | 5 +- docs/en/tutorials/configuration.md | 2 +- docs/en/tutorials/finetune.md | 4 +- docs/en/tutorials/quick_start.md | 13 +- docs/zh/modelzoo/yolov3.md | 6 +- docs/zh/modelzoo/yolov4.md | 6 +- docs/zh/modelzoo/yolov5.md | 6 +- docs/zh/modelzoo/yolov7.md | 6 +- docs/zh/modelzoo/yolov8.md | 6 +- docs/zh/modelzoo/yolox.md | 6 +- docs/zh/tutorials/configuration.md | 2 +- docs/zh/tutorials/finetune.md | 4 +- docs/zh/tutorials/quick_start.md | 13 +- examples/finetune_SHWD/README.md | 4 +- examples/finetune_car_detection/README.md | 10 +- .../finetune_single_class_dataset/README.md | 2 +- tutorials/configuration_CN.md | 2 +- 34 files changed, 423 insertions(+), 108 deletions(-) create mode 100644 demo/__init__.py create mode 100644 demo/predict.py diff --git a/GETTING_STARTED.md b/GETTING_STARTED.md index a80ba9cc..7b61508f 100644 --- a/GETTING_STARTED.md +++ b/GETTING_STARTED.md @@ -13,9 +13,6 @@ This document provides a brief introduction to the usage of built-in command-lin ``` # Run with Ascend (By default) python demo/predict.py --config ./configs/yolov7/yolov7.yaml --weight=/path_to_ckpt/WEIGHT.ckpt --image_path /path_to_image/IMAGE.jpg - -# Run with GPU -python demo/predict.py --config ./configs/yolov7/yolov7.yaml --weight=/path_to_ckpt/WEIGHT.ckpt --image_path /path_to_image/IMAGE.jpg --device_target=GPU ``` @@ -48,23 +45,23 @@ to understand their behavior. Some common arguments are: ``` -* To train a model on 1 NPU/GPU/CPU: +* To train a model on 1 NPU/CPU: ``` python train.py --config ./configs/yolov7/yolov7.yaml ``` -* To train a model on 8 NPUs/GPUs: +* To train a model on 8 NPUs: ``` msrun --worker_num=8 --local_worker_num=8 --bind_core=True --log_dir=./yolov7_log python train.py --config ./configs/yolov7/yolov7.yaml --is_parallel True ``` -* To evaluate a model's performance on 1 NPU/GPU/CPU: +* To evaluate a model's performance on 1 NPU/CPU: ``` python test.py --config ./configs/yolov7/yolov7.yaml --weight /path_to_ckpt/WEIGHT.ckpt ``` -* To evaluate a model's performance 8 NPUs/GPUs: +* To evaluate a model's performance 8 NPUs: ``` msrun --worker_num=8 --local_worker_num=8 --bind_core=True --log_dir=./yolov7_log python test.py --config ./configs/yolov7/yolov7.yaml --weight /path_to_ckpt/WEIGHT.ckpt --is_parallel True ``` -*Notes: (1) The default hyper-parameter is used for 8-card training, and some parameters need to be adjusted in the case of a single card. (2) The default device is Ascend, and you can modify it by specifying 'device_target' as Ascend/GPU/CPU, as these are currently supported.* +*Notes: (1) The default hyper-parameter is used for 8-card training, and some parameters need to be adjusted in the case of a single card. (2) The default device is Ascend, and you can modify it by specifying 'device_target' as Ascend/CPU, as these are currently supported.* * For more options, see `train/test.py -h`. * Notice that if you are using `msrun` startup with 2 devices, please add `--bind_core=True` to improve performance. For example: diff --git a/GETTING_STARTED_CN.md b/GETTING_STARTED_CN.md index 93327d45..2bd14958 100644 --- a/GETTING_STARTED_CN.md +++ b/GETTING_STARTED_CN.md @@ -11,9 +11,6 @@ ```shell # NPU (默认) python demo/predict.py --config ./configs/yolov7/yolov7.yaml --weight=/path_to_ckpt/WEIGHT.ckpt --image_path /path_to_image/IMAGE.jpg - -# GPU -python demo/predict.py --config ./configs/yolov7/yolov7.yaml --weight=/path_to_ckpt/WEIGHT.ckpt --image_path /path_to_image/IMAGE.jpg --device_target=GPU ``` 有关命令行参数的详细信息,请参阅`demo/predict.py -h`,或查看其[源代码](https://github.com/mindspore-lab/mindyolo/blob/master/deploy/predict.py)。 @@ -45,24 +42,24 @@ python demo/predict.py --config ./configs/yolov7/yolov7.yaml --weight=/path_to_c ``` -* 在单卡NPU/GPU/CPU上训练模型: +* 在单卡NPU/CPU上训练模型: ```shell python train.py --config ./configs/yolov7/yolov7.yaml ``` -* 在多卡NPU/GPU上进行分布式模型训练,以8卡为例: +* 在多卡NPU上进行分布式模型训练,以8卡为例: ```shell msrun --worker_num=8 --local_worker_num=8 --bind_core=True --log_dir=./yolov7_log python train.py --config ./configs/yolov7/yolov7.yaml --is_parallel True ``` -* 在单卡NPU/GPU/CPU上评估模型的精度: +* 在单卡NPU/CPU上评估模型的精度: ```shell python test.py --config ./configs/yolov7/yolov7.yaml --weight /path_to_ckpt/WEIGHT.ckpt ``` -* 在多卡NPU/GPU上进行分布式评估模型的精度: +* 在多卡NPU上进行分布式评估模型的精度: ```shell msrun --worker_num=8 --local_worker_num=8 --bind_core=True --log_dir=./yolov7_log python test.py --config ./configs/yolov7/yolov7.yaml --weight /path_to_ckpt/WEIGHT.ckpt --is_parallel True ``` -*注意:默认超参为8卡训练,单卡情况需调整部分参数。 默认设备为Ascend,您可以指定'device_target'的值为Ascend/GPU/CPU。* +*注意:默认超参为8卡训练,单卡情况需调整部分参数。 默认设备为Ascend,您可以指定'device_target'的值为Ascend/CPU。* * 有关更多选项,请参阅 `train/test.py -h`. * 在云脑上进行训练,请在[这里](./tutorials/cloud/modelarts_CN.md)查看 diff --git a/configs/yolov10/README.md b/configs/yolov10/README.md index d1dc6a47..6eec9841 100644 --- a/configs/yolov10/README.md +++ b/configs/yolov10/README.md @@ -48,11 +48,10 @@ Please refer to the [GETTING_STARTED](https://github.com/mindspore-lab/mindyolo/ It is easy to reproduce the reported results with the pre-defined training recipe. For distributed training on multiple Ascend 910 devices, please run ```shell -# distributed training on multiple GPU/Ascend devices +# distributed training on multiple Ascend devices msrun --worker_num=8 --local_worker_num=8 --bind_core=True --log_dir=./yolov10_log python train.py --config ./configs/yolov10/yolov10n.yaml --device_target Ascend --is_parallel True ``` -Similarly, you can train the model on multiple GPU devices with the above msrun command. **Note:** For more information about msrun configuration, please refer to [here](https://www.mindspore.cn/tutorials/experts/zh-CN/r2.3.1/parallel/msrun_launcher.html). For detailed illustration of all hyper-parameters, please refer to [config.py](https://github.com/mindspore-lab/mindyolo/blob/master/mindyolo/utils/config.py). @@ -64,7 +63,7 @@ For detailed illustration of all hyper-parameters, please refer to [config.py](h If you want to train or finetune the model on a smaller dataset without distributed training, please run: ```shell -# standalone training on a CPU/GPU/Ascend device +# standalone training on a CPU/Ascend device python train.py --config ./configs/yolov19/yolov10n.yaml --device_target Ascend ``` diff --git a/configs/yolov3/README.md b/configs/yolov3/README.md index f0b2336d..66311bf8 100644 --- a/configs/yolov3/README.md +++ b/configs/yolov3/README.md @@ -37,11 +37,10 @@ python mindyolo/utils/convert_weight_darknet53.py It is easy to reproduce the reported results with the pre-defined training recipe. For distributed training on multiple Ascend 910 devices, please run ```shell -# distributed training on multiple GPU/Ascend devices +# distributed training on multiple Ascend devices msrun --worker_num=8 --local_worker_num=8 --bind_core=True --log_dir=./yolov3_log python train.py --config ./configs/yolov3/yolov3.yaml --device_target Ascend --is_parallel True ``` -Similarly, you can train the model on multiple GPU devices with the above msrun command. **Note:** For more information about msrun configuration, please refer to [here](https://www.mindspore.cn/tutorials/experts/zh-CN/r2.3.1/parallel/msrun_launcher.html). For detailed illustration of all hyper-parameters, please refer to [config.py](https://github.com/mindspore-lab/mindyolo/blob/master/mindyolo/utils/config.py). @@ -53,7 +52,7 @@ For detailed illustration of all hyper-parameters, please refer to [config.py](h If you want to train or finetune the model on a smaller dataset without distributed training, please run: ```shell -# standalone training on a CPU/GPU/Ascend device +# standalone training on a CPU/Ascend device python train.py --config ./configs/yolov3/yolov3.yaml --device_target Ascend ``` diff --git a/configs/yolov4/README.md b/configs/yolov4/README.md index c97c8893..9cc75c69 100644 --- a/configs/yolov4/README.md +++ b/configs/yolov4/README.md @@ -51,11 +51,10 @@ python mindyolo/utils/convert_weight_cspdarknet53.py It is easy to reproduce the reported results with the pre-defined training recipe. For distributed training on multiple Ascend 910 devices, please run ```shell -# distributed training on multiple GPU/Ascend devices +# distributed training on multiple Ascend devices msrun --worker_num=8 --local_worker_num=8 --bind_core=True --log_dir=./yolov4_log python train.py --config ./configs/yolov4/yolov4-silu.yaml --device_target Ascend --is_parallel True --epochs 320 ``` -Similarly, you can train the model on multiple GPU devices with the above msrun command. **Note:** For more information about msrun configuration, please refer to [here](https://www.mindspore.cn/tutorials/experts/zh-CN/r2.3.1/parallel/msrun_launcher.html). For detailed illustration of all hyper-parameters, please refer to [config.py](https://github.com/mindspore-lab/mindyolo/blob/master/mindyolo/utils/config.py). @@ -72,7 +71,7 @@ multiprocessing/semaphore_tracker.py: 144 UserWarning: semaphore_tracker: There If you want to train or finetune the model on a smaller dataset without distributed training, please run: ```shell -# standalone training on a CPU/GPU/Ascend device +# standalone training on a CPU/Ascend device python train.py --config ./configs/yolov4/yolov4-silu.yaml --device_target Ascend --epochs 320 ``` diff --git a/configs/yolov5/README.md b/configs/yolov5/README.md index d59520b1..b9b7720b 100644 --- a/configs/yolov5/README.md +++ b/configs/yolov5/README.md @@ -25,11 +25,10 @@ Please refer to the [GETTING_STARTED](https://github.com/mindspore-lab/mindyolo/ It is easy to reproduce the reported results with the pre-defined training recipe. For distributed training on multiple Ascend 910 devices, please run ```shell -# distributed training on multiple GPU/Ascend devices +# distributed training on multiple Ascend devices msrun --worker_num=8 --local_worker_num=8 --bind_core=True --log_dir=./yolov5_log python train.py --config ./configs/yolov5/yolov5n.yaml --device_target Ascend --is_parallel True ``` -Similarly, you can train the model on multiple GPU devices with the above msrun command. **Note:** For more information about msrun configuration, please refer to [here](https://www.mindspore.cn/tutorials/experts/zh-CN/r2.3.1/parallel/msrun_launcher.html). For detailed illustration of all hyper-parameters, please refer to [config.py](https://github.com/mindspore-lab/mindyolo/blob/master/mindyolo/utils/config.py). @@ -41,7 +40,7 @@ For detailed illustration of all hyper-parameters, please refer to [config.py](h If you want to train or finetune the model on a smaller dataset without distributed training, please run: ```shell -# standalone training on a CPU/GPU/Ascend device +# standalone training on a CPU/Ascend device python train.py --config ./configs/yolov5/yolov5n.yaml --device_target Ascend ``` diff --git a/configs/yolov7/README.md b/configs/yolov7/README.md index baf932a2..628bbca8 100644 --- a/configs/yolov7/README.md +++ b/configs/yolov7/README.md @@ -28,11 +28,10 @@ Please refer to the [GETTING_STARTED](https://github.com/mindspore-lab/mindyolo/ It is easy to reproduce the reported results with the pre-defined training recipe. For distributed training on multiple Ascend 910 devices, please run ```shell -# distributed training on multiple GPU/Ascend devices +# distributed training on multiple Ascend devices msrun --worker_num=8 --local_worker_num=8 --bind_core=True --log_dir=./yolov7_log python train.py --config ./configs/yolov7/yolov7.yaml --device_target Ascend --is_parallel True ``` -Similarly, you can train the model on multiple GPU devices with the above msrun command. **Note:** For more information about msrun configuration, please refer to [here](https://www.mindspore.cn/tutorials/experts/zh-CN/r2.3.1/parallel/msrun_launcher.html). For detailed illustration of all hyper-parameters, please refer to [config.py](https://github.com/mindspore-lab/mindyolo/blob/master/mindyolo/utils/config.py). @@ -44,7 +43,7 @@ For detailed illustration of all hyper-parameters, please refer to [config.py](h If you want to train or finetune the model on a smaller dataset without distributed training, please run: ```shell -# standalone training on a CPU/GPU/Ascend device +# standalone training on a CPU/Ascend device python train.py --config ./configs/yolov7/yolov7.yaml --device_target Ascend ``` diff --git a/configs/yolov8/README.md b/configs/yolov8/README.md index 46652ab8..02348d33 100644 --- a/configs/yolov8/README.md +++ b/configs/yolov8/README.md @@ -26,11 +26,10 @@ Please refer to the [GETTING_STARTED](https://github.com/mindspore-lab/mindyolo/ It is easy to reproduce the reported results with the pre-defined training recipe. For distributed training on multiple Ascend 910 devices, please run ```shell -# distributed training on multiple GPU/Ascend devices +# distributed training on multiple Ascend devices msrun --worker_num=8 --local_worker_num=8 --bind_core=True --log_dir=./yolov8_log python train.py --config ./configs/yolov8/yolov8n.yaml --device_target Ascend --is_parallel True ``` -Similarly, you can train the model on multiple GPU devices with the above msrun command. **Note:** For more information about msrun configuration, please refer to [here](https://www.mindspore.cn/tutorials/experts/zh-CN/r2.3.1/parallel/msrun_launcher.html). For detailed illustration of all hyper-parameters, please refer to [config.py](https://github.com/mindspore-lab/mindyolo/blob/master/mindyolo/utils/config.py). @@ -42,7 +41,7 @@ For detailed illustration of all hyper-parameters, please refer to [config.py](h If you want to train or finetune the model on a smaller dataset without distributed training, please run: ```shell -# standalone training on a CPU/GPU/Ascend device +# standalone training on a CPU/Ascend device python train.py --config ./configs/yolov8/yolov8n.yaml --device_target Ascend ``` diff --git a/configs/yolov9/README.md b/configs/yolov9/README.md index 2fa06f90..c9c5fb0f 100644 --- a/configs/yolov9/README.md +++ b/configs/yolov9/README.md @@ -56,11 +56,10 @@ Please refer to the [GETTING_STARTED](https://github.com/mindspore-lab/mindyolo/ It is easy to reproduce the reported results with the pre-defined training recipe. For distributed training on multiple Ascend 910 devices, please run ```shell -# distributed training on multiple GPU/Ascend devices +# distributed training on multiple Ascend devices msrun --worker_num=8 --local_worker_num=8 --bind_core=True --log_dir=./yolov9_log python train.py --config ./configs/yolov9/yolov9-t.yaml --device_target Ascend --is_parallel True ``` -Similarly, you can train the model on multiple GPU devices with the above msrun command. **Note:** For more information about msrun configuration, please refer to [here](https://www.mindspore.cn/tutorials/experts/zh-CN/r2.3.1/parallel/msrun_launcher.html). For detailed illustration of all hyper-parameters, please refer to [config.py](https://github.com/mindspore-lab/mindyolo/blob/master/mindyolo/utils/config.py). @@ -72,7 +71,7 @@ For detailed illustration of all hyper-parameters, please refer to [config.py](h If you want to train or finetune the model on a smaller dataset without distributed training, please run: ```shell -# standalone training on a CPU/GPU/Ascend device +# standalone training on a CPU/Ascend device python train.py --config ./configs/yolov9/yolov9-t.yaml --device_target Ascend ``` diff --git a/configs/yolox/README.md b/configs/yolox/README.md index 1ab33411..971210c4 100644 --- a/configs/yolox/README.md +++ b/configs/yolox/README.md @@ -25,11 +25,10 @@ Please refer to the [GETTING_STARTED](https://github.com/mindspore-lab/mindyolo/ It is easy to reproduce the reported results with the pre-defined training recipe. For distributed training on multiple Ascend 910 devices, please run ```shell -# distributed training on multiple GPU/Ascend devices +# distributed training on multiple Ascend devices msrun --worker_num=8 --local_worker_num=8 --bind_core=True --log_dir=./yolox_log python train.py --config ./configs/yolox/yolox-s.yaml --device_target Ascend --is_parallel True ``` -Similarly, you can train the model on multiple GPU devices with the above msrun command. **Note:** For more information about msrun configuration, please refer to [here](https://www.mindspore.cn/tutorials/experts/zh-CN/r2.3.1/parallel/msrun_launcher.html). For detailed illustration of all hyper-parameters, please refer to [config.py](https://github.com/mindspore-lab/mindyolo/blob/master/mindyolo/utils/config.py). @@ -41,7 +40,7 @@ For detailed illustration of all hyper-parameters, please refer to [config.py](h If you want to train or finetune the model on a smaller dataset without distributed training, please firstly run: ```shell -# standalone 1st stage training on a CPU/GPU/Ascend device +# standalone 1st stage training on a CPU/Ascend device python train.py --config ./configs/yolox/yolox-s.yaml --device_target Ascend ``` diff --git a/demo/__init__.py b/demo/__init__.py new file mode 100644 index 00000000..b7744fda --- /dev/null +++ b/demo/__init__.py @@ -0,0 +1,3 @@ +from .predict import detect + +__all__ = ['detect'] \ No newline at end of file diff --git a/demo/predict.py b/demo/predict.py new file mode 100644 index 00000000..872d018f --- /dev/null +++ b/demo/predict.py @@ -0,0 +1,340 @@ +import argparse +import ast +import math +import os +import sys +import time +import cv2 +import numpy as np +import yaml +from datetime import datetime + +import mindspore as ms +from mindspore import Tensor, nn + +from mindyolo.data import COCO80_TO_COCO91_CLASS +from mindyolo.models import create_model +from mindyolo.utils import logger +from mindyolo.utils.config import parse_args +from mindyolo.utils.metrics import non_max_suppression, scale_coords, xyxy2xywh, process_mask_upsample, scale_image +from mindyolo.utils.utils import draw_result, set_seed + + +def get_parser_infer(parents=None): + parser = argparse.ArgumentParser(description="Infer", parents=[parents] if parents else []) + parser.add_argument("--task", type=str, default="detect", choices=["detect", "segment"]) + parser.add_argument("--device_target", type=str, default="Ascend", help="device target, Ascend/GPU/CPU") + parser.add_argument("--ms_mode", type=int, default=0, help="train mode, graph/pynative") + parser.add_argument("--ms_amp_level", type=str, default="O0", help="amp level, O0/O1/O2") + parser.add_argument( + "--ms_enable_graph_kernel", type=ast.literal_eval, default=False, help="use enable_graph_kernel or not" + ) + parser.add_argument( + "--precision_mode", type=str, default=None, help="set accuracy mode of network model" + ) + parser.add_argument("--weight", type=str, default="yolov7_300.ckpt", help="model.ckpt path(s)") + parser.add_argument("--img_size", type=int, default=640, help="inference size (pixels)") + parser.add_argument( + "--single_cls", type=ast.literal_eval, default=False, help="train multi-class data as single-class" + ) + parser.add_argument("--nms_time_limit", type=float, default=60.0, help="time limit for NMS") + parser.add_argument("--conf_thres", type=float, default=0.25, help="object confidence threshold") + parser.add_argument("--iou_thres", type=float, default=0.65, help="IOU threshold for NMS") + parser.add_argument( + "--conf_free", type=ast.literal_eval, default=False, help="Whether the prediction result include conf" + ) + parser.add_argument("--seed", type=int, default=2, help="set global seed") + parser.add_argument("--log_level", type=str, default="INFO", help="save dir") + parser.add_argument("--save_dir", type=str, default="./runs_infer", help="save dir") + + parser.add_argument("--image_path", type=str, help="path to image") + parser.add_argument("--save_result", type=ast.literal_eval, default=True, help="whether save the inference result") + + return parser + + +def set_default_infer(args): + # Set Context + ms.set_context(mode=args.ms_mode, device_target=args.device_target, max_call_depth=2000) + if args.precision_mode is not None: + ms.set_context(ascend_config={"precision_mode":args.precision_mode}) + if args.ms_mode == 0: + ms.set_context(jit_config={"jit_level": "O2"}) + if args.device_target == "Ascend": + ms.set_context(device_id=int(os.getenv("DEVICE_ID", 0))) + elif args.device_target == "GPU" and args.ms_enable_graph_kernel: + ms.set_context(enable_graph_kernel=True) + args.rank, args.rank_size = 0, 1 + # Set Data + args.data.nc = 1 if args.single_cls else int(args.data.nc) # number of classes + args.data.names = ["item"] if args.single_cls and len(args.names) != 1 else args.data.names # class names + assert len(args.data.names) == args.data.nc, "%g names found for nc=%g dataset in %s" % ( + len(args.data.names), + args.data.nc, + args.config, + ) + # Directories and Save run settings + platform = sys.platform + if platform == "win32": + args.save_dir = os.path.join(args.save_dir, datetime.now().strftime("%Y.%m.%d-%H.%M.%S")) + else: + args.save_dir = os.path.join(args.save_dir, datetime.now().strftime("%Y.%m.%d-%H:%M:%S")) + os.makedirs(args.save_dir, exist_ok=True) + if args.rank % args.rank_size == 0: + with open(os.path.join(args.save_dir, "cfg.yaml"), "w") as f: + yaml.dump(vars(args), f, sort_keys=False) + # Set Logger + logger.setup_logging(logger_name="MindYOLO", log_level="INFO", rank_id=args.rank, device_per_servers=args.rank_size) + logger.setup_logging_file(log_dir=os.path.join(args.save_dir, "logs")) + + +def detect( + network: nn.Cell, + img: np.ndarray, + conf_thres: float = 0.25, + iou_thres: float = 0.65, + conf_free: bool = False, + nms_time_limit: float = 60.0, + img_size: int = 640, + stride: int = 32, + num_class: int = 80, + is_coco_dataset: bool = True, +): + # Resize + h_ori, w_ori = img.shape[:2] # orig hw + r = img_size / max(h_ori, w_ori) # resize image to img_size + if r != 1: # always resize down, only resize up if training with augmentation + interp = cv2.INTER_AREA if r < 1 else cv2.INTER_LINEAR + img = cv2.resize(img, (int(w_ori * r), int(h_ori * r)), interpolation=interp) + h, w = img.shape[:2] + if h < img_size or w < img_size: + new_h, new_w = math.ceil(h / stride) * stride, math.ceil(w / stride) * stride + dh, dw = (new_h - h) / 2, (new_w - w) / 2 + top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) + left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) + img = cv2.copyMakeBorder( + img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114) + ) # add border + + # Transpose Norm + img = img[:, :, ::-1].transpose(2, 0, 1) / 255.0 + imgs_tensor = Tensor(img[None], ms.float32) + + # Run infer + _t = time.time() + out = network(imgs_tensor) # inference and training outputs + out = out[0] if isinstance(out, (tuple, list)) else out + infer_times = time.time() - _t + + # Run NMS + t = time.time() + out = out.asnumpy() + out = non_max_suppression( + out, + conf_thres=conf_thres, + iou_thres=iou_thres, + conf_free=conf_free, + multi_label=True, + time_limit=nms_time_limit, + ) + nms_times = time.time() - t + + result_dict = {"category_id": [], "bbox": [], "score": []} + total_category_ids, total_bboxes, total_scores = [], [], [] + for si, pred in enumerate(out): + if len(pred) == 0: + continue + + # Predictions + predn = np.copy(pred) + scale_coords(img.shape[1:], predn[:, :4], (h_ori, w_ori)) # native-space pred + + box = xyxy2xywh(predn[:, :4]) # xywh + box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner + category_ids, bboxes, scores = [], [], [] + for p, b in zip(pred.tolist(), box.tolist()): + category_ids.append(COCO80_TO_COCO91_CLASS[int(p[5])] if is_coco_dataset else int(p[5])) + bboxes.append([round(x, 3) for x in b]) + scores.append(round(p[4], 5)) + + total_category_ids.extend(category_ids) + total_bboxes.extend(bboxes) + total_scores.extend(scores) + + result_dict["category_id"].extend(total_category_ids) + result_dict["bbox"].extend(total_bboxes) + result_dict["score"].extend(total_scores) + + t = tuple(x * 1e3 for x in (infer_times, nms_times, infer_times + nms_times)) + (img_size, img_size, 1) # tuple + logger.info(f"Predict result is: {result_dict}") + logger.info(f"Speed: %.1f/%.1f/%.1f ms inference/NMS/total per %gx%g image at batch-size %g;" % t) + logger.info(f"Detect a image success.") + + return result_dict + + +def segment( + network: nn.Cell, + img: np.ndarray, + conf_thres: float = 0.25, + iou_thres: float = 0.65, + conf_free: bool = False, + nms_time_limit: float = 60.0, + img_size: int = 640, + stride: int = 32, + num_class: int = 80, + is_coco_dataset: bool = True, +): + # Resize + h_ori, w_ori = img.shape[:2] # orig hw + r = img_size / max(h_ori, w_ori) # resize image to img_size + if r != 1: # always resize down, only resize up if training with augmentation + interp = cv2.INTER_AREA if r < 1 else cv2.INTER_LINEAR + img = cv2.resize(img, (int(w_ori * r), int(h_ori * r)), interpolation=interp) + h, w = img.shape[:2] + if h < img_size or w < img_size: + new_h, new_w = math.ceil(h / stride) * stride, math.ceil(w / stride) * stride + dh, dw = (new_h - h) / 2, (new_w - w) / 2 + top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) + left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) + img = cv2.copyMakeBorder( + img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114) + ) # add border + + # Transpose Norm + img = img[:, :, ::-1].transpose(2, 0, 1) / 255.0 + imgs_tensor = Tensor(img[None], ms.float32) + + # Run infer + _t = time.time() + out, (_, _, prototypes) = network(imgs_tensor) # inference and training outputs + infer_times = time.time() - _t + + # Run NMS + t = time.time() + _c = num_class + 4 if conf_free else num_class + 5 + out = out.asnumpy() + bboxes, mask_coefficient = out[:, :, :_c], out[:, :, _c:] + out = non_max_suppression( + bboxes, + mask_coefficient, + conf_thres=conf_thres, + iou_thres=iou_thres, + conf_free=conf_free, + multi_label=True, + time_limit=nms_time_limit, + ) + nms_times = time.time() - t + + prototypes = prototypes.asnumpy() + + result_dict = {"category_id": [], "bbox": [], "score": [], "segmentation": []} + total_category_ids, total_bboxes, total_scores, total_seg = [], [], [], [] + for si, (pred, proto) in enumerate(zip(out, prototypes)): + if len(pred) == 0: + continue + + # Predictions + pred_masks = process_mask_upsample(proto, pred[:, 6:], pred[:, :4], shape=imgs_tensor[si].shape[1:]) + pred_masks = pred_masks.astype(np.float32) + pred_masks = scale_image((pred_masks.transpose(1, 2, 0)), (h_ori, w_ori)) + predn = np.copy(pred) + scale_coords(img.shape[1:], predn[:, :4], (h_ori, w_ori)) # native-space pred + + box = xyxy2xywh(predn[:, :4]) # xywh + box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner + category_ids, bboxes, scores, segs = [], [], [], [] + for ii, (p, b) in enumerate(zip(pred.tolist(), box.tolist())): + category_ids.append(COCO80_TO_COCO91_CLASS[int(p[5])] if is_coco_dataset else int(p[5])) + bboxes.append([round(x, 3) for x in b]) + scores.append(round(p[4], 5)) + segs.append(pred_masks[:, :, ii]) + + total_category_ids.extend(category_ids) + total_bboxes.extend(bboxes) + total_scores.extend(scores) + total_seg.extend(segs) + + result_dict["category_id"].extend(total_category_ids) + result_dict["bbox"].extend(total_bboxes) + result_dict["score"].extend(total_scores) + result_dict["segmentation"].extend(total_seg) + + t = tuple(x * 1e3 for x in (infer_times, nms_times, infer_times + nms_times)) + (img_size, img_size, 1) # tuple + logger.info(f"Predict result is:") + for k, v in result_dict.items(): + if k == "segmentation": + logger.info(f"{k} shape: {v[0].shape}") + else: + logger.info(f"{k}: {v}") + logger.info(f"Speed: %.1f/%.1f/%.1f ms inference/NMS/total per %gx%g image at batch-size %g;" % t) + logger.info(f"Detect a image success.") + + return result_dict + + +def infer(args): + # Init + set_seed(args.seed) + set_default_infer(args) + + # Create Network + network = create_model( + model_name=args.network.model_name, + model_cfg=args.network, + num_classes=args.data.nc, + sync_bn=False, + checkpoint_path=args.weight, + ) + network.set_train(False) + ms.amp.auto_mixed_precision(network, amp_level=args.ms_amp_level) + + # Load Image + if isinstance(args.image_path, str) and os.path.isfile(args.image_path): + import cv2 + img = cv2.imread(args.image_path) + else: + raise ValueError("Detect: input image file not available.") + + # Detect + is_coco_dataset = "coco" in args.data.dataset_name + if args.task == "detect": + result_dict = detect( + network=network, + img=img, + conf_thres=args.conf_thres, + iou_thres=args.iou_thres, + conf_free=args.conf_free, + nms_time_limit=args.nms_time_limit, + img_size=args.img_size, + stride=max(max(args.network.stride), 32), + num_class=args.data.nc, + is_coco_dataset=is_coco_dataset, + ) + if args.save_result: + save_path = os.path.join(args.save_dir, "detect_results") + draw_result(args.image_path, result_dict, args.data.names, is_coco_dataset=is_coco_dataset, save_path=save_path) + elif args.task == "segment": + result_dict = segment( + network=network, + img=img, + conf_thres=args.conf_thres, + iou_thres=args.iou_thres, + conf_free=args.conf_free, + nms_time_limit=args.nms_time_limit, + img_size=args.img_size, + stride=max(max(args.network.stride), 32), + num_class=args.data.nc, + is_coco_dataset=is_coco_dataset, + ) + if args.save_result: + save_path = os.path.join(args.save_dir, "segment_results") + draw_result(args.image_path, result_dict, args.data.names, is_coco_dataset=is_coco_dataset, save_path=save_path) + + logger.info("Infer completed.") + + +if __name__ == "__main__": + parser = get_parser_infer() + args = parse_args(parser) + infer(args) \ No newline at end of file diff --git a/docs/en/modelzoo/yolov3.md b/docs/en/modelzoo/yolov3.md index 66db64b7..776297ab 100644 --- a/docs/en/modelzoo/yolov3.md +++ b/docs/en/modelzoo/yolov3.md @@ -59,11 +59,10 @@ python mindyolo/utils/convert_weight_darknet53.py It is easy to reproduce the reported results with the pre-defined training recipe. For distributed training on multiple Ascend 910 devices, please run ```shell -# distributed training on multiple GPU/Ascend devices +# distributed training on multiple Ascend devices msrun --worker_num=8 --local_worker_num=8 --bind_core=True --log_dir=./yolov3_log python train.py --config ./configs/yolov3/yolov3.yaml --device_target Ascend --is_parallel True ``` -Similarly, you can train the model on multiple GPU devices with the above msrun command. **Note:** For more information about msrun configuration, please refer to [here](https://www.mindspore.cn/tutorials/experts/en/r2.3.1/parallel/msrun_launcher.html). For detailed illustration of all hyper-parameters, please refer to [config.py](https://github.com/mindspore-lab/mindyolo/blob/master/mindyolo/utils/config.py). @@ -75,7 +74,7 @@ For detailed illustration of all hyper-parameters, please refer to [config.py](h If you want to train or finetune the model on a smaller dataset without distributed training, please run: ```shell -# standalone training on a CPU/GPU/Ascend device +# standalone training on a CPU/Ascend device python train.py --config ./configs/yolov3/yolov3.yaml --device_target Ascend ``` diff --git a/docs/en/modelzoo/yolov4.md b/docs/en/modelzoo/yolov4.md index d854d2d0..5f792976 100644 --- a/docs/en/modelzoo/yolov4.md +++ b/docs/en/modelzoo/yolov4.md @@ -73,11 +73,10 @@ python mindyolo/utils/convert_weight_cspdarknet53.py It is easy to reproduce the reported results with the pre-defined training recipe. For distributed training on multiple Ascend 910 devices, please run ```shell -# distributed training on multiple GPU/Ascend devices +# distributed training on multiple Ascend devices msrun --worker_num=8 --local_worker_num=8 --bind_core=True --log_dir=./yolov4_log python train.py --config ./configs/yolov4/yolov4-silu.yaml --device_target Ascend --is_parallel True --epochs 320 ``` -Similarly, you can train the model on multiple GPU devices with the above msrun command. **Note:** For more information about msrun configuration, please refer to [here](https://www.mindspore.cn/tutorials/experts/en/r2.3.1/parallel/msrun_launcher.html). For detailed illustration of all hyper-parameters, please refer to [config.py](https://github.com/mindspore-lab/mindyolo/blob/master/mindyolo/utils/config.py). @@ -94,7 +93,7 @@ multiprocessing/semaphore_tracker.py: 144 UserWarning: semaphore_tracker: There If you want to train or finetune the model on a smaller dataset without distributed training, please run: ```shell -# standalone training on a CPU/GPU/Ascend device +# standalone training on a CPU/Ascend device python train.py --config ./configs/yolov4/yolov4-silu.yaml --device_target Ascend --epochs 320 ``` diff --git a/docs/en/modelzoo/yolov5.md b/docs/en/modelzoo/yolov5.md index b4ff3280..4c79347c 100644 --- a/docs/en/modelzoo/yolov5.md +++ b/docs/en/modelzoo/yolov5.md @@ -53,11 +53,10 @@ Please refer to the [QUICK START](../tutorials/quick_start.md) in MindYOLO for d It is easy to reproduce the reported results with the pre-defined training recipe. For distributed training on multiple Ascend 910 devices, please run ```shell -# distributed training on multiple GPU/Ascend devices +# distributed training on multiple Ascend devices msrun --worker_num=8 --local_worker_num=8 --bind_core=True --log_dir=./yolov5_log python train.py --config ./configs/yolov5/yolov5n.yaml --device_target Ascend --is_parallel True ``` -Similarly, you can train the model on multiple GPU devices with the above msrun command. **Note:** For more information about msrun configuration, please refer to [here](https://www.mindspore.cn/tutorials/experts/en/r2.3.1/parallel/msrun_launcher.html). For detailed illustration of all hyper-parameters, please refer to [config.py](https://github.com/mindspore-lab/mindyolo/blob/master/mindyolo/utils/config.py). @@ -69,7 +68,7 @@ For detailed illustration of all hyper-parameters, please refer to [config.py](h If you want to train or finetune the model on a smaller dataset without distributed training, please run: ```shell -# standalone training on a CPU/GPU/Ascend device +# standalone training on a CPU/Ascend device python train.py --config ./configs/yolov5/yolov5n.yaml --device_target Ascend ``` diff --git a/docs/en/modelzoo/yolov7.md b/docs/en/modelzoo/yolov7.md index 8c46fd9e..742a2e00 100644 --- a/docs/en/modelzoo/yolov7.md +++ b/docs/en/modelzoo/yolov7.md @@ -53,11 +53,10 @@ Please refer to the [QUICK START](../tutorials/quick_start.md) in MindYOLO for d It is easy to reproduce the reported results with the pre-defined training recipe. For distributed training on multiple Ascend 910 devices, please run ```shell -# distributed training on multiple GPU/Ascend devices +# distributed training on multiple Ascend devices msrun --worker_num=8 --local_worker_num=8 --bind_core=True --log_dir=./yolov7_log python train.py --config ./configs/yolov7/yolov7.yaml --device_target Ascend --is_parallel True ``` -Similarly, you can train the model on multiple GPU devices with the above msrun command. **Note:** For more information about msrun configuration, please refer to [here](https://www.mindspore.cn/tutorials/experts/en/r2.3.1/parallel/msrun_launcher.html). For detailed illustration of all hyper-parameters, please refer to [config.py](https://github.com/mindspore-lab/mindyolo/blob/master/mindyolo/utils/config.py). @@ -69,7 +68,7 @@ For detailed illustration of all hyper-parameters, please refer to [config.py](h If you want to train or finetune the model on a smaller dataset without distributed training, please run: ```shell -# standalone training on a CPU/GPU/Ascend device +# standalone training on a CPU/Ascend device python train.py --config ./configs/yolov7/yolov7.yaml --device_target Ascend ``` diff --git a/docs/en/modelzoo/yolov8.md b/docs/en/modelzoo/yolov8.md index babe0771..0aa20899 100644 --- a/docs/en/modelzoo/yolov8.md +++ b/docs/en/modelzoo/yolov8.md @@ -63,11 +63,10 @@ Please refer to the [QUICK START](../tutorials/quick_start.md) in MindYOLO for d It is easy to reproduce the reported results with the pre-defined training recipe. For distributed training on multiple Ascend 910 devices, please run ```shell -# distributed training on multiple GPU/Ascend devices +# distributed training on multiple Ascend devices msrun --worker_num=8 --local_worker_num=8 --bind_core=True --log_dir=./yolov8_log python train.py --config ./configs/yolov8/yolov8n.yaml --device_target Ascend --is_parallel True ``` -Similarly, you can train the model on multiple GPU devices with the above msrun command. **Note:** For more information about msrun configuration, please refer to [here](https://www.mindspore.cn/tutorials/experts/en/r2.3.1/parallel/msrun_launcher.html). For detailed illustration of all hyper-parameters, please refer to [config.py](https://github.com/mindspore-lab/mindyolo/blob/master/mindyolo/utils/config.py). @@ -79,7 +78,7 @@ For detailed illustration of all hyper-parameters, please refer to [config.py](h If you want to train or finetune the model on a smaller dataset without distributed training, please run: ```shell -# standalone training on a CPU/GPU/Ascend device +# standalone training on a CPU/Ascend device python train.py --config ./configs/yolov8/yolov8n.yaml --device_target Ascend ``` diff --git a/docs/en/modelzoo/yolox.md b/docs/en/modelzoo/yolox.md index 7b90ca86..ecb30a98 100644 --- a/docs/en/modelzoo/yolox.md +++ b/docs/en/modelzoo/yolox.md @@ -52,11 +52,10 @@ Please refer to the [QUICK START](../tutorials/quick_start.md) in MindYOLO for d It is easy to reproduce the reported results with the pre-defined training recipe. For distributed training on multiple Ascend 910 devices, please run ```shell -# distributed training on multiple GPU/Ascend devices +# distributed training on multiple Ascend devices msrun --worker_num=8 --local_worker_num=8 --bind_core=True --log_dir=./yolox_log python train.py --config ./configs/yolox/yolox-s.yaml --device_target Ascend --is_parallel True ``` -Similarly, you can train the model on multiple GPU devices with the above msrun command. **Note:** For more information about msrun configuration, please refer to [here](https://www.mindspore.cn/tutorials/experts/en/r2.3.1/parallel/msrun_launcher.html). For detailed illustration of all hyper-parameters, please refer to [config.py](https://github.com/mindspore-lab/mindyolo/blob/master/mindyolo/utils/config.py). @@ -68,7 +67,7 @@ For detailed illustration of all hyper-parameters, please refer to [config.py](h If you want to train or finetune the model on a smaller dataset without distributed training, please firstly run: ```shell -# standalone 1st stage training on a CPU/GPU/Ascend device +# standalone 1st stage training on a CPU/Ascend device python train.py --config ./configs/yolox/yolox-s.yaml --device_target Ascend ``` diff --git a/docs/en/tutorials/configuration.md b/docs/en/tutorials/configuration.md index 997a4f5f..2c087002 100644 --- a/docs/en/tutorials/configuration.md +++ b/docs/en/tutorials/configuration.md @@ -27,7 +27,7 @@ __BASE__: [ ## Basic Parameters ### Parameter Description - - device_target: device used, Ascend/GPU/CPU + - device_target: device used, Ascend/CPU - save_dir: the path to save the running results, the default is ./runs - log_interval: step interval to print logs, the default is 100 - is_parallel: whether to perform distributed training, the default is False diff --git a/docs/en/tutorials/finetune.md b/docs/en/tutorials/finetune.md index 29f8bcc2..febe66f1 100644 --- a/docs/en/tutorials/finetune.md +++ b/docs/en/tutorials/finetune.md @@ -114,13 +114,13 @@ During the process of model fine-tuning, you can first train according to the de * Anchor can be adjusted according to the actual object size Since the SHWD training set only has about 6,000 images, the yolov7-tiny model was selected for training. -* Distributed model training on multi-card NPU/GPU, taking 8 cards as an example: +* Distributed model training on multi-card NPU, taking 8 cards as an example: ```shell msrun --worker_num=8 --local_worker_num=8 --bind_core=True --log_dir=./yolov7-tiny_log python train.py --config ./examples/finetune_SHWD/yolov7-tiny_shwd.yaml --is_parallel True ``` -* Train the model on a single card NPU/GPU/CPU: +* Train the model on a single card NPU/CPU: ```shell python train.py --config ./examples/finetune_SHWD/yolov7-tiny_shwd.yaml diff --git a/docs/en/tutorials/quick_start.md b/docs/en/tutorials/quick_start.md index 0e855e0f..523c15d9 100644 --- a/docs/en/tutorials/quick_start.md +++ b/docs/en/tutorials/quick_start.md @@ -16,9 +16,6 @@ This document provides a brief introduction to the usage of built-in command-lin ```shell # Run with Ascend (By default) python demo/predict.py --config ./configs/yolov7/yolov7.yaml --weight=/path_to_ckpt/WEIGHT.ckpt --image_path /path_to_image/IMAGE.jpg - -# Run with GPU -python demo/predict.py --config ./configs/yolov7/yolov7.yaml --weight=/path_to_ckpt/WEIGHT.ckpt --image_path /path_to_image/IMAGE.jpg --device_target=GPU ``` @@ -51,21 +48,21 @@ to understand their behavior. Some common arguments are: ``` -* To train a model on 8 NPUs/GPUs: +* To train a model on 8 NPUs: ```shell msrun --worker_num=8 --local_worker_num=8 --bind_core=True --log_dir=./yolov7_log python train.py --config ./configs/yolov7/yolov7.yaml --is_parallel True ``` -* To train a model on 1 NPU/GPU/CPU: +* To train a model on 1 NPU/CPU: ```shell python train.py --config ./configs/yolov7/yolov7.yaml ``` -* To evaluate a model's performance on 1 NPU/GPU/CPU: +* To evaluate a model's performance on 1 NPU/CPU: ```shell python test.py --config ./configs/yolov7/yolov7.yaml --weight /path_to_ckpt/WEIGHT.ckpt ``` -* To evaluate a model's performance 8 NPUs/GPUs: +* To evaluate a model's performance 8 NPUs: ```shell msrun --worker_num=8 --local_worker_num=8 --bind_core=True --log_dir=./yolov7_log python test.py --config ./configs/yolov7/yolov7.yaml --weight /path_to_ckpt/WEIGHT.ckpt --is_parallel True ``` @@ -73,7 +70,7 @@ to understand their behavior. Some common arguments are: *(1) The default hyper-parameter is used for 8-card training, and some parameters need to be adjusted in the case of a single card.* -*(2) The default device is Ascend, and you can modify it by specifying 'device_target' as Ascend/GPU/CPU, as these are currently supported.* +*(2) The default device is Ascend, and you can modify it by specifying 'device_target' as Ascend/CPU, as these are currently supported.* *(3) For more options, see `train/test.py -h`.* diff --git a/docs/zh/modelzoo/yolov3.md b/docs/zh/modelzoo/yolov3.md index bd7d7880..a11c4d1e 100644 --- a/docs/zh/modelzoo/yolov3.md +++ b/docs/zh/modelzoo/yolov3.md @@ -58,11 +58,11 @@ python mindyolo/utils/convert_weight_darknet53.py 使用预置的训练配方可以轻松重现报告的结果。如需在多台Ascend 910设备上进行分布式训练,请运行 ```shell -# 在多台GPU/Ascend设备上进行分布式训练 +# 在多台Ascend设备上进行分布式训练 msrun --worker_num=8 --local_worker_num=8 --bind_core=True --log_dir=./yolov3_log python train.py --config ./configs/yolov3/yolov3.yaml --device_target Ascend --is_parallel True ``` -同样的,您可以使用上述msrun命令在多台GPU设备上训练模型。**注意:** 更多关于msrun配置的信息,请参考[这里](https://www.mindspore.cn/tutorials/experts/zh-CN/r2.3.1/parallel/msrun_launcher.html)。 +**注意:** 更多关于msrun配置的信息,请参考[这里](https://www.mindspore.cn/tutorials/experts/zh-CN/r2.3.1/parallel/msrun_launcher.html)。 有关所有超参数的详细说明,请参阅[config.py](https://github.com/mindspore-lab/mindyolo/blob/master/mindyolo/utils/config.py)。 @@ -73,7 +73,7 @@ msrun --worker_num=8 --local_worker_num=8 --bind_core=True --log_dir=./yolov3_lo 如果您想在较小的数据集上训练或微调模型而不进行分布式训练,请运行: ```shell -# 在 CPU/GPU/Ascend 设备上进行单卡训练 +# 在 CPU/Ascend 设备上进行单卡训练 python train.py --config ./configs/yolov3/yolov3.yaml --device_target Ascend ``` diff --git a/docs/zh/modelzoo/yolov4.md b/docs/zh/modelzoo/yolov4.md index 864c3cda..614fae97 100644 --- a/docs/zh/modelzoo/yolov4.md +++ b/docs/zh/modelzoo/yolov4.md @@ -66,11 +66,11 @@ python mindyolo/utils/convert_weight_cspdarknet53.py 使用预置的训练配方可以轻松重现报告的结果。如需在多台Ascend 910设备上进行分布式训练,请运行 ```shell -# distributed training on multiple GPU/Ascend devices +# distributed training on multiple Ascend devices msrun --worker_num=8 --local_worker_num=8 --bind_core=True --log_dir=./yolov4_log python train.py --config ./configs/yolov4/yolov4-silu.yaml --device_target Ascend --is_parallel True --epochs 320 ``` -同样的,您可以使用上述msrun命令在多台GPU设备上训练模型。**注意:** 更多关于msrun配置的信息,请参考[这里](https://www.mindspore.cn/tutorials/experts/zh-CN/r2.3.1/parallel/msrun_launcher.html)。 +**注意:** 更多关于msrun配置的信息,请参考[这里](https://www.mindspore.cn/tutorials/experts/zh-CN/r2.3.1/parallel/msrun_launcher.html)。 有关所有超参数的详细说明,请参阅[config.py](https://github.com/mindspore-lab/mindyolo/blob/master/mindyolo/utils/config.py)。 @@ -86,7 +86,7 @@ multiprocessing/semaphore_tracker.py: 144 UserWarning: semaphore_tracker: There 如果您想在较小的数据集上训练或微调模型而不进行分布式训练,请运行: ```shell -# 在 CPU/GPU/Ascend 设备上进行单卡训练 +# 在 CPU/Ascend 设备上进行单卡训练 python train.py --config ./configs/yolov4/yolov4-silu.yaml --device_target Ascend --epochs 320 ``` diff --git a/docs/zh/modelzoo/yolov5.md b/docs/zh/modelzoo/yolov5.md index d6b2b025..ecf1ddbf 100644 --- a/docs/zh/modelzoo/yolov5.md +++ b/docs/zh/modelzoo/yolov5.md @@ -51,11 +51,11 @@ YOLOv5 是在 COCO 数据集上预训练的一系列对象检测架构和模型 使用预置的训练配方可以轻松重现报告的结果。如需在多台Ascend 910设备上进行分布式训练,请运行 ```shell -# 在多台GPU/Ascend设备上进行分布式训练 +# 在多台Ascend设备上进行分布式训练 msrun --worker_num=8 --local_worker_num=8 --bind_core=True --log_dir=./yolov5_log python train.py --config ./configs/yolov5/yolov5n.yaml --device_target Ascend --is_parallel True ``` -同样的,您可以使用上述msrun命令在多台GPU设备上训练模型。**注意:** 更多关于msrun配置的信息,请参考[这里](https://www.mindspore.cn/tutorials/experts/zh-CN/r2.3.1/parallel/msrun_launcher.html)。 +**注意:** 更多关于msrun配置的信息,请参考[这里](https://www.mindspore.cn/tutorials/experts/zh-CN/r2.3.1/parallel/msrun_launcher.html)。 有关所有超参数的详细说明,请参阅[config.py](https://github.com/mindspore-lab/mindyolo/blob/master/mindyolo/utils/config.py)。 @@ -66,7 +66,7 @@ msrun --worker_num=8 --local_worker_num=8 --bind_core=True --log_dir=./yolov5_lo 如果您想在较小的数据集上训练或微调模型而不进行分布式训练,请运行: ```shell -# 在 CPU/GPU/Ascend 设备上进行单卡训练 +# 在 CPU/Ascend 设备上进行单卡训练 python train.py --config ./configs/yolov5/yolov5n.yaml --device_target Ascend ``` diff --git a/docs/zh/modelzoo/yolov7.md b/docs/zh/modelzoo/yolov7.md index 83464a8e..fbf2e03a 100644 --- a/docs/zh/modelzoo/yolov7.md +++ b/docs/zh/modelzoo/yolov7.md @@ -52,11 +52,11 @@ YOLOv7在5FPS到 160 FPS 范围内的速度和准确度都超过了所有已知 使用预置的训练配方可以轻松重现报告的结果。如需在多台Ascend 910设备上进行分布式训练,请运行 ```shell -# 在多台GPU/Ascend设备上进行分布式训练 +# 在多台Ascend设备上进行分布式训练 msrun --worker_num=8 --local_worker_num=8 --bind_core=True --log_dir=./yolov7_log python train.py --config ./configs/yolov7/yolov7.yaml --device_target Ascend --is_parallel True ``` -同样的,您可以使用上述msrun命令在多台GPU设备上训练模型。**注意:** 更多关于msrun配置的信息,请参考[这里](https://www.mindspore.cn/tutorials/experts/zh-CN/r2.3.1/parallel/msrun_launcher.html)。 +**注意:** 更多关于msrun配置的信息,请参考[这里](https://www.mindspore.cn/tutorials/experts/zh-CN/r2.3.1/parallel/msrun_launcher.html)。 有关所有超参数的详细说明,请参阅[config.py](https://github.com/mindspore-lab/mindyolo/blob/master/mindyolo/utils/config.py)。 @@ -67,7 +67,7 @@ msrun --worker_num=8 --local_worker_num=8 --bind_core=True --log_dir=./yolov7_lo 如果您想在较小的数据集上训练或微调模型而不进行分布式训练,请运行: ```shell -# 在 CPU/GPU/Ascend 设备上进行单卡训练 +# 在 CPU/Ascend 设备上进行单卡训练 python train.py --config ./configs/yolov7/yolov7.yaml --device_target Ascend ``` diff --git a/docs/zh/modelzoo/yolov8.md b/docs/zh/modelzoo/yolov8.md index 1f5852eb..b9a43188 100644 --- a/docs/zh/modelzoo/yolov8.md +++ b/docs/zh/modelzoo/yolov8.md @@ -63,11 +63,11 @@ Ultralytics YOLOv8 由 Ultralytics 开发,是一款尖端的、最先进的 (S 使用预置的训练配方可以轻松重现报告的结果。如需在多台Ascend 910设备上进行分布式训练,请运行 ```shell -# 在多台GPU/Ascend设备上进行分布式训练 +# 在多台Ascend设备上进行分布式训练 msrun --worker_num=8 --local_worker_num=8 --bind_core=True --log_dir=./yolov8_log python train.py --config ./configs/yolov8/yolov8n.yaml --device_target Ascend --is_parallel True ``` -同样的,您可以使用上述msrun命令在多台GPU设备上训练模型。**注意:** 更多关于msrun配置的信息,请参考[这里](https://www.mindspore.cn/tutorials/experts/zh-CN/r2.3.1/parallel/msrun_launcher.html)。 +**注意:** 更多关于msrun配置的信息,请参考[这里](https://www.mindspore.cn/tutorials/experts/zh-CN/r2.3.1/parallel/msrun_launcher.html)。 有关所有超参数的详细说明,请参阅[config.py](https://github.com/mindspore-lab/mindyolo/blob/master/mindyolo/utils/config.py)。 @@ -78,7 +78,7 @@ msrun --worker_num=8 --local_worker_num=8 --bind_core=True --log_dir=./yolov8_lo 如果您想在较小的数据集上训练或微调模型而不进行分布式训练,请运行: ```shell -# 在 CPU/GPU/Ascend 设备上进行单卡训练 +# 在 CPU/Ascend 设备上进行单卡训练 python train.py --config ./configs/yolov8/yolov8n.yaml --device_target Ascend ``` diff --git a/docs/zh/modelzoo/yolox.md b/docs/zh/modelzoo/yolox.md index f770d277..6c6f7692 100644 --- a/docs/zh/modelzoo/yolox.md +++ b/docs/zh/modelzoo/yolox.md @@ -52,11 +52,11 @@ YOLOX 是一款新型高性能检测模型,在 YOLO 系列的基础上进行 使用预置的训练配方可以轻松重现报告的结果。如需在多台Ascend 910设备上进行分布式训练,请运行 ```shell -# 在多台GPU/Ascend设备上进行分布式训练 +# 在多台Ascend设备上进行分布式训练 msrun --worker_num=8 --local_worker_num=8 --bind_core=True --log_dir=./yolox_log python train.py --config ./configs/yolox/yolox-s.yaml --device_target Ascend --is_parallel True ``` -同样的,您可以使用上述msrun命令在多台GPU设备上训练模型。**注意:** 更多关于msrun配置的信息,请参考[这里](https://www.mindspore.cn/tutorials/experts/zh-CN/r2.3.1/parallel/msrun_launcher.html)。 +**注意:** 更多关于msrun配置的信息,请参考[这里](https://www.mindspore.cn/tutorials/experts/zh-CN/r2.3.1/parallel/msrun_launcher.html)。 有关所有超参数的详细说明,请参阅[config.py](https://github.com/mindspore-lab/mindyolo/blob/master/mindyolo/utils/config.py)。 @@ -67,7 +67,7 @@ msrun --worker_num=8 --local_worker_num=8 --bind_core=True --log_dir=./yolox_log 如果您想在较小的数据集上训练或微调模型而不进行分布式训练,请运行: ```shell -# 在 CPU/GPU/Ascend 设备上进行单卡训练 +# 在 CPU/Ascend 设备上进行单卡训练 python train.py --config ./configs/yolox/yolox-s.yaml --device_target Ascend ``` diff --git a/docs/zh/tutorials/configuration.md b/docs/zh/tutorials/configuration.md index e42b8d2e..42f53206 100644 --- a/docs/zh/tutorials/configuration.md +++ b/docs/zh/tutorials/configuration.md @@ -26,7 +26,7 @@ __BASE__: [ ## 基础参数 ### 参数说明 - - device_target: 所用设备,Ascend/GPU/CPU + - device_target: 所用设备,Ascend/CPU - save_dir: 运行结果保存路径,默认为./runs - log_interval: 打印日志step间隔,默认为100 - is_parallel: 是否分布式训练,默认为False diff --git a/docs/zh/tutorials/finetune.md b/docs/zh/tutorials/finetune.md index 24bc02f8..003d004b 100644 --- a/docs/zh/tutorials/finetune.md +++ b/docs/zh/tutorials/finetune.md @@ -115,13 +115,13 @@ optimizer: * anchor可根据实际物体大小进行调整 由于SHWD训练集只有约6000张图片,选用yolov7-tiny模型进行训练。 -* 在多卡NPU/GPU上进行分布式模型训练,以8卡为例: +* 在多卡NPU上进行分布式模型训练,以8卡为例: ```shell msrun --worker_num=8 --local_worker_num=8 --bind_core=True --log_dir=./yolov7-tiny_log python train.py --config ./examples/finetune_SHWD/yolov7-tiny_shwd.yaml --is_parallel True ``` -* 在单卡NPU/GPU/CPU上训练模型: +* 在单卡NPU/CPU上训练模型: ```shell python train.py --config ./examples/finetune_SHWD/yolov7-tiny_shwd.yaml diff --git a/docs/zh/tutorials/quick_start.md b/docs/zh/tutorials/quick_start.md index 6d5a6d19..efc06ac6 100644 --- a/docs/zh/tutorials/quick_start.md +++ b/docs/zh/tutorials/quick_start.md @@ -13,9 +13,6 @@ ```shell # NPU (默认) python demo/predict.py --config ./configs/yolov7/yolov7.yaml --weight=/path_to_ckpt/WEIGHT.ckpt --image_path /path_to_image/IMAGE.jpg - -# GPU -python demo/predict.py --config ./configs/yolov7/yolov7.yaml --weight=/path_to_ckpt/WEIGHT.ckpt --image_path /path_to_image/IMAGE.jpg --device_target=GPU ``` 有关命令行参数的详细信息,请参阅`demo/predict.py -h`,或查看其[源代码](https://github.com/mindspore-lab/mindyolo/blob/master/deploy/predict.py)。 @@ -47,24 +44,24 @@ python demo/predict.py --config ./configs/yolov7/yolov7.yaml --weight=/path_to_c ``` -* 在多卡NPU/GPU上进行分布式模型训练,以8卡为例: +* 在多卡NPU上进行分布式模型训练,以8卡为例: ```shell msrun --worker_num=8 --local_worker_num=8 --bind_core=True --log_dir=./yolov7_log python train.py --config ./configs/yolov7/yolov7.yaml --is_parallel True ``` -* 在单卡NPU/GPU/CPU上训练模型: +* 在单卡NPU/CPU上训练模型: ```shell python train.py --config ./configs/yolov7/yolov7.yaml ``` -* 在单卡NPU/GPU/CPU上评估模型的精度: +* 在单卡NPU/CPU上评估模型的精度: ```shell python test.py --config ./configs/yolov7/yolov7.yaml --weight /path_to_ckpt/WEIGHT.ckpt ``` -* 在多卡NPU/GPU上进行分布式评估模型的精度: +* 在多卡NPU上进行分布式评估模型的精度: ```shell msrun --worker_num=8 --local_worker_num=8 --bind_core=True --log_dir=./yolov7_log python test.py --config ./configs/yolov7/yolov7.yaml --weight /path_to_ckpt/WEIGHT.ckpt --is_parallel True @@ -74,7 +71,7 @@ python demo/predict.py --config ./configs/yolov7/yolov7.yaml --weight=/path_to_c *(1) 默认超参为8卡训练,单卡情况需调整部分参数。* -*(2) 默认设备为Ascend,您可以指定'device_target'的值为Ascend/GPU/CPU。* +*(2) 默认设备为Ascend,您可以指定'device_target'的值为Ascend/CPU。* *(3) 有关更多选项,请参阅 `train/test.py -h`。* diff --git a/examples/finetune_SHWD/README.md b/examples/finetune_SHWD/README.md index 8f32a3b8..055cf7f1 100644 --- a/examples/finetune_SHWD/README.md +++ b/examples/finetune_SHWD/README.md @@ -117,13 +117,13 @@ optimizer: * anchor可根据实际物体大小进行调整 由于SHWD训练集只有约6000张图片,选用yolov7-tiny模型进行训练。 -* 在多卡NPU/GPU上进行分布式模型训练,以8卡为例: +* 在多卡NPU上进行分布式模型训练,以8卡为例: ```shell msrun --worker_num=8 --local_worker_num=8 --bind_core=True --log_dir=./yolov7-tiny_log python train.py --config ./examples/finetune_SHWD/yolov7-tiny_shwd.yaml --is_parallel True ``` -* 在单卡NPU/GPU/CPU上训练模型: +* 在单卡NPU/CPU上训练模型: ```shell python train.py --config ./examples/finetune_SHWD/yolov7-tiny_shwd.yaml diff --git a/examples/finetune_car_detection/README.md b/examples/finetune_car_detection/README.md index e568dd27..3c5b1e20 100644 --- a/examples/finetune_car_detection/README.md +++ b/examples/finetune_car_detection/README.md @@ -75,12 +75,12 @@ MindYOLO支持yaml文件继承机制,因此新编写的配置文件只需要 这里要注意的是云平台的模型存放路径较以前有变化,若是找不到路径可以利用云平台提供的C2Net库得到预训练模型路径,在train.py中修改weight和ckpt_url参数的值。 也可以选择在终端用命令行进行训练: -* 在多卡NPU/GPU上进行分布式模型训练,以8卡为例: +* 在多卡NPU上进行分布式模型训练,以8卡为例: ```shell msrun --worker_num=8 --local_worker_num=8 --bind_core=True --log_dir=./yolov7-tiny_log python train.py --config ./yolov7-tiny_ud.yaml --is_parallel True ``` -* 在单卡NPU/GPU/CPU上训练模型: +* 在单卡NPU/CPU上训练模型: ```shell python train.py --config ./yolov7-tiny_ud.yaml ``` @@ -88,12 +88,12 @@ MindYOLO支持yaml文件继承机制,因此新编写的配置文件只需要 由于yolov7-tiny模型比较小,直接用单卡训练就可以满足需求,所以这里选择用单卡进行训练。 ### yolov7-tiny的最终精度: 保存训练得到的权重参数的ckpt文件,用来测试精度和推理。 -* 在单卡NPU/GPU/CPU上评估模型的精度: +* 在单卡NPU/CPU上评估模型的精度: ```shell python test.py --config ./yolov7-tiny_ud.yaml --weight /path_to_ckpt/WEIGHT.ckpt ``` -* 在多卡NPU/GPU上进行分布式评估模型的精度: +* 在多卡NPU上进行分布式评估模型的精度: ```shell msrun --worker_num=8 --local_worker_num=8 --bind_core=True --log_dir=./yolov7-tiny_log python test.py --config ./yolov7-tiny_ud.yaml --weight /path_to_ckpt/WEIGHT.ckpt --is_parallel True @@ -120,8 +120,6 @@ MindYOLO支持yaml文件继承机制,因此新编写的配置文件只需要 # NPU (默认) python demo/predict.py --config ./yolov7l_ud.yaml --weight=/path_to_ckpt/WEIGHT.ckpt --image_path /path_to_image/IMAGE.jpg -# GPU -python demo/predict.py --config ./yolov7l_ud.yaml --weight=/path_to_ckpt/WEIGHT.ckpt --image_path /path_to_image/IMAGE.jpg --device_target=GPU ``` predict predict2 diff --git a/examples/finetune_single_class_dataset/README.md b/examples/finetune_single_class_dataset/README.md index 3f3cb3e3..bdc40808 100644 --- a/examples/finetune_single_class_dataset/README.md +++ b/examples/finetune_single_class_dataset/README.md @@ -51,7 +51,7 @@ optimizer: ``` #### 模型训练 选用yolov8n模型进行训练。 -* 在多卡NPU/GPU上进行分布式模型训练,以8卡为例: +* 在多卡NPU上进行分布式模型训练,以8卡为例: ```shell mpirun --allow-run-as-root -n 8 python train.py --config ./examples/finetune_single_class_dataset/yolov8n_single_class_dataset.yaml --is_parallel True diff --git a/tutorials/configuration_CN.md b/tutorials/configuration_CN.md index 18aa4150..6a1caa4c 100644 --- a/tutorials/configuration_CN.md +++ b/tutorials/configuration_CN.md @@ -24,7 +24,7 @@ __BASE__: [ ## 基础参数 ### 参数说明 - - device_target: 所用设备,Ascend/GPU/CPU + - device_target: 所用设备,Ascend/CPU - save_dir: 运行结果保存路径,默认为./runs - log_interval: 打印日志step间隔,默认为100 - is_parallel: 是否分布式训练,默认为False