Skip to content

Commit

Permalink
use GE backend for ms2.3
Browse files Browse the repository at this point in the history
  • Loading branch information
Ash-Lee233 committed Jun 18, 2024
1 parent 1ba1ba8 commit dcb4375
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 16 deletions.
10 changes: 6 additions & 4 deletions demo/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from datetime import datetime

import mindspore as ms
from mindspore import Tensor, context, nn
from mindspore import Tensor, nn

from mindyolo.data import COCO80_TO_COCO91_CLASS
from mindyolo.models import create_model
Expand Down Expand Up @@ -52,11 +52,13 @@ def get_parser_infer(parents=None):

def set_default_infer(args):
# Set Context
context.set_context(mode=args.ms_mode, device_target=args.device_target, max_call_depth=2000)
ms.set_context(mode=args.ms_mode, device_target=args.device_target, max_call_depth=2000)
if args.ms_mode == 0:
ms.set_context(jit_config={"jit_level": "O2"})
if args.device_target == "Ascend":
context.set_context(device_id=int(os.getenv("DEVICE_ID", 0)))
ms.set_context(device_id=int(os.getenv("DEVICE_ID", 0)))
elif args.device_target == "GPU" and args.ms_enable_graph_kernel:
context.set_context(enable_graph_kernel=True)
ms.set_context(enable_graph_kernel=True)
args.rank, args.rank_size = 0, 1
# Set Data
args.data.nc = 1 if args.single_cls else int(args.data.nc) # number of classes
Expand Down
16 changes: 9 additions & 7 deletions mindyolo/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import numpy as np

import mindspore as ms
from mindspore import context, ops, Tensor, nn
from mindspore import ops, Tensor, nn
from mindspore.communication.management import get_group_size, get_rank, init
from mindspore.context import ParallelMode
from mindspore import ParallelMode

from mindyolo.utils import logger

Expand All @@ -21,22 +21,24 @@ def set_seed(seed=2):

def set_default(args):
# Set Context
context.set_context(mode=args.ms_mode, device_target=args.device_target, max_call_depth=2000)
ms.set_context(mode=args.ms_mode, device_target=args.device_target, max_call_depth=2000)
if args.ms_mode == 0:
ms.set_context(jit_config={"jit_level": "O2"})
if args.device_target == "Ascend":
device_id = int(os.getenv("DEVICE_ID", 0))
context.set_context(device_id=device_id)
ms.set_context(device_id=device_id)
elif args.device_target == "GPU" and args.ms_enable_graph_kernel:
context.set_context(enable_graph_kernel=True)
ms.set_context(enable_graph_kernel=True)
# Set Parallel
if args.is_parallel:
init()
args.rank, args.rank_size, parallel_mode = get_rank(), get_group_size(), ParallelMode.DATA_PARALLEL
context.set_auto_parallel_context(device_num=args.rank_size, parallel_mode=parallel_mode, gradients_mean=True)
ms.set_auto_parallel_context(device_num=args.rank_size, parallel_mode=parallel_mode, gradients_mean=True)
else:
args.rank, args.rank_size = 0, 1
# Set Default
args.total_batch_size = args.per_batch_size * args.rank_size
args.sync_bn = args.sync_bn and context.get_context("device_target") == "Ascend" and args.rank_size > 1
args.sync_bn = args.sync_bn and ms.get_context("device_target") == "Ascend" and args.rank_size > 1
args.accumulate = max(1, np.round(args.nbs / args.total_batch_size)) if args.auto_accumulate else args.accumulate
# optimizer
args.optimizer.warmup_epochs = args.optimizer.get("warmup_epochs", 0)
Expand Down
12 changes: 7 additions & 5 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from pycocotools.mask import encode

import mindspore as ms
from mindspore import Tensor, context, nn, ParallelMode
from mindspore import Tensor, nn, ParallelMode
from mindspore.communication import init, get_rank, get_group_size

from mindyolo.data import COCO80_TO_COCO91_CLASS, COCODataset, create_loader
Expand Down Expand Up @@ -70,16 +70,18 @@ def get_parser_test(parents=None):

def set_default_test(args):
# Set Context
context.set_context(mode=args.ms_mode, device_target=args.device_target, max_call_depth=2000)
ms.set_context(mode=args.ms_mode, device_target=args.device_target, max_call_depth=2000)
if args.ms_mode == 0:
ms.set_context(jit_config={"jit_level": "O2"})
if args.device_target == "Ascend":
context.set_context(device_id=int(os.getenv("DEVICE_ID", 0)))
ms.set_context(device_id=int(os.getenv("DEVICE_ID", 0)))
elif args.device_target == "GPU" and args.ms_enable_graph_kernel:
context.set_context(enable_graph_kernel=True)
ms.set_context(enable_graph_kernel=True)
# Set Parallel
if args.is_parallel:
init()
args.rank, args.rank_size, parallel_mode = get_rank(), get_group_size(), ParallelMode.DATA_PARALLEL
context.set_auto_parallel_context(device_num=args.rank_size, parallel_mode=parallel_mode)
ms.set_auto_parallel_context(device_num=args.rank_size, parallel_mode=parallel_mode)
else:
args.rank, args.rank_size = 0, 1
# Set Data
Expand Down

0 comments on commit dcb4375

Please sign in to comment.