Skip to content

Commit

Permalink
Merge pull request #23 from oxideai/api/device
Browse files Browse the repository at this point in the history
feat(api): add device
  • Loading branch information
minghuaw authored Apr 16, 2024
2 parents 60f2214 + 9f4a216 commit aa6c09c
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 19 deletions.
46 changes: 46 additions & 0 deletions .github/workflows/validate.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
name: validate
on:
push:
branches:
- main
pull_request:
types: [opened, synchronize]
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
rustfmt-check:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions-rust-lang/setup-rust-toolchain@v1
with:
cache: false
components: rustfmt
- name: Run cargo fmt
run: cargo fmt -- --check
# - name: Run cargo clippy
# run: cargo clippy -- -D warnings

tests:
runs-on: blaze/macos-14
strategy:
matrix:
rust: [ stable, 1.75.0 ]
steps:
- name: Checkout
uses: actions/checkout@v4
with:
submodules: true
- name: Install Rust
uses: actions-rust-lang/setup-rust-toolchain@v1
with:
cache: false
toolchain: ${{ matrix.rust }}
rustflags: "" # Disable when we're ready
- name: Setup cache
uses: Swatinem/rust-cache@v2
with:
key: ${{ runner.os }}-${{ matrix.cache }}-${{ matrix.backend }}-${{ hashFiles('**/Cargo.toml') }}
- name: Run tests
run: cargo test --all
13 changes: 8 additions & 5 deletions mlx-sys/build.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
extern crate cmake;

use cmake::Config;
use std::env;
use std::path::PathBuf;
use cmake::Config;

fn main() {
let mut config = Config::new("src/mlx-c");
Expand All @@ -11,11 +11,13 @@ fn main() {
config.define("MLX_BUILD_METAL", "OFF");
config.define("MLX_BUILD_ACCELERATE", "OFF");

#[cfg(feature = "metal")] {
#[cfg(feature = "metal")]
{
config.define("MLX_BUILD_METAL", "ON");
}

#[cfg(feature = "accelerate")] {
#[cfg(feature = "accelerate")]
{
config.define("MLX_BUILD_ACCELERATE", "ON");
}

Expand All @@ -30,7 +32,8 @@ fn main() {
println!("cargo:rustc-link-lib=dylib=objc");
println!("cargo:rustc-link-lib=framework=Foundation");

#[cfg(feature = "metal")] {
#[cfg(feature = "metal")]
{
println!("cargo:rustc-link-lib=framework=Metal");
}

Expand All @@ -47,4 +50,4 @@ fn main() {
bindings
.write_to_file(out_path.join("bindings.rs"))
.expect("Couldn't write bindings!");
}
}
78 changes: 78 additions & 0 deletions src/device.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
use crate::utils::mlx_describe;

///Type of device.
pub enum DeviceType {
Cpu,
Gpu,
}

/// Representation of a Device in MLX.
#[derive(Debug)]
pub struct Device {
ctx: mlx_sys::mlx_device,
}

impl Device {
pub fn new_default() -> Device {
let ctx = unsafe { mlx_sys::mlx_default_device() };
Device { ctx }
}

pub fn new(device_type: DeviceType, index: i32) -> Device {
let c_device_type: u32 = match device_type {
DeviceType::Cpu => mlx_sys::mlx_device_type__MLX_CPU,
DeviceType::Gpu => mlx_sys::mlx_device_type__MLX_GPU,
};

let ctx = unsafe { mlx_sys::mlx_device_new(c_device_type, index) };
Device { ctx }
}

pub fn cpu() -> Device {
Device::new(DeviceType::Cpu, 0)
}

pub fn gpu() -> Device {
Device::new(DeviceType::Gpu, 0)
}

/// Set the default device.
///
/// Example:
/// ```rust
/// use mlx::device::{Device, DeviceType};
/// Device::set_default(&Device::new(DeviceType::Cpu, 1));
/// ```
///
/// By default, this is `gpu()`.
pub fn set_default(&self) {
unsafe { mlx_sys::mlx_set_default_device(self.ctx) };
}
}

impl Drop for Device {
fn drop(&mut self) {
unsafe { mlx_sys::mlx_free(self.ctx as *mut std::ffi::c_void) };
}
}

impl std::fmt::Display for Device {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
let description = unsafe { mlx_describe(self.ctx as *mut std::os::raw::c_void) };
let description = description.unwrap_or_else(|| "Device".to_string());

write!(f, "{}", description)
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_fmt() {
let device = Device::new_default();
let description = format!("{}", device);
assert_eq!(description, "Device(gpu, 0)");
}
}
16 changes: 2 additions & 14 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,2 @@
pub fn add(left: usize, right: usize) -> usize {
left + right
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn it_works() {
let result = add(2, 2);
assert_eq!(result, 4);
}
}
pub mod device;
mod utils;
19 changes: 19 additions & 0 deletions src/utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/// Helper method to get a string representation of an mlx object.
pub(crate) fn mlx_describe(ptr: *mut ::std::os::raw::c_void) -> Option<String> {
let mlx_description = unsafe { mlx_sys::mlx_tostring(ptr) };
let c_str = unsafe { mlx_sys::mlx_string_data(mlx_description) };

let description = if c_str.is_null() {
None
} else {
Some(unsafe {
std::ffi::CStr::from_ptr(c_str)
.to_string_lossy()
.into_owned()
})
};

unsafe { mlx_sys::mlx_free(mlx_description as *mut std::ffi::c_void) };

description
}

0 comments on commit aa6c09c

Please sign in to comment.