diff --git a/README.md b/README.md index d530fa0..b2b6f02 100644 --- a/README.md +++ b/README.md @@ -58,7 +58,7 @@ And the following ruby: ```ruby require 'candle' -model = Candle::Model.new +model = Candle::Model.new("dmis-lab/biobert-base-cased-v1.1") embedding = model.embedding("Hi there!") ``` diff --git a/ext/candle/Cargo.toml b/ext/candle/Cargo.toml index 7cdda12..bcc9a28 100644 --- a/ext/candle/Cargo.toml +++ b/ext/candle/Cargo.toml @@ -7,6 +7,8 @@ edition = "2021" crate-type = ["cdylib"] [dependencies] +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" candle-core = "0.4.1" candle-nn = "0.4.1" candle-transformers = "0.4.1" diff --git a/ext/candle/src/lib.rs b/ext/candle/src/lib.rs index 49e171b..c643c2c 100644 --- a/ext/candle/src/lib.rs +++ b/ext/candle/src/lib.rs @@ -95,6 +95,8 @@ fn init(ruby: &Ruby) -> RbResult<()> { let rb_model = rb_candle.define_class("Model", Ruby::class_object(ruby))?; rb_model.define_singleton_method("new", function!(RbModel::new, 0))?; + rb_model.define_singleton_method("new1", function!(RbModel::new1, 1))?; + rb_model.define_singleton_method("new2", function!(RbModel::new2, 2))?; rb_model.define_method("embedding", method!(RbModel::embedding, 1))?; rb_model.define_method("to_s", method!(RbModel::__str__, 0))?; rb_model.define_method("inspect", method!(RbModel::__repr__, 0))?; diff --git a/ext/candle/src/model/rb_model.rs b/ext/candle/src/model/rb_model.rs index b6ff1c0..34c68a2 100644 --- a/ext/candle/src/model/rb_model.rs +++ b/ext/candle/src/model/rb_model.rs @@ -8,11 +8,14 @@ use crate::model::{ errors::{wrap_candle_err, wrap_hf_err, wrap_std_err}, rb_tensor::RbTensor, }; -use candle_core::{DType, Device, Module, Tensor}; +use candle_core::{DType, Device, Tensor}; use candle_nn::VarBuilder; -use candle_transformers::models::jina_bert::{BertModel, Config}; +use candle_transformers::models::bert::{BertModel, Config}; use magnus::Error; use crate::model::RbResult; +use serde_json; +use std::fs; +use std::path::PathBuf; use tokenizers::Tokenizer; #[magnus::wrap(class = "Candle::Model", free_immediately, size)] @@ -28,10 +31,18 @@ pub struct RbModelInner { impl RbModel { pub fn new() -> RbResult { - Self::new2(Some("jinaai/jina-embeddings-v2-base-en".to_string()), Some("sentence-transformers/all-MiniLM-L6-v2".to_string()), Some(Device::Cpu)) + Self::new3(Some("jinaai/jina-embeddings-v2-base-en".to_string()), Some("sentence-transformers/all-MiniLM-L6-v2".to_string()), Some(Device::Cpu)) } - pub fn new2(model_path: Option, tokenizer_path: Option, device: Option) -> RbResult { + pub fn new1(model_path: Option) -> RbResult { + Self::new3(model_path, Some("sentence-transformers/all-MiniLM-L6-v2".to_string()), Some(Device::Cpu)) + } + + pub fn new2(model_path: Option, tokenizer_path: Option) -> RbResult { + Self::new3(model_path, tokenizer_path, Some(Device::Cpu)) + } + + pub fn new3(model_path: Option, tokenizer_path: Option, device: Option) -> RbResult { let device = device.unwrap_or(Device::Cpu); Ok(RbModel(RbModelInner { device: device.clone(), @@ -60,7 +71,6 @@ impl RbModel { } None => Err(magnus::Error::new(magnus::exception::runtime_error(), "Tokenizer or Model not found")) } - } fn build_model(model_path: String, device: Device) -> RbResult { @@ -68,17 +78,31 @@ impl RbModel { let model_path = Api::new() .map_err(wrap_hf_err)? .repo(Repo::new( - model_path, + model_path.clone(), RepoType::Model, )) .get("model.safetensors") .map_err(wrap_hf_err)?; - let config = Config::v2_base(); + println!("Model path: {:?}", model_path); + let config_path = model_path.parent().unwrap().join("config.json"); + println!("Config path: {:?}", config_path); + + // let config_path = Api::new() + // .map_err(wrap_hf_err)? + // .repo(Repo::new( + // model_path.to_str().unwrap().to_string(), + // RepoType::Model, + // )) + // .get("config.json") + // .map_err(wrap_hf_err)?; + + let config: Config = read_config(config_path)?; + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_path], DType::F32, &device) .map_err(wrap_candle_err)? }; - let model = BertModel::new(vb, &config).map_err(wrap_candle_err)?; + let model = BertModel::load(vb, &config).map_err(wrap_candle_err)?; Ok(model) } @@ -119,8 +143,10 @@ impl RbModel { .unsqueeze(0) .map_err(wrap_candle_err)?; - // let start: std::time::Instant = std::time::Instant::now(); - let result = model.forward(&token_ids).map_err(wrap_candle_err)?; + let token_type_ids = Tensor::zeros(&*token_ids.shape(), DType::I64, &self.0.device) + .map_err(wrap_candle_err)?; + + let result = model.forward(&token_ids, &token_type_ids).map_err(wrap_candle_err)?; // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding) let (_n_sentence, n_tokens, _hidden_size) = result.dims3() @@ -129,7 +155,6 @@ impl RbModel { .map_err(wrap_candle_err)?; let embeddings = (sum / (n_tokens as f64)) .map_err(wrap_candle_err)?; - // let embeddings = Self::normalize_l2(&embeddings).map_err(wrap_candle_err)?; Ok(embeddings) } @@ -148,6 +173,13 @@ impl RbModel { } } +fn read_config(config_path: PathBuf) -> Result { + let config_str = fs::read_to_string(config_path).map_err(|e| wrap_std_err(Box::new(e)))?; + println!("Config string: {}", config_str); + let config_json: Config = serde_json::from_str(&config_str).map_err(|e| wrap_std_err(Box::new(e)))?; + Ok(config_json) +} + // #[cfg(test)] // mod tests { // #[test]