Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove some allow(unsafe_op_in_unsafe_fn)s and use target_feature 1.1 in examples #1727

Merged
merged 1 commit into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion crates/std_detect/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
#![feature(staged_api, doc_cfg, allow_internal_unstable)]
#![deny(rust_2018_idioms)]
#![allow(clippy::shadow_reuse)]
#![allow(unsafe_op_in_unsafe_fn)]
#![cfg_attr(test, allow(unused_imports))]
#![no_std]
#![allow(internal_features)]
Expand Down
94 changes: 54 additions & 40 deletions examples/connect5.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
//! each move.

#![allow(internal_features)]
#![allow(unsafe_op_in_unsafe_fn)]
#![feature(avx512_target_feature)]
#![cfg_attr(target_arch = "x86", feature(stdarch_x86_avx512, stdarch_internal))]
#![cfg_attr(target_arch = "x86_64", feature(stdarch_x86_avx512, stdarch_internal))]
Expand Down Expand Up @@ -419,12 +418,12 @@ fn pos_is_draw(pos: &Pos) -> bool {
found && !pos_is_winner(pos)
}

#[target_feature(enable = "avx512f,avx512bw")]
#[target_feature(enable = "avx512f,avx512bw,popcnt")]
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
unsafe fn pos_is_draw_avx512(pos: &Pos) -> bool {
fn pos_is_draw_avx512(pos: &Pos) -> bool {
let empty = Color::Empty as usize;

let board0org = _mm512_loadu_epi32(&pos.bitboard[empty][0][0]);
let board0org = unsafe { _mm512_loadu_epi32(&pos.bitboard[empty][0][0]) };

let answer = _mm512_set1_epi32(0);

Expand Down Expand Up @@ -481,7 +480,7 @@ fn search(pos: &Pos, alpha: i32, beta: i32, depth: i32, _ply: i32) -> i32 {

#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if is_x86_feature_detected!("avx512bw") {
if check_x86_avx512_features() {
unsafe {
if pos_is_winner_avx512(pos) {
return -EVAL_INF + _ply;
Expand Down Expand Up @@ -571,7 +570,7 @@ fn eval(pos: &Pos, _ply: i32) -> i32 {
// check if opp has live4 which will win playing next move
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if is_x86_feature_detected!("avx512bw") {
if check_x86_avx512_features() {
unsafe {
if check_patternlive4_avx512(pos, def) {
return -4096;
Expand All @@ -594,7 +593,7 @@ fn eval(pos: &Pos, _ply: i32) -> i32 {
// check if self has live4 which will win playing next move
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if is_x86_feature_detected!("avx512bw") {
if check_x86_avx512_features() {
unsafe {
if check_patternlive4_avx512(pos, atk) {
return 2560;
Expand All @@ -617,7 +616,7 @@ fn eval(pos: &Pos, _ply: i32) -> i32 {
// check if self has dead4 which will win playing next move
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if is_x86_feature_detected!("avx512bw") {
if check_x86_avx512_features() {
unsafe {
if check_patterndead4_avx512(pos, atk) > 0 {
return 2560;
Expand All @@ -639,7 +638,7 @@ fn eval(pos: &Pos, _ply: i32) -> i32 {

#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if is_x86_feature_detected!("avx512bw") {
if check_x86_avx512_features() {
unsafe {
let n_c4: i32 = check_patterndead4_avx512(pos, def);
let n_c3: i32 = check_patternlive3_avx512(pos, def);
Expand Down Expand Up @@ -854,16 +853,18 @@ fn check_patternlive3(pos: &Pos, sd: Side) -> i32 {
n
}

#[target_feature(enable = "avx512f,avx512bw")]
#[target_feature(enable = "avx512f,avx512bw,popcnt")]
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
unsafe fn pos_is_winner_avx512(pos: &Pos) -> bool {
fn pos_is_winner_avx512(pos: &Pos) -> bool {
let current_side = side_opp(pos.p_turn);
let coloridx = current_side as usize;

let board0org: [__m512i; 2] = [
_mm512_loadu_epi32(&pos.bitboard[coloridx][0][0]),
_mm512_loadu_epi32(&pos.bitboard[coloridx][1][0]),
]; // load states from bitboard
let board0org: [__m512i; 2] = unsafe {
[
_mm512_loadu_epi32(&pos.bitboard[coloridx][0][0]),
_mm512_loadu_epi32(&pos.bitboard[coloridx][1][0]),
]
}; // load states from bitboard

#[rustfmt::skip]
let answer = _mm512_set1_epi16((1<<15)|(1<<14)|(1<<13)|(1<<12)|(1<<11)); // an unbroken chain of five moves
Expand Down Expand Up @@ -928,9 +929,9 @@ unsafe fn pos_is_winner_avx512(pos: &Pos) -> bool {
count_match > 0
}

#[target_feature(enable = "avx512f,avx512bw")]
#[target_feature(enable = "avx512f,avx512bw,popcnt")]
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
unsafe fn check_patternlive4_avx512(pos: &Pos, sd: Side) -> bool {
fn check_patternlive4_avx512(pos: &Pos, sd: Side) -> bool {
let coloridx = sd as usize;
let emptyidx = Color::Empty as usize;

Expand All @@ -952,14 +953,18 @@ unsafe fn check_patternlive4_avx512(pos: &Pos, sd: Side) -> bool {
0b00_10_10_11_11_11_11_11_10_10_10_10_10_11_11_10,
0b00_10_10_10_11_11_11_10_10_10_10_10_11_11_11_10,
0b00_10_10_10_10_11_10_10_10_10_10_11_11_11_11_10];
let board0org: [__m512i; 2] = [
_mm512_loadu_epi32(&pos.bitboard[coloridx][0][0]),
_mm512_loadu_epi32(&pos.bitboard[coloridx][1][0]),
];
let board1org: [__m512i; 2] = [
_mm512_loadu_epi32(&pos.bitboard[emptyidx][0][0]),
_mm512_loadu_epi32(&pos.bitboard[emptyidx][1][0]),
];
let board0org: [__m512i; 2] = unsafe {
[
_mm512_loadu_epi32(&pos.bitboard[coloridx][0][0]),
_mm512_loadu_epi32(&pos.bitboard[coloridx][1][0]),
]
};
let board1org: [__m512i; 2] = unsafe {
[
_mm512_loadu_epi32(&pos.bitboard[emptyidx][0][0]),
_mm512_loadu_epi32(&pos.bitboard[emptyidx][1][0]),
]
};

let mut count_match: i32 = 0;

Expand Down Expand Up @@ -990,9 +995,9 @@ unsafe fn check_patternlive4_avx512(pos: &Pos, sd: Side) -> bool {
count_match > 0
}

#[target_feature(enable = "avx512f,avx512bw")]
#[target_feature(enable = "avx512f,avx512bw,popcnt")]
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
unsafe fn check_patterndead4_avx512(pos: &Pos, sd: Side) -> i32 {
fn check_patterndead4_avx512(pos: &Pos, sd: Side) -> i32 {
let coloridx = sd as usize;
let emptyidx = Color::Empty as usize;

Expand Down Expand Up @@ -1023,14 +1028,18 @@ unsafe fn check_patterndead4_avx512(pos: &Pos, sd: Side) -> i32 {
0b00_10_10_11_11_11_11_11_10_10_10_10_11_11_11_10,
0b00_10_10_10_11_11_11_10_10_10_10_11_11_11_11_10,
0b00_10_10_10_10_11_10_10_10_10_11_11_11_11_11_10];
let board0org: [__m512i; 2] = [
_mm512_loadu_epi32(&pos.bitboard[coloridx][0][0]),
_mm512_loadu_epi32(&pos.bitboard[coloridx][1][0]),
];
let board1org: [__m512i; 2] = [
_mm512_loadu_epi32(&pos.bitboard[emptyidx][0][0]),
_mm512_loadu_epi32(&pos.bitboard[emptyidx][1][0]),
];
let board0org: [__m512i; 2] = unsafe {
[
_mm512_loadu_epi32(&pos.bitboard[coloridx][0][0]),
_mm512_loadu_epi32(&pos.bitboard[coloridx][1][0]),
]
};
let board1org: [__m512i; 2] = unsafe {
[
_mm512_loadu_epi32(&pos.bitboard[emptyidx][0][0]),
_mm512_loadu_epi32(&pos.bitboard[emptyidx][1][0]),
]
};

let mut count_match: i32 = 0;

Expand Down Expand Up @@ -1063,16 +1072,16 @@ unsafe fn check_patterndead4_avx512(pos: &Pos, sd: Side) -> i32 {
count_match
}

#[target_feature(enable = "avx512f,avx512bw")]
#[target_feature(enable = "avx512f,avx512bw,popcnt")]
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
unsafe fn check_patternlive3_avx512(pos: &Pos, sd: Side) -> i32 {
fn check_patternlive3_avx512(pos: &Pos, sd: Side) -> i32 {
let coloridx = sd as usize;
let emptyidx = Color::Empty as usize;

#[rustfmt::skip]
let board0org: [__m512i; 2] = [_mm512_loadu_epi32(&pos.bitboard[coloridx][0][0]), _mm512_loadu_epi32(&pos.bitboard[coloridx][1][0])];
let board0org: [__m512i; 2] = unsafe { [_mm512_loadu_epi32(&pos.bitboard[coloridx][0][0]), _mm512_loadu_epi32(&pos.bitboard[coloridx][1][0])] };
#[rustfmt::skip]
let board1org: [__m512i; 2] = [_mm512_loadu_epi32(&pos.bitboard[emptyidx][0][0]), _mm512_loadu_epi32(&pos.bitboard[emptyidx][1][0])];
let board1org: [__m512i; 2] = unsafe { [_mm512_loadu_epi32(&pos.bitboard[emptyidx][0][0]), _mm512_loadu_epi32(&pos.bitboard[emptyidx][1][0])] };

#[rustfmt::skip]
let answer_color: [__m512i; 1] = [_mm512_set1_epi16( (1<<14)|(1<<13)|(1<<12) )];
Expand Down Expand Up @@ -1170,10 +1179,15 @@ unsafe fn check_patternlive3_avx512(pos: &Pos, sd: Side) -> i32 {
count_match
}

#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
fn check_x86_avx512_features() -> bool {
is_x86_feature_detected!("avx512bw") && is_x86_feature_detected!("popcnt")
}

fn main() {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if is_x86_feature_detected!("avx512bw") {
if check_x86_avx512_features() {
println!("\n\nThe program is running with avx512f and avx512bw intrinsics\n\n");
} else {
println!("\n\nThe program is running with NO intrinsics.\n\n");
Expand Down
68 changes: 46 additions & 22 deletions examples/hex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
clippy::cast_sign_loss,
clippy::missing_docs_in_private_items
)]
#![allow(unsafe_op_in_unsafe_fn)]

use std::{
io::{self, Read},
Expand Down Expand Up @@ -67,7 +66,7 @@ fn hex_encode<'a>(src: &[u8], dst: &'a mut [u8]) -> Result<&'a str, usize> {
#[cfg(target_arch = "wasm32")]
{
if true {
return unsafe { hex_encode_simd128(src, dst) };
return hex_encode_simd128(src, dst);
}
}

Expand All @@ -76,15 +75,18 @@ fn hex_encode<'a>(src: &[u8], dst: &'a mut [u8]) -> Result<&'a str, usize> {

#[target_feature(enable = "avx2")]
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
unsafe fn hex_encode_avx2<'a>(mut src: &[u8], dst: &'a mut [u8]) -> Result<&'a str, usize> {
fn hex_encode_avx2<'a>(mut src: &[u8], dst: &'a mut [u8]) -> Result<&'a str, usize> {
assert!(dst.len() >= src.len().checked_mul(2).unwrap());

let ascii_zero = _mm256_set1_epi8(b'0' as i8);
let nines = _mm256_set1_epi8(9);
let ascii_a = _mm256_set1_epi8((b'a' - 9 - 1) as i8);
let and4bits = _mm256_set1_epi8(0xf);

let mut i = 0_usize;
while src.len() >= 32 {
let invec = _mm256_loadu_si256(src.as_ptr() as *const _);
// SAFETY: the loop condition ensures that we have at least 32 bytes
let invec = unsafe { _mm256_loadu_si256(src.as_ptr() as *const _) };

let masked1 = _mm256_and_si256(invec, and4bits);
let masked2 = _mm256_and_si256(_mm256_srli_epi64(invec, 4), and4bits);
Expand All @@ -102,34 +104,43 @@ unsafe fn hex_encode_avx2<'a>(mut src: &[u8], dst: &'a mut [u8]) -> Result<&'a s
let res2 = _mm256_unpackhi_epi8(masked2, masked1);

// Store everything into the right destination now
let base = dst.as_mut_ptr().add(i * 2);
let base1 = base.add(0) as *mut _;
let base2 = base.add(16) as *mut _;
let base3 = base.add(32) as *mut _;
let base4 = base.add(48) as *mut _;
_mm256_storeu2_m128i(base3, base1, res1);
_mm256_storeu2_m128i(base4, base2, res2);
unsafe {
// SAFETY: the assertion at the beginning of the function ensures
// that `dst` is large enough.
let base = dst.as_mut_ptr().add(i * 2);
let base1 = base.add(0) as *mut _;
let base2 = base.add(16) as *mut _;
let base3 = base.add(32) as *mut _;
let base4 = base.add(48) as *mut _;
_mm256_storeu2_m128i(base3, base1, res1);
_mm256_storeu2_m128i(base4, base2, res2);
}

src = &src[32..];
i += 32;
}

let _ = hex_encode_sse41(src, &mut dst[i * 2..]);

Ok(str::from_utf8_unchecked(&dst[..src.len() * 2 + i * 2]))
// SAFETY: `dst` only contains ASCII characters
unsafe { Ok(str::from_utf8_unchecked(&dst[..src.len() * 2 + i * 2])) }
}

// copied from https://github.com/Matherunner/bin2hex-sse/blob/master/base16_sse4.cpp
#[target_feature(enable = "sse4.1")]
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
unsafe fn hex_encode_sse41<'a>(mut src: &[u8], dst: &'a mut [u8]) -> Result<&'a str, usize> {
fn hex_encode_sse41<'a>(mut src: &[u8], dst: &'a mut [u8]) -> Result<&'a str, usize> {
assert!(dst.len() >= src.len().checked_mul(2).unwrap());

let ascii_zero = _mm_set1_epi8(b'0' as i8);
let nines = _mm_set1_epi8(9);
let ascii_a = _mm_set1_epi8((b'a' - 9 - 1) as i8);
let and4bits = _mm_set1_epi8(0xf);

let mut i = 0_usize;
while src.len() >= 16 {
let invec = _mm_loadu_si128(src.as_ptr() as *const _);
// SAFETY: the loop condition ensures that we have at least 16 bytes
let invec = unsafe { _mm_loadu_si128(src.as_ptr() as *const _) };

let masked1 = _mm_and_si128(invec, and4bits);
let masked2 = _mm_and_si128(_mm_srli_epi64(invec, 4), and4bits);
Expand All @@ -146,20 +157,27 @@ unsafe fn hex_encode_sse41<'a>(mut src: &[u8], dst: &'a mut [u8]) -> Result<&'a
let res1 = _mm_unpacklo_epi8(masked2, masked1);
let res2 = _mm_unpackhi_epi8(masked2, masked1);

_mm_storeu_si128(dst.as_mut_ptr().add(i * 2) as *mut _, res1);
_mm_storeu_si128(dst.as_mut_ptr().add(i * 2 + 16) as *mut _, res2);
unsafe {
// SAFETY: the assertion at the beginning of the function ensures
// that `dst` is large enough.
_mm_storeu_si128(dst.as_mut_ptr().add(i * 2) as *mut _, res1);
_mm_storeu_si128(dst.as_mut_ptr().add(i * 2 + 16) as *mut _, res2);
}
src = &src[16..];
i += 16;
}

let _ = hex_encode_fallback(src, &mut dst[i * 2..]);

Ok(str::from_utf8_unchecked(&dst[..src.len() * 2 + i * 2]))
// SAFETY: `dst` only contains ASCII characters
unsafe { Ok(str::from_utf8_unchecked(&dst[..src.len() * 2 + i * 2])) }
}

#[cfg(target_arch = "wasm32")]
#[target_feature(enable = "simd128")]
unsafe fn hex_encode_simd128<'a>(mut src: &[u8], dst: &'a mut [u8]) -> Result<&'a str, usize> {
fn hex_encode_simd128<'a>(mut src: &[u8], dst: &'a mut [u8]) -> Result<&'a str, usize> {
assert!(dst.len() >= src.len().checked_mul(2).unwrap());

use core_arch::arch::wasm32::*;

let ascii_zero = u8x16_splat(b'0');
Expand All @@ -169,7 +187,8 @@ unsafe fn hex_encode_simd128<'a>(mut src: &[u8], dst: &'a mut [u8]) -> Result<&'

let mut i = 0_usize;
while src.len() >= 16 {
let invec = v128_load(src.as_ptr() as *const _);
// SAFETY: the loop condition ensures that we have at least 16 bytes
let invec = unsafe { v128_load(src.as_ptr() as *const _) };

let masked1 = v128_and(invec, and4bits);
let masked2 = v128_and(u8x16_shr(invec, 4), and4bits);
Expand All @@ -193,15 +212,20 @@ unsafe fn hex_encode_simd128<'a>(mut src: &[u8], dst: &'a mut [u8]) -> Result<&'
masked2, masked1,
);

v128_store(dst.as_mut_ptr().add(i * 2) as *mut _, res1);
v128_store(dst.as_mut_ptr().add(i * 2 + 16) as *mut _, res2);
unsafe {
// SAFETY: the assertion at the beginning of the function ensures
// that `dst` is large enough.
v128_store(dst.as_mut_ptr().add(i * 2) as *mut _, res1);
v128_store(dst.as_mut_ptr().add(i * 2 + 16) as *mut _, res2);
}
src = &src[16..];
i += 16;
}

let _ = hex_encode_fallback(src, &mut dst[i * 2..]);

Ok(str::from_utf8_unchecked(&dst[..src.len() * 2 + i * 2]))
// SAFETY: `dst` only contains ASCII characters
unsafe { Ok(str::from_utf8_unchecked(&dst[..src.len() * 2 + i * 2])) }
}

fn hex_encode_fallback<'a>(src: &[u8], dst: &'a mut [u8]) -> Result<&'a str, usize> {
Expand Down
Loading