Skip to content

Commit

Permalink
Change to support zero length tensor specifically in format_transform.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Nov 7, 2023
1 parent e921d4d commit 8a11643
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 10 deletions.
12 changes: 6 additions & 6 deletions lib/nnc/ccv_cnnp_model_addons.c
Original file line number Diff line number Diff line change
Expand Up @@ -396,21 +396,21 @@ static void _ccv_cnnp_concat_build(ccv_cnnp_model_t* const super, ccv_nnc_symbol
// Format transform is more flexible.
ccv_nnc_graph_exec_symbol_new(graph, CMD_FORMAT_TRANSFORM_FORWARD(), inputs, input_size, aliases, input_size, "concat");
} else {
int contiguous_input_size = 0;
ccv_nnc_tensor_symbol_t contiguous_inputs[input_size];
ccv_nnc_tensor_symbol_t aliases[input_size];
for (i = 0; i < input_size; i++)
{
const ccv_nnc_tensor_param_t input_params = ccv_nnc_tensor_symbol_params(graph, inputs[i]);
if (input_params.dim[0] == 0)
{
// Create a new alias anyway, but not going to use it, in this way, the alias count will match during absorb.
aliases[i] = ccv_nnc_tensor_symbol_alias_new(graph, outputs[0], ofs, stride, input_params, 0);
continue;
contiguous_inputs[contiguous_input_size] = inputs[i];
aliases[contiguous_input_size] = ccv_nnc_tensor_symbol_alias_new(graph, outputs[0], ofs, stride, input_params, 0);
}
aliases[i] = ccv_nnc_tensor_symbol_alias_new(graph, outputs[0], ofs, stride, input_params, 0);
ofs[axis] += input_params.dim[axis];
++contiguous_input_size;
}
// Format transform is more flexible.
ccv_nnc_graph_exec_symbol_new(graph, CMD_FORMAT_TRANSFORM_FORWARD(), contiguous_inputs, contiguous_input_size, aliases, contiguous_input_size, "concat");
ccv_nnc_graph_exec_symbol_new(graph, CMD_FORMAT_TRANSFORM_FORWARD(), inputs, input_size, aliases, input_size, "concat");
}
}

Expand Down
10 changes: 6 additions & 4 deletions lib/nnc/ccv_nnc_dynamic_graph.c
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,9 @@ ccv_nnc_tensor_t* ccv_nnc_tensor_from_variable_impl(ccv_nnc_dynamic_graph_t* con
if (ccv_nnc_is_tensor_auto(tensor_variable->info))
return 0;
void* ptr = 0;
if (CCV_TENSOR_GET_MEMORY(tensor_variable->info.type) == CCV_TENSOR_GPU_MEMORY)
ptr = ccv_nnc_xpu_alloc(&graph->xpu_alloc, CCV_TENSOR_GET_DEVICE_ID(tensor_variable->info.type), stream_context, ccv_nnc_tensor_data_size(tensor_variable->info));
const size_t data_size = ccv_nnc_tensor_data_size(tensor_variable->info);
if (CCV_TENSOR_GET_MEMORY(tensor_variable->info.type) == CCV_TENSOR_GPU_MEMORY && data_size > 0)
ptr = ccv_nnc_xpu_alloc(&graph->xpu_alloc, CCV_TENSOR_GET_DEVICE_ID(tensor_variable->info.type), stream_context, data_size);
tensor_variable->tensor_view = (ccv_nnc_tensor_view_t*)ccv_nnc_tensor_new(ptr, tensor_variable->info, 0);
if (tensor_variable->info.dim[0] > 0)
{ assert(tensor_variable->tensor_view->data.u8); }
Expand All @@ -335,8 +336,9 @@ ccv_nnc_tensor_t* ccv_nnc_tensor_from_variable_impl(ccv_nnc_dynamic_graph_t* con
return 0;
void* ptr = 0;
assert(variable_to->info.type == tensor_variable->info.type);
if (CCV_TENSOR_GET_MEMORY(variable_to->info.type) == CCV_TENSOR_GPU_MEMORY)
ptr = ccv_nnc_xpu_alloc(&graph->xpu_alloc, CCV_TENSOR_GET_DEVICE_ID(variable_to->info.type), stream_context, ccv_nnc_tensor_data_size(variable_to->info));
const size_t data_size = ccv_nnc_tensor_data_size(variable_to->info);
if (CCV_TENSOR_GET_MEMORY(variable_to->info.type) == CCV_TENSOR_GPU_MEMORY && data_size > 0)
ptr = ccv_nnc_xpu_alloc(&graph->xpu_alloc, CCV_TENSOR_GET_DEVICE_ID(variable_to->info.type), stream_context, data_size);
variable_to->tensor_view = (ccv_nnc_tensor_view_t*)ccv_nnc_tensor_new(ptr, variable_to->info, 0);
assert(variable_to->tensor_view->data.u8);
}
Expand Down
2 changes: 2 additions & 0 deletions lib/nnc/cmd/util/ccv_nnc_util_cpu_ref.c
Original file line number Diff line number Diff line change
Expand Up @@ -1003,6 +1003,8 @@ static int _ccv_nnc_format_transform(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint
ccv_nnc_tensor_view_t* b = (ccv_nnc_tensor_view_t*)outputs[i];
assert(a != b); // Cannot do inplace transform.
assert(a->info.datatype == b->info.datatype);
if (a->info.dim[0] == 0 || b->info.dim[0] == 0)
continue;
if (a->info.datatype == CCV_32F || a->info.datatype == CCV_32S)
{
if (a->info.format == b->info.format) {
Expand Down
2 changes: 2 additions & 0 deletions lib/nnc/cmd/util/gpu/ccv_nnc_util_gpu_cudnn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ static int _ccv_nnc_format_transform(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint
cudnnHandle_t cudnn = ccv_nnc_stream_context_get_cudnn(stream_context);
for (i = 0; i < output_size; i++)
{
if (inputs[i]->info.dim[0] == 0 || outputs[i]->info.dim[0] == 0)
continue;
const ccv_nnc_cudnn_tensor_view_descriptor_t a = ccv_nnc_cudnn_get_tensor_view_descriptor(stream_context, (const ccv_nnc_tensor_view_t*)inputs[i]);
const ccv_nnc_cudnn_tensor_view_descriptor_t b = ccv_nnc_cudnn_get_tensor_view_descriptor(stream_context, (const ccv_nnc_tensor_view_t*)outputs[i]);
assert(inputs[i]->info.datatype == outputs[i]->info.datatype);
Expand Down
2 changes: 2 additions & 0 deletions lib/nnc/cmd/util/mps/ccv_nnc_util_mps.m
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,8 @@ static int _ccv_nnc_format_transform(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint
memcpy(bt.info.dim, bdim, sizeof(bdim));
memcpy(bt.stride, bstride, sizeof(bstride));
}
if (adim[0] == 0 || bdim[0] == 0)
continue;
MPSGraphTensor* mps_input_a;
MPSGraphTensor* mps_a = ccv_nnc_mps_graph_tensor_input(graph, a, adim, astride, &mps_input_a);
MPSGraphTensorData* data_a = ccv_nnc_mps_graph_tensor_data(a, adim, astride);
Expand Down

0 comments on commit 8a11643

Please sign in to comment.