Skip to content

Commit

Permalink
Merge pull request #299 from mindspore-lab/revert-296-master
Browse files Browse the repository at this point in the history
Revert "use GE backend for graph mode"
  • Loading branch information
CaitinZhao authored Jun 18, 2024
2 parents e69d1ef + 77b0719 commit 1ba1ba8
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 16 deletions.
3 changes: 0 additions & 3 deletions demo/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

import mindspore as ms
from mindspore import Tensor, context, nn
from mindspore._c_expression import ms_ctx_param

from mindyolo.data import COCO80_TO_COCO91_CLASS
from mindyolo.models import create_model
Expand Down Expand Up @@ -54,8 +53,6 @@ def get_parser_infer(parents=None):
def set_default_infer(args):
# Set Context
context.set_context(mode=args.ms_mode, device_target=args.device_target, max_call_depth=2000)
if "jit_config" in ms_ctx_param.__members__ and args.mode == 0:
ms.set_context(jit_config={"jit_level": "O2"})
if args.device_target == "Ascend":
context.set_context(device_id=int(os.getenv("DEVICE_ID", 0)))
elif args.device_target == "GPU" and args.ms_enable_graph_kernel:
Expand Down
17 changes: 7 additions & 10 deletions mindyolo/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
import numpy as np

import mindspore as ms
from mindspore import ops, Tensor, nn
from mindspore import context, ops, Tensor, nn
from mindspore.communication.management import get_group_size, get_rank, init
from mindspore import ParallelMode
from mindspore._c_expression import ms_ctx_param
from mindspore.context import ParallelMode

from mindyolo.utils import logger

Expand All @@ -22,24 +21,22 @@ def set_seed(seed=2):

def set_default(args):
# Set Context
ms.set_context(mode=args.ms_mode, device_target=args.device_target, max_call_depth=2000)
if "jit_config" in ms_ctx_param.__members__ and args.mode == 0:
ms.set_context(jit_config={"jit_level": "O2"})
context.set_context(mode=args.ms_mode, device_target=args.device_target, max_call_depth=2000)
if args.device_target == "Ascend":
device_id = int(os.getenv("DEVICE_ID", 0))
ms.set_context(device_id=device_id)
context.set_context(device_id=device_id)
elif args.device_target == "GPU" and args.ms_enable_graph_kernel:
ms.set_context(enable_graph_kernel=True)
context.set_context(enable_graph_kernel=True)
# Set Parallel
if args.is_parallel:
init()
args.rank, args.rank_size, parallel_mode = get_rank(), get_group_size(), ParallelMode.DATA_PARALLEL
ms.set_auto_parallel_context(device_num=args.rank_size, parallel_mode=parallel_mode, gradients_mean=True)
context.set_auto_parallel_context(device_num=args.rank_size, parallel_mode=parallel_mode, gradients_mean=True)
else:
args.rank, args.rank_size = 0, 1
# Set Default
args.total_batch_size = args.per_batch_size * args.rank_size
args.sync_bn = args.sync_bn and ms.get_context("device_target") == "Ascend" and args.rank_size > 1
args.sync_bn = args.sync_bn and context.get_context("device_target") == "Ascend" and args.rank_size > 1
args.accumulate = max(1, np.round(args.nbs / args.total_batch_size)) if args.auto_accumulate else args.accumulate
# optimizer
args.optimizer.warmup_epochs = args.optimizer.get("warmup_epochs", 0)
Expand Down
3 changes: 0 additions & 3 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import mindspore as ms
from mindspore import Tensor, context, nn, ParallelMode
from mindspore.communication import init, get_rank, get_group_size
from mindspore._c_expression import ms_ctx_param

from mindyolo.data import COCO80_TO_COCO91_CLASS, COCODataset, create_loader
from mindyolo.models.model_factory import create_model
Expand Down Expand Up @@ -72,8 +71,6 @@ def get_parser_test(parents=None):
def set_default_test(args):
# Set Context
context.set_context(mode=args.ms_mode, device_target=args.device_target, max_call_depth=2000)
if "jit_config" in ms_ctx_param.__members__ and args.mode == 0:
ms.set_context(jit_config={"jit_level": "O2"})
if args.device_target == "Ascend":
context.set_context(device_id=int(os.getenv("DEVICE_ID", 0)))
elif args.device_target == "GPU" and args.ms_enable_graph_kernel:
Expand Down

0 comments on commit 1ba1ba8

Please sign in to comment.