-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #23 from oxideai/api/device
feat(api): add device
- Loading branch information
Showing
5 changed files
with
153 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
name: validate | ||
on: | ||
push: | ||
branches: | ||
- main | ||
pull_request: | ||
types: [opened, synchronize] | ||
concurrency: | ||
group: ${{ github.workflow }}-${{ github.ref }} | ||
cancel-in-progress: true | ||
jobs: | ||
rustfmt-check: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v2 | ||
- uses: actions-rust-lang/setup-rust-toolchain@v1 | ||
with: | ||
cache: false | ||
components: rustfmt | ||
- name: Run cargo fmt | ||
run: cargo fmt -- --check | ||
# - name: Run cargo clippy | ||
# run: cargo clippy -- -D warnings | ||
|
||
tests: | ||
runs-on: blaze/macos-14 | ||
strategy: | ||
matrix: | ||
rust: [ stable, 1.75.0 ] | ||
steps: | ||
- name: Checkout | ||
uses: actions/checkout@v4 | ||
with: | ||
submodules: true | ||
- name: Install Rust | ||
uses: actions-rust-lang/setup-rust-toolchain@v1 | ||
with: | ||
cache: false | ||
toolchain: ${{ matrix.rust }} | ||
rustflags: "" # Disable when we're ready | ||
- name: Setup cache | ||
uses: Swatinem/rust-cache@v2 | ||
with: | ||
key: ${{ runner.os }}-${{ matrix.cache }}-${{ matrix.backend }}-${{ hashFiles('**/Cargo.toml') }} | ||
- name: Run tests | ||
run: cargo test --all |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
use crate::utils::mlx_describe; | ||
|
||
///Type of device. | ||
pub enum DeviceType { | ||
Cpu, | ||
Gpu, | ||
} | ||
|
||
/// Representation of a Device in MLX. | ||
#[derive(Debug)] | ||
pub struct Device { | ||
ctx: mlx_sys::mlx_device, | ||
} | ||
|
||
impl Device { | ||
pub fn new_default() -> Device { | ||
let ctx = unsafe { mlx_sys::mlx_default_device() }; | ||
Device { ctx } | ||
} | ||
|
||
pub fn new(device_type: DeviceType, index: i32) -> Device { | ||
let c_device_type: u32 = match device_type { | ||
DeviceType::Cpu => mlx_sys::mlx_device_type__MLX_CPU, | ||
DeviceType::Gpu => mlx_sys::mlx_device_type__MLX_GPU, | ||
}; | ||
|
||
let ctx = unsafe { mlx_sys::mlx_device_new(c_device_type, index) }; | ||
Device { ctx } | ||
} | ||
|
||
pub fn cpu() -> Device { | ||
Device::new(DeviceType::Cpu, 0) | ||
} | ||
|
||
pub fn gpu() -> Device { | ||
Device::new(DeviceType::Gpu, 0) | ||
} | ||
|
||
/// Set the default device. | ||
/// | ||
/// Example: | ||
/// ```rust | ||
/// use mlx::device::{Device, DeviceType}; | ||
/// Device::set_default(&Device::new(DeviceType::Cpu, 1)); | ||
/// ``` | ||
/// | ||
/// By default, this is `gpu()`. | ||
pub fn set_default(&self) { | ||
unsafe { mlx_sys::mlx_set_default_device(self.ctx) }; | ||
} | ||
} | ||
|
||
impl Drop for Device { | ||
fn drop(&mut self) { | ||
unsafe { mlx_sys::mlx_free(self.ctx as *mut std::ffi::c_void) }; | ||
} | ||
} | ||
|
||
impl std::fmt::Display for Device { | ||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { | ||
let description = unsafe { mlx_describe(self.ctx as *mut std::os::raw::c_void) }; | ||
let description = description.unwrap_or_else(|| "Device".to_string()); | ||
|
||
write!(f, "{}", description) | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
|
||
#[test] | ||
fn test_fmt() { | ||
let device = Device::new_default(); | ||
let description = format!("{}", device); | ||
assert_eq!(description, "Device(gpu, 0)"); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,2 @@ | ||
pub fn add(left: usize, right: usize) -> usize { | ||
left + right | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
|
||
#[test] | ||
fn it_works() { | ||
let result = add(2, 2); | ||
assert_eq!(result, 4); | ||
} | ||
} | ||
pub mod device; | ||
mod utils; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
/// Helper method to get a string representation of an mlx object. | ||
pub(crate) fn mlx_describe(ptr: *mut ::std::os::raw::c_void) -> Option<String> { | ||
let mlx_description = unsafe { mlx_sys::mlx_tostring(ptr) }; | ||
let c_str = unsafe { mlx_sys::mlx_string_data(mlx_description) }; | ||
|
||
let description = if c_str.is_null() { | ||
None | ||
} else { | ||
Some(unsafe { | ||
std::ffi::CStr::from_ptr(c_str) | ||
.to_string_lossy() | ||
.into_owned() | ||
}) | ||
}; | ||
|
||
unsafe { mlx_sys::mlx_free(mlx_description as *mut std::ffi::c_void) }; | ||
|
||
description | ||
} |