Skip to content

support Resnet50 on devices #413

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from
2 changes: 0 additions & 2 deletions Vision/classification/image/resnet50/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ bash examples/train_graph_distributed_fp16.sh
Train resnet50 with graph mode.
--use-fp16
Whether to enable amp training.
--use-gpu-decode
Use gpu to decode the data packed in ofrecord, only supported in graph mode.
--scale-grad
Whether to scale gradient when training in fp32 with graph mode.
--skip-eval
Expand Down
30 changes: 24 additions & 6 deletions Vision/classification/image/resnet50/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ def parse_args(ignore_unknown_args=False):
parser = argparse.ArgumentParser(
description="OneFlow ResNet50 Arguments", allow_abbrev=False
)
parser.add_argument("--device", type=str, default="cuda", help="device: cpu, cuda...")
parser.add_argument(
"--data-loading-device",
type=str,
default="cuda",
choices=["cpu", "cuda"],
help="Specify the device for data loading: 'cpu' or 'cuda' (default: 'cuda')."
)
parser.add_argument(
"--save",
type=str,
Expand Down Expand Up @@ -60,12 +68,6 @@ def parse_args(ignore_unknown_args=False):
dest="ofrecord_part_num",
help="ofrecord data part number",
)
parser.add_argument(
"--use-gpu-decode",
action="store_true",
dest="use_gpu_decode",
help="Use gpu decode.",
)
parser.add_argument(
"--synthetic-data",
action="store_true",
Expand All @@ -86,6 +88,22 @@ def parse_args(ignore_unknown_args=False):
dest="fuse_bn_add_relu",
help="Whether to use use fuse batch_normalization, add and relu.",
)
parser.add_argument(
"--disable-fuse-add-to-output",
action="store_false",
dest="fuse_add_to_output",
help="Disable fusion of the add operation into the output (enabled by default). \n"
"For more details, see `graph_config.py` in the OneFlow repository: \n"
"https://github.com/Oneflow-Inc/oneflow",
)
parser.add_argument(
"--disable-fuse-model-update-ops",
action="store_false",
dest="fuse_model_update_ops",
help="Disable fusion of the model update operations (enabled by default). \n"
"For more details, see `graph_config.py` in the OneFlow repository: \n"
"https://github.com/Oneflow-Inc/oneflow",
)

# training hyper-parameters
parser.add_argument(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,5 @@ python3 $SRC_DIR/train.py \
--save $CHECKPOINT_SAVE_PATH \
--samples-per-epoch 50 \
--val-samples-per-epoch 50 \
--use-gpu-decode \
--scale-grad \
--graph \
Original file line number Diff line number Diff line change
Expand Up @@ -50,5 +50,4 @@ python3 -m oneflow.distributed.launch \
--metric-train-acc True \
--fuse-bn-relu \
--fuse-bn-add-relu \
--use-gpu-decode \
--channel-last \
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ python3 -m oneflow.distributed.launch \
--num-epochs $EPOCH \
--train-batch-size $TRAIN_BATCH_SIZE \
--val-batch-size $VAL_BATCH_SIZE \
--use-gpu-decode \
--scale-grad \
--graph \
--fuse-bn-relu \
Expand Down
16 changes: 9 additions & 7 deletions Vision/classification/image/resnet50/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def __init__(
elif args.scale_grad:
self.set_grad_scaler(make_static_grad_scaler())

self.config.allow_fuse_add_to_output(True)
self.config.allow_fuse_model_update_ops(True)
self.config.allow_fuse_add_to_output(args.fuse_add_to_output)
self.config.allow_fuse_model_update_ops(args.fuse_model_update_ops)

# Disable cudnn_conv_heuristic_search_algo will open dry-run.
# Dry-run is better with single device, but has no effect with multiple device.
Expand All @@ -51,11 +51,12 @@ def __init__(
self.cross_entropy = cross_entropy
self.data_loader = data_loader
self.add_optimizer(optimizer, lr_sch=lr_scheduler)
self.device = args.device

def build(self):
image, label = self.data_loader()
image = image.to("cuda")
label = label.to("cuda")
image = image.to(self.device)
label = label.to(self.device)
logits = self.model(image)
loss = self.cross_entropy(logits, label)
if self.return_pred_and_label:
Expand All @@ -75,15 +76,16 @@ def __init__(self, model, data_loader):
if args.use_fp16:
self.config.enable_amp(True)

self.config.allow_fuse_add_to_output(True)
self.config.allow_fuse_add_to_output(args.fuse_add_to_output)

self.data_loader = data_loader
self.model = model
self.device = args.device

def build(self):
image, label = self.data_loader()
image = image.to("cuda")
label = label.to("cuda")
image = image.to(self.device)
label = label.to(self.device)
logits = self.model(image)
pred = logits.softmax()
return pred, label
10 changes: 8 additions & 2 deletions Vision/classification/image/resnet50/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ def _parse_args():
dest="image_path",
help="input image path",
)
parser.add_argument(
"--device", type=str, default="cuda", choices=["cuda", "cpu", "npu"], help="device"
)
parser.add_argument("--graph", action="store_true", help="Run model in graph mode.")
return parser.parse_args()

Expand All @@ -52,10 +55,13 @@ def build(self, image):
def main(args):
start_t = time.perf_counter()

if args.device == "npu":
import oneflow_npu

print("***** Model Init *****")
model = resnet50()
model.load_state_dict(flow.load(args.model_path))
model = model.to("cuda")
model = model.to(args.device)
model.eval()
end_t = time.perf_counter()
print(f"***** Model Init Finish, time escapled {end_t - start_t:.6f} s *****")
Expand All @@ -65,7 +71,7 @@ def main(args):

start_t = end_t
image = load_image(args.image_path)
image = flow.Tensor(image, device=flow.device("cuda"))
image = flow.Tensor(image, device=flow.device(args.device))
if args.graph:
pred = model_graph(image)
else:
Expand Down
22 changes: 12 additions & 10 deletions Vision/classification/image/resnet50/models/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ def make_data_loader(args, mode, is_global=False, synthetic=False):
placement=placement,
sbp=sbp,
channel_last=args.channel_last,
device=args.device,
)
return data_loader.to("cuda")
return data_loader.to(args.device)

ofrecord_data_loader = OFRecordDataLoader(
ofrecord_dir=args.ofrecord_path,
Expand All @@ -44,7 +45,7 @@ def make_data_loader(args, mode, is_global=False, synthetic=False):
channel_last=args.channel_last,
placement=placement,
sbp=sbp,
use_gpu_decode=args.use_gpu_decode,
device=args.data_loading_device,
)
return ofrecord_data_loader

Expand All @@ -61,7 +62,7 @@ def __init__(
channel_last=False,
placement=None,
sbp=None,
use_gpu_decode=False,
device="cuda",
):
super().__init__()

Expand All @@ -71,6 +72,7 @@ def __init__(
self.total_batch_size = total_batch_size
self.dataset_size = dataset_size
self.mode = mode
self.device = device

random_shuffle = True if mode == "train" else False
shuffle_after_epoch = True if mode == "train" else False
Expand Down Expand Up @@ -101,9 +103,8 @@ def __init__(
rgb_mean = [123.68, 116.779, 103.939]
rgb_std = [58.393, 57.12, 57.375]

self.use_gpu_decode = use_gpu_decode
if self.mode == "train":
if self.use_gpu_decode:
if self.device == "cuda":
self.bytesdecoder_img = flow.nn.OFRecordBytesDecoder("encoded")
self.image_decoder = flow.nn.OFRecordImageGpuDecoderRandomCropResize(
target_width=image_width,
Expand Down Expand Up @@ -153,17 +154,17 @@ def __len__(self):
def forward(self):
if self.mode == "train":
record = self.ofrecord_reader()
if self.use_gpu_decode:
if self.device == "cuda":
encoded = self.bytesdecoder_img(record)
image = self.image_decoder(encoded)
else:
image_raw_bytes = self.image_decoder(record)
image = self.resize(image_raw_bytes)[0]
image = image.to("cuda")

label = self.label_decoder(record)
flip_code = self.flip()
flip_code = flip_code.to("cuda")
if self.device == "cuda":
flip_code = flip_code.to(self.device)
image = self.crop_mirror_norm(image, flip_code)
else:
record = self.ofrecord_reader()
Expand All @@ -184,6 +185,7 @@ def __init__(
placement=None,
sbp=None,
channel_last=False,
device="cuda",
):
super().__init__()

Expand Down Expand Up @@ -220,10 +222,10 @@ def __init__(
)
else:
self.image = flow.randint(
0, high=256, size=self.image_shape, dtype=flow.float32, device="cuda"
0, high=256, size=self.image_shape, dtype=flow.float32, device=device,
)
self.label = flow.randint(
0, high=self.num_classes, size=self.label_shape, device="cuda",
0, high=self.num_classes, size=self.label_shape, device=device,
).to(dtype=flow.int32)

def forward(self):
Expand Down
3 changes: 2 additions & 1 deletion Vision/classification/image/resnet50/models/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,5 +83,6 @@ def forward(self, input, label):
# log_prob = input.softmax(dim=-1).log()
# onehot_label = flow.F.cast(onehot_label, log_prob.dtype)
# loss = flow.mul(log_prob * -1, onehot_label).sum(dim=-1).mean()
loss = flow._C.softmax_cross_entropy(input, onehot_label.to(dtype=input.dtype))
#loss = flow._C.softmax_cross_entropy(input, onehot_label.to(dtype=input.dtype))
loss = flow._C.cross_entropy(input, onehot_label.to(dtype=input.dtype), reduction='none')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

loss这块确定要改吗😂(cross_entropy内部貌似包含2个oplog_softmaxnll,可能效率不一定有softmax_cross_entropy好)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

临时改的,为了能跑通,等志鹏那个开发好了,就改回来。

return loss.mean()
20 changes: 13 additions & 7 deletions Vision/classification/image/resnet50/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@
class Trainer(object):
def __init__(self):
args = get_args()
self.device = args.device.lower()
if self.device == "npu":
import oneflow_npu
elif self.device == "xpu":
import oneflow_xpu

for k, v in args.__dict__.items():
setattr(self, k, v)

Expand Down Expand Up @@ -89,12 +95,12 @@ def init_model(self):
start_t = time.perf_counter()

if self.is_global:
placement = flow.env.all_device_placement("cuda")
placement = flow.env.all_device_placement(self.device)
self.model = self.model.to_global(
placement=placement, sbp=flow.sbp.broadcast
)
else:
self.model = self.model.to("cuda")
self.model = self.model.to(self.device)

if self.load_path is None:
self.legacy_init_parameters()
Expand Down Expand Up @@ -276,7 +282,7 @@ def train_eager(self):
param.grad /= self.world_size
else:
loss.backward()
loss = loss / self.world_size
#loss = loss / self.world_size

self.optimizer.step()
self.optimizer.zero_grad()
Expand Down Expand Up @@ -311,8 +317,8 @@ def eval(self):

def forward(self):
image, label = self.train_data_loader()
image = image.to("cuda")
label = label.to("cuda")
image = image.to(self.device)
label = label.to(self.device)
logits = self.model(image)
loss = self.cross_entropy(logits, label)
if self.metric_train_acc:
Expand All @@ -323,8 +329,8 @@ def forward(self):

def inference(self):
image, label = self.val_data_loader()
image = image.to("cuda")
label = label.to("cuda")
image = image.to(self.device)
label = label.to(self.device)
with flow.no_grad():
logits = self.model(image)
pred = logits.softmax()
Expand Down