From f665edbbe458faf30c70436bad9ccc37008e6cef Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 31 Jul 2024 11:31:16 +0200 Subject: [PATCH] Respects torch.device(0) new behavior without breaking backward compatibilty. --- bindings/python/src/lib.rs | 67 +++++++++++++------------------------- 1 file changed, 22 insertions(+), 45 deletions(-) diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index 4df69ebe..50bb1e5d 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -267,6 +267,22 @@ enum Device { Npu(usize), Xpu(usize), Xla(usize), + /// User didn't specify acceletor, torch + /// is responsible for choosing. + Anonymous(usize), +} + +/// Parsing the device index. +fn parse_device(name: &str) -> PyResult { + let tokens: Vec<_> = name.split(':').collect(); + if tokens.len() == 2 { + let device: usize = tokens[1].parse()?; + Ok(device) + } else { + Err(SafetensorError::new_err(format!( + "device {name} is invalid" + ))) + } } impl<'source> FromPyObject<'source> for Device { @@ -279,56 +295,16 @@ impl<'source> FromPyObject<'source> for Device { "npu" => Ok(Device::Npu(0)), "xpu" => Ok(Device::Xpu(0)), "xla" => Ok(Device::Xla(0)), - name if name.starts_with("cuda:") => { - let tokens: Vec<_> = name.split(':').collect(); - if tokens.len() == 2 { - let device: usize = tokens[1].parse()?; - Ok(Device::Cuda(device)) - } else { - Err(SafetensorError::new_err(format!( - "device {name} is invalid" - ))) - } - } - name if name.starts_with("npu:") => { - let tokens: Vec<_> = name.split(':').collect(); - if tokens.len() == 2 { - let device: usize = tokens[1].parse()?; - Ok(Device::Npu(device)) - } else { - Err(SafetensorError::new_err(format!( - "device {name} is invalid" - ))) - } - } - name if name.starts_with("xpu:") => { - let tokens: Vec<_> = name.split(':').collect(); - if tokens.len() == 2 { - let device: usize = tokens[1].parse()?; - Ok(Device::Xpu(device)) - } else { - Err(SafetensorError::new_err(format!( - "device {name} is invalid" - ))) - } - } - name if name.starts_with("xla:") => { - let tokens: Vec<_> = name.split(':').collect(); - if tokens.len() == 2 { - let device: usize = tokens[1].parse()?; - Ok(Device::Xla(device)) - } else { - Err(SafetensorError::new_err(format!( - "device {name} is invalid" - ))) - } - } + name if name.starts_with("cuda:") => parse_device(name).map(Device::Cuda), + name if name.starts_with("npu:") => parse_device(name).map(Device::Npu), + name if name.starts_with("xpu:") => parse_device(name).map(Device::Xpu), + name if name.starts_with("xla:") => parse_device(name).map(Device::Xla), name => Err(SafetensorError::new_err(format!( "device {name} is invalid" ))), } } else if let Ok(number) = ob.extract::() { - Ok(Device::Cuda(number)) + Ok(Device::Anonymous(number)) } else { Err(SafetensorError::new_err(format!("device {ob} is invalid"))) } @@ -344,6 +320,7 @@ impl IntoPy for Device { Device::Npu(n) => format!("npu:{n}").into_py(py), Device::Xpu(n) => format!("xpu:{n}").into_py(py), Device::Xla(n) => format!("xla:{n}").into_py(py), + Device::Anonymous(n) => format!("{n}").into_py(py), } } }