Skip to content

Commit

Permalink
Create Store Validation (#80)
Browse files Browse the repository at this point in the history
* Valid selected models against supported models on store create. And set max_token_size for models

* update aiproxy runner in python tests

* use value delimited to accept supported models

* WIP: create model enum for ai

* swap out aimodelmanager trait with model enum

* Create seperate methods for max_token and max_image_dim

* Add subcommand to display supported models and info on supported models

* add termcolor for output

* use strum to iter over supportedmodel variants
  • Loading branch information
Iamdavidonuh authored Aug 18, 2024
1 parent 5b6c50f commit 71140ee
Show file tree
Hide file tree
Showing 12 changed files with 388 additions and 96 deletions.
4 changes: 3 additions & 1 deletion ahnlich/ai/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ ahnlich_similarity = { path = "../similarity", version = "*", features = ["serde
cap.workspace = true
deadpool.workspace = true
nonzero_ext = "0.3.0"

serde_json.workspace = true
termcolor = "1.4.1"
strum = { version = "0.26", features = ["derive"] }

[dev-dependencies]
db = { path = "../db", version = "*" }
Expand Down
93 changes: 92 additions & 1 deletion ahnlich/ai/src/cli/server.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,16 @@
use clap::{ArgAction, Args, Parser, Subcommand};
use ahnlich_types::ai::AIModel;
use clap::{ArgAction, Args, Parser, Subcommand, ValueEnum};
use strum::VariantArray;

use crate::engine::ai::models::{Model, ModelInfo};
use std::io::Write;
use termcolor::{Color, ColorChoice, ColorSpec, StandardStream, WriteColor};

#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum, VariantArray)]
pub enum SupportedModels {
Llama3,
Dalle3,
}

#[derive(Parser)]
#[command(version, about, long_about = None)]
Expand All @@ -11,6 +23,9 @@ pub struct Cli {
pub enum Commands {
/// Starts Anhlich AI Proxy
Start(AIProxyConfig),

/// Outputs all supported models by aiproxy
SupportedModels(SupportedModelArgs),
}

#[derive(Args, Debug, Clone)]
Expand Down Expand Up @@ -73,6 +88,10 @@ pub struct AIProxyConfig {
/// Defaults to 1000
#[arg(long, default_value_t = 1000)]
pub(crate) maximum_clients: usize,

/// List of ai models to support in your aiproxy stores
#[arg(long, required(true), value_delimiter = ',')]
pub(crate) supported_models: Vec<SupportedModels>,
}

impl Default for AIProxyConfig {
Expand All @@ -95,6 +114,7 @@ impl Default for AIProxyConfig {
otel_endpoint: None,
log_level: String::from("info"),
maximum_clients: 1000,
supported_models: vec![SupportedModels::Llama3, SupportedModels::Dalle3],
}
}
}
Expand All @@ -121,4 +141,75 @@ impl AIProxyConfig {
self.maximum_clients = maximum_clients;
self
}

#[cfg(test)]
pub fn set_supported_models(mut self, models: Vec<SupportedModels>) -> Self {
self.supported_models = models;
self
}
}

impl From<&AIModel> for SupportedModels {
fn from(value: &AIModel) -> Self {
match value {
AIModel::Llama3 => SupportedModels::Llama3,
AIModel::DALLE3 => SupportedModels::Dalle3,
}
}
}

impl From<&SupportedModels> for AIModel {
fn from(value: &SupportedModels) -> Self {
match value {
SupportedModels::Llama3 => AIModel::Llama3,
SupportedModels::Dalle3 => AIModel::DALLE3,
}
}
}

#[derive(Args, Debug, Clone)]
pub struct SupportedModelArgs {
/// Models to display information about
#[arg(long, value_delimiter = ',')]
pub names: Vec<SupportedModels>,
}

impl SupportedModelArgs {
pub fn list_supported_models(&self) -> String {
let mut output = String::new();

for supported_model in SupportedModels::VARIANTS.iter() {
let aimodel: AIModel = supported_model.into();
let model: Model = (&aimodel).into();
output.push_str(format!("{}, ", model.model_name()).as_str())
}
output
}
pub fn list_supported_models_verbose(&self) -> String {
let mut output = vec![];

for supported_model in self.names.iter() {
let aimodel: AIModel = supported_model.into();
let model: Model = (&aimodel).into();
output.push(ModelInfo::build(&model))
}
serde_json::to_string_pretty(&output)
.expect("Failed Generate Supported Models Verbose Text")
}

pub fn output(&self) {
let mut stdout = StandardStream::stdout(ColorChoice::Always);
stdout
.set_color(ColorSpec::new().set_fg(Some(Color::Green)))
.expect("Failed to set output Color");

let mut text = "\n\nDisplaying Supported Models \n\n".to_string();
if !self.names.is_empty() {
text.push_str(&self.list_supported_models_verbose());
} else {
text.push_str(&self.list_supported_models());
}

writeln!(&mut stdout, "{}", text).expect("Failed to write output");
}
}
11 changes: 1 addition & 10 deletions ahnlich/ai/src/engine/ai/mod.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1 @@
mod models;
use ahnlich_types::keyval::{StoreInput, StoreKey};
use models::ModelInfo;
use std::num::NonZeroUsize;

pub trait AIModelManager {
fn embedding_size(&self) -> NonZeroUsize;
fn model_ndarray(&self, storeinput: &StoreInput) -> StoreKey;
fn model_info(&self) -> ModelInfo;
}
pub mod models;
121 changes: 96 additions & 25 deletions ahnlich/ai/src/engine/ai/models.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use crate::engine::ai::AIModelManager;
use ahnlich_types::{
ai::{AIModel, AIStoreInputType},
keyval::{StoreInput, StoreKey},
Expand All @@ -8,39 +7,111 @@ use nonzero_ext::nonzero;
use serde::{Deserialize, Serialize};
use std::num::NonZeroUsize;

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct ModelInfo {
pub name: String,
pub embedding_size: NonZeroUsize,
pub input_type: AIStoreInputType,
pub enum Model {
Text {
name: String,
description: String,
embedding_size: NonZeroUsize,
max_input_tokens: NonZeroUsize,
},
Image {
name: String,
description: String,
max_image_dimensions: NonZeroUsize,
embedding_size: NonZeroUsize,
},
}

impl AIModelManager for AIModel {
fn embedding_size(&self) -> NonZeroUsize {
self.model_info().embedding_size
impl From<&AIModel> for Model {
fn from(value: &AIModel) -> Self {
match value {
AIModel::Llama3 => Self::Text {
name: String::from("Llama3"),
description: String::from("Llama3, a text model"),
embedding_size: nonzero!(100usize),
max_input_tokens: nonzero!(100usize),
},
AIModel::DALLE3 => Self::Image {
name: String::from("DALL.E 3"),
description: String::from("Dalle3, an image model"),
embedding_size: nonzero!(300usize),
max_image_dimensions: nonzero!(300usize),
},
}
}
}

impl Model {
pub fn embedding_size(&self) -> NonZeroUsize {
match self {
Model::Text { embedding_size, .. } => *embedding_size,
Model::Image { embedding_size, .. } => *embedding_size,
}
}
pub fn input_type(&self) -> String {
match self {
Model::Text { .. } => AIStoreInputType::RawString.to_string(),
Model::Image { .. } => AIStoreInputType::Image.to_string(),
}
}

// TODO: model ndarray values is based on length of string or vec, so for now make sure strings
// or vecs have different lengths
fn model_ndarray(&self, storeinput: &StoreInput) -> StoreKey {
pub fn model_ndarray(&self, storeinput: &StoreInput) -> StoreKey {
let length = storeinput.len() as f32;
StoreKey(
Array1::from_iter(0..self.model_info().embedding_size.into())
.mapv(|v| v as f32 * length),
)
StoreKey(Array1::from_iter(0..self.embedding_size().into()).mapv(|v| v as f32 * length))
}

fn model_info(&self) -> ModelInfo {
pub fn max_input_token(&self) -> Option<NonZeroUsize> {
match self {
AIModel::Llama3 => ModelInfo {
name: String::from("Llama3"),
embedding_size: nonzero!(100usize),
input_type: AIStoreInputType::RawString,
},
AIModel::DALLE3 => ModelInfo {
name: String::from("DALL.E 3"),
embedding_size: nonzero!(300usize),
input_type: AIStoreInputType::Image,
},
Model::Text {
max_input_tokens, ..
} => Some(*max_input_tokens),
Model::Image { .. } => None,
}
}
pub fn max_image_dimensions(&self) -> Option<NonZeroUsize> {
match self {
Model::Text { .. } => None,
Model::Image {
max_image_dimensions,
..
} => Some(*max_image_dimensions),
}
}
pub fn model_name(&self) -> String {
match self {
Model::Text { name, .. } => name.clone(),
Model::Image { name, .. } => name.clone(),
}
}
pub fn model_description(&self) -> String {
match self {
Model::Text { description, .. } => description.clone(),
Model::Image { description, .. } => description.clone(),
}
}
}

#[derive(Debug, Serialize, Deserialize)]
pub(crate) struct ModelInfo {
name: String,
input_type: String,
embedding_size: NonZeroUsize,
max_input_tokens: Option<NonZeroUsize>,
max_image_dimensions: Option<NonZeroUsize>,
description: String,
}

impl ModelInfo {
pub(crate) fn build(model: &Model) -> Self {
Self {
name: model.model_name(),
input_type: model.input_type(),
embedding_size: model.embedding_size(),
max_input_tokens: model.max_input_token(),
max_image_dimensions: model.max_image_dimensions(),
description: model.model_description(),
}
}
}
Loading

0 comments on commit 71140ee

Please sign in to comment.