Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement ConditionallySpeculatable for ConvolutionOp #2228

Merged
merged 4 commits into from
Apr 17, 2024

Conversation

mlevesquedion
Copy link
Contributor

Relevant constraints from the spec:

(C10) dim(lhs, input_batch_dimension) % batch_group_count = 0.
(C11) dim(lhs, input_feature_dimension) % feature_group_count = 0.

(C14) dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count.
(C15) dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0.
(C16) dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0.

(C25) dim(result, result_dim) is defined as:
* dim(lhs, input_batch_dimension) / batch_group_count if result_dim = output_batch_dimension.
* dim(rhs, kernel_output_feature_dimension) if result_dim = output_feature_dimension.
* num_windows otherwise, where:
* output_spatial_dimensions[spatial_dim] = result_dim.
* lhs_dim = input_spatial_dimensions[spatial_dim].
* rhs_dim = kernel_spatial_dimensions[spatial_dim].
* dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1.
* padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1].
* dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1.
* is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim].
* num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1.

Because of (C14), input_feature_dimension and kernel_input_feature_dimension must be static. input_batch_dimension must be static if batch_group_count > 1 (C10) or if output_batch_dimension is static (C25, first bullet). kernel_output_feature_dimension must be static if batch_group_count > 1 (C15) or feature_group_count > 1 (C16) or if output_feature_dimension is static (C25, second bullet).

Because of (C25), each spatial dimension in the output can depend on the spatial dimensions in the inputs (input + kernel), so if it is static in the output, it must be static in the inputs, otherwise mismatches could occur at runtime.

@mlevesquedion mlevesquedion marked this pull request as draft April 17, 2024 17:59
@mlevesquedion
Copy link
Contributor Author

Converting to draft while I fix conflicts.

Relevant constraints from the spec:

```
(C10) dim(lhs, input_batch_dimension) % batch_group_count = 0.
(C11) dim(lhs, input_feature_dimension) % feature_group_count = 0.

(C14) dim(rhs, kernel_input_feature_dimension) = dim(lhs, input_feature_dimension) / feature_group_count.
(C15) dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0.
(C16) dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0.

(C25) dim(result, result_dim) is defined as:
* dim(lhs, input_batch_dimension) / batch_group_count if result_dim = output_batch_dimension.
* dim(rhs, kernel_output_feature_dimension) if result_dim = output_feature_dimension.
* num_windows otherwise, where:
* output_spatial_dimensions[spatial_dim] = result_dim.
* lhs_dim = input_spatial_dimensions[spatial_dim].
* rhs_dim = kernel_spatial_dimensions[spatial_dim].
* dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1.
* padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1].
* dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1.
* is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim].
* num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1.
```

Because of (C14), input_feature_dimension and kernel_input_feature_dimension
must be static. input_batch_dimension must be static if batch_group_count > 1
(C10) or if output_batch_dimension is static (C25, first bullet).
kernel_output_feature_dimension must be static if batch_group_count > 1 (C15) or
feature_group_count > 1 (C16) or if output_feature_dimension is static
(C25, second bullet).

Because of (C25), each spatial dimension in the output can depend on the spatial
dimensions in the inputs (input + kernel), so if it is static in the output, it
must be static in the inputs, otherwise mismatches could occur at runtime.
@mlevesquedion mlevesquedion marked this pull request as ready for review April 17, 2024 18:01
@mlevesquedion
Copy link
Contributor Author

Fixed the conflicts, this is ready for review.

stablehlo/dialect/StablehloOps.cpp Outdated Show resolved Hide resolved
@mlevesquedion mlevesquedion merged commit 19e3142 into openxla:main Apr 17, 2024
10 checks passed
@mlevesquedion mlevesquedion deleted the conv-cs branch April 17, 2024 19:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants