From 094e676b371763498a5d04a28477a7ef0fb6f6a8 Mon Sep 17 00:00:00 2001 From: "Hz, Ji" Date: Mon, 20 Nov 2023 17:23:44 +0800 Subject: [PATCH] Adding support for Ascend NPU (#372) * Adding support for Ascend NPU * remove the unnecessary hack code * test more dtype * npu doesn't support calling torch.allclose with bf16 for now --- bindings/python/src/lib.rs | 14 ++++++++++++++ bindings/python/tests/test_pt_comparison.py | 21 +++++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index f44915cc..e9cd82b7 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -220,6 +220,7 @@ enum Device { Cpu, Cuda(usize), Mps, + Npu(usize), } impl<'source> FromPyObject<'source> for Device { @@ -229,6 +230,7 @@ impl<'source> FromPyObject<'source> for Device { "cpu" => Ok(Device::Cpu), "cuda" => Ok(Device::Cuda(0)), "mps" => Ok(Device::Mps), + "npu" => Ok(Device::Npu(0)), name if name.starts_with("cuda:") => { let tokens: Vec<_> = name.split(':').collect(); if tokens.len() == 2 { @@ -240,6 +242,17 @@ impl<'source> FromPyObject<'source> for Device { ))) } } + name if name.starts_with("npu:") => { + let tokens: Vec<_> = name.split(':').collect(); + if tokens.len() == 2 { + let device: usize = tokens[1].parse()?; + Ok(Device::Npu(device)) + } else { + Err(SafetensorError::new_err(format!( + "device {name} is invalid" + ))) + } + } name => Err(SafetensorError::new_err(format!( "device {name} is invalid" ))), @@ -258,6 +271,7 @@ impl IntoPy for Device { Device::Cpu => "cpu".into_py(py), Device::Cuda(n) => format!("cuda:{n}").into_py(py), Device::Mps => "mps".into_py(py), + Device::Npu(n) => format!("npu:{n}").into_py(py), } } } diff --git a/bindings/python/tests/test_pt_comparison.py b/bindings/python/tests/test_pt_comparison.py index cb1a9f57..d4fa0ec6 100644 --- a/bindings/python/tests/test_pt_comparison.py +++ b/bindings/python/tests/test_pt_comparison.py @@ -7,6 +7,14 @@ from safetensors.torch import load, load_file, save, save_file +try: + import torch_npu # noqa + + npu_present = True +except Exception: + npu_present = False + + class TorchTestCase(unittest.TestCase): def test_serialization(self): data = torch.zeros((2, 2), dtype=torch.int32) @@ -119,6 +127,19 @@ def test_gpu(self): reloaded = load_file(local) self.assertTrue(torch.equal(torch.arange(4).view((2, 2)), reloaded["test"])) + @unittest.skipIf(not npu_present, "Npu is not available") + def test_npu(self): + data = { + "test1": torch.zeros((2, 2), dtype=torch.float32).to("npu:0"), + "test2": torch.zeros((2, 2), dtype=torch.float16).to("npu:0"), + } + local = "./tests/data/out_safe_pt_mmap_small_npu.safetensors" + save_file(data, local) + + reloaded = load_file(local, device="npu:0") + for k, v in reloaded.items(): + self.assertTrue(torch.allclose(data[k], reloaded[k])) + def test_sparse(self): data = {"test": torch.sparse_coo_tensor(size=(2, 3))} local = "./tests/data/out_safe_pt_sparse.safetensors"