From 1be26527085dc7d66fea8b2851a64dfe72207392 Mon Sep 17 00:00:00 2001 From: Siva Date: Thu, 30 Jan 2025 15:49:38 +0530 Subject: [PATCH 1/2] [CPP_CLI] MLC Cli App over JSONEngine interface Comprehensive CPP Cli developed over existing JSONEngine interface. Intended to use in environments with only cli access like Android ADB shells. --- CMakeLists.txt | 3 + apps/mlc_cli_chat/CMakeLists.txt | 31 +++ apps/mlc_cli_chat/README.md | 3 + apps/mlc_cli_chat/base.h | 18 ++ apps/mlc_cli_chat/chat_state.cc | 110 +++++++++ apps/mlc_cli_chat/chat_state.h | 30 +++ apps/mlc_cli_chat/engine.cc | 368 ++++++++++++++++++++++++++++++ apps/mlc_cli_chat/engine.h | 100 ++++++++ apps/mlc_cli_chat/mlc_cli_chat.cc | 114 +++++++++ cpp/serve/config.h | 3 +- 10 files changed, 779 insertions(+), 1 deletion(-) create mode 100644 apps/mlc_cli_chat/CMakeLists.txt create mode 100644 apps/mlc_cli_chat/README.md create mode 100644 apps/mlc_cli_chat/base.h create mode 100644 apps/mlc_cli_chat/chat_state.cc create mode 100644 apps/mlc_cli_chat/chat_state.h create mode 100644 apps/mlc_cli_chat/engine.cc create mode 100644 apps/mlc_cli_chat/engine.h create mode 100644 apps/mlc_cli_chat/mlc_cli_chat.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index a010a05192..9d6cde7821 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -158,6 +158,8 @@ if(NOT CARGO_EXECUTABLE) message(FATAL_ERROR "Cargo is not found! Please install cargo.") endif() +add_subdirectory(apps/mlc_cli_chat) + # when this option is on, we install all static lib deps into lib if(MLC_LLM_INSTALL_STATIC_LIB) install(TARGETS mlc_llm_static tokenizers_cpp sentencepiece-static tvm_runtime @@ -178,6 +180,7 @@ else() mlc_llm_static tokenizers_cpp sentencepiece-static + mlc_cli_chat RUNTIME_DEPENDENCY_SET tokenizers_c RUNTIME DESTINATION bin diff --git a/apps/mlc_cli_chat/CMakeLists.txt b/apps/mlc_cli_chat/CMakeLists.txt new file mode 100644 index 0000000000..f0b659fc32 --- /dev/null +++ b/apps/mlc_cli_chat/CMakeLists.txt @@ -0,0 +1,31 @@ +cmake_policy(SET CMP0069 NEW) # suppress cmake warning about IPO + +set(MLC_CLI_SOURCES + mlc_cli_chat.cc + chat_state.cc + engine.cc +) +set(MLC_CLI_LINKER_LIBS "") + +set( + MLC_CLI_CHAT_INCLUDES + ../../3rdparty/tvm/include + ../../3rdparty/tvm/3rdparty/dlpack/include + ../../3rdparty/tvm/3rdparty/dmlc-core/include + ../../3rdparty/tvm/3rdparty/picojson + ../../3rdparty/tokenizers-cpp/include + ../..//3rdparty/xgrammar/include +) + +add_executable(mlc_cli_chat ${MLC_CLI_SOURCES}) +target_include_directories(mlc_cli_chat PUBLIC ${CMAKE_CURRENT_SOURCE_DIR} ${MLC_CLI_CHAT_INCLUDES} ${PROJECT_SOURCE_DIR}/cpp) +target_link_libraries(mlc_cli_chat PUBLIC mlc_llm ${TVM_RUNTIME_LINKER_LIBS}) + +if(USE_CUDA) + include(../../3rdparty/tvm/cmake/utils/Utils.cmake) + include(../../3rdparty/tvm/cmake/utils/FindCUDA.cmake) + find_cuda(${USE_CUDA} ${USE_CUDNN}) + target_link_libraries(mlc_cli_chat PUBLIC ${CUDA_NVRTC_LIBRARY}) + target_link_libraries(mlc_cli_chat PUBLIC ${CUDA_CUDART_LIBRARY}) + target_link_libraries(mlc_cli_chat PUBLIC ${CUDA_CUDA_LIBRARY}) +endif() diff --git a/apps/mlc_cli_chat/README.md b/apps/mlc_cli_chat/README.md new file mode 100644 index 0000000000..e613a1770e --- /dev/null +++ b/apps/mlc_cli_chat/README.md @@ -0,0 +1,3 @@ +# MLC Chat Cli Application + +A native app application that can load and run MLC models on cli. diff --git a/apps/mlc_cli_chat/base.h b/apps/mlc_cli_chat/base.h new file mode 100644 index 0000000000..bc71e492b9 --- /dev/null +++ b/apps/mlc_cli_chat/base.h @@ -0,0 +1,18 @@ +/*! + * Copyright (c) 2023-2025 by Contributors + * \file base.h + */ + +#ifndef MLC_CLI_CHAT_BASE_H +#define MLC_CLI_CHAT_BASE_H + +#include + +#include +#include + +struct Message { + std::unordered_map content; +}; + +#endif diff --git a/apps/mlc_cli_chat/chat_state.cc b/apps/mlc_cli_chat/chat_state.cc new file mode 100644 index 0000000000..97923139fb --- /dev/null +++ b/apps/mlc_cli_chat/chat_state.cc @@ -0,0 +1,110 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file chat_state.cc + */ + +#include "chat_state.h" + +#include + +#include "base.h" +#include "engine.h" + +void print_help_str() { + std::string help_string = R"("""You can use the following special commands: + /help print the special commands + /exit quit the cli + /stats print out stats of last request (token/sec) + Multi-line input: Use escape+enter to start a new line. +""")"; + + std::cout << help_string << std::endl; +} + +ChatState::ChatState(std::string model_path, std::string model_lib_path, std::string mode, + std::string device, int device_id) { + history_window_begin = 0; + __json_wrapper = + std::make_shared(model_path, model_lib_path, mode, device, device_id); +} + +void ChatState::slide_history() { + size_t history_window_size = history.size() - history_window_begin; + history_window_begin += ((history_window_size + 3) / 4) * 2; +} + +std::vector ChatState::get_current_history_window() { + return std::vector(history.begin() + history_window_begin, history.end()); +} + +int ChatState::generate(const std::string& prompt) { + // setting back the finish_reason_length + bool finish_reason_length = false; + + // User Message + Message new_message; + new_message.content["role"] = "user"; + new_message.content["content"] = prompt; + history.push_back(new_message); + + auto curr_window = get_current_history_window(); + + std::string output_text{""}; + + output_text = (*__json_wrapper).chat.completions.create(curr_window); + + if (__json_wrapper->engine_state->finish_reason == "length") { + finish_reason_length = true; + } + + if (finish_reason_length) { + std::cout << "\n[output truncated due to context length limit...]"; + } + + Message assistant_response; + assistant_response.content["role"] = "assistant"; + + picojson::value val(output_text); + + std::string output_json_str = val.serialize(); + + assistant_response.content["content"] = output_json_str; + history.push_back(assistant_response); + + if (finish_reason_length) { + slide_history(); + } + return 0; +} + +void ChatState::reset() { + history.clear(); + history_window_begin = 0; +} + +int ChatState::chat(std::string prompt) { + print_help_str(); + // Get the prompt message + if (!prompt.empty()) { + int ret = generate(prompt); + __json_wrapper->background_loops->terminate(); + this->__json_wrapper->engine_state->getStats(); + return ret; + } + std::string cin_prompt; + while (true) { + std::cout << ">>> "; + std::getline(std::cin, cin_prompt); + if (std::cin.eof() || cin_prompt == "/exit") { + __json_wrapper->background_loops->terminate(); + break; + } else if (cin_prompt == "/help") { + print_help_str(); + } else if (cin_prompt == "/stats") { + this->__json_wrapper->engine_state->getStats(); + } else { + generate(cin_prompt); + } + } + return 0; +} diff --git a/apps/mlc_cli_chat/chat_state.h b/apps/mlc_cli_chat/chat_state.h new file mode 100644 index 0000000000..370226ba40 --- /dev/null +++ b/apps/mlc_cli_chat/chat_state.h @@ -0,0 +1,30 @@ +/*! + * Copyright (c) 2023-2025 by Contributors + * \file chat_state.h + */ + +#ifndef MLC_CLI_CHAT_CHAT_STATE_H +#define MLC_CLI_CHAT_CHAT_STATE_H + +#include "base.h" +#include "engine.h" + +void print_help_str(); + +class ChatState { + public: + std::vector history; + size_t history_window_begin; + std::shared_ptr __json_wrapper; + + ChatState(std::string model_path, std::string model_lib_path, std::string mode, + std::string device, int device_id = 0); + + void slide_history(); + std::vector get_current_history_window(); + int generate(const std::string& prompt); + void reset(); + int chat(std::string prompt = ""); +}; + +#endif diff --git a/apps/mlc_cli_chat/engine.cc b/apps/mlc_cli_chat/engine.cc new file mode 100644 index 0000000000..6f093d06ff --- /dev/null +++ b/apps/mlc_cli_chat/engine.cc @@ -0,0 +1,368 @@ +/*! + * Copyright (c) 2023-2025 by Contributors + * \file engine.cc + */ + +#include "engine.h" + +#include + +#include "base.h" + +/// Helper function to get the json format of messages +std::string messagesToString(const std::vector& messages) { + std::string result{""}; + for (size_t i = 0; i < messages.size(); ++i) { + const auto& msg = messages[i]; + result += "{"; + + bool firstItem = true; + for (const auto& [key, value] : msg.content) { + if (!firstItem) { + result += ","; + } + result += "\"" + key + "\":"; + + if (key == "role") { + if (value == "user") { + result += "\"" + value + "\""; + } else if (value == "assistant") { + result += "\"" + value + "\""; + } + } else if (key == "content") { + if (i % 2 == 0) { + result += "\"" + value + "\""; + } else { + result += value; + } + } + firstItem = false; + } + result += "}"; + + if (i != messages.size() - 1) { + result += ","; + } + } + return result; +} + +// Helper function to print History +void printHistory(std::vector history) { + for (int i = 0; i < history.size(); i++) { + auto msg = history[i]; + std::cout << " content " << msg.content["content"]; + } + std::cout << "\n"; +} + +EngineStateCli::EngineStateCli() + : queue_cv(std::make_shared()), + queue_mutex(std::make_shared()) {} + +std::function EngineStateCli::get_request_stream_callback() { + return [this](const std::string& response) -> void { + { + this->sync_queue.push(response); + queue_cv->notify_one(); + } + }; +} + +std::string EngineStateCli::handle_chat_completion(tvm::runtime::Module mod, + const std::string& request_json, bool include_usage, + const std::string& request_id) { + // Clear the queue making sure that queue is empty + // Not really required since this process should ideally make the queue empty + { + std::lock_guard lock(*queue_mutex); + std::queue empty; + std::swap(sync_queue, empty); + } + + // TVM Global Function which generates the responses + bool success = mod.GetFunction("chat_completion")(request_json, request_id); + + if (!success) { + std::cerr << "Failed to start chat completion" << std::endl; + } + + try { + last_chunk_arrived = false; + + // Clear the ouput after every chat completion + output = ""; + + while (!last_chunk_arrived) { + std::string json_str; + std::unique_lock lock(*queue_mutex); + + // Wait until the queue is not empty + queue_cv->wait(lock, [this] { return !sync_queue.empty(); }); + std::string response = sync_queue.front(); + sync_queue.pop(); + + picojson::value v; + + // Parse the JSON + std::string err = picojson::parse(v, response); + + // Check for errors + if (!err.empty()) { + std::cerr << "JSON parsing error: " << err << std::endl; + } + + // parsing successful, navigate through the array + picojson::array& arr = v.get(); + for (auto& item : arr) { + picojson::object& obj = item.get(); + + // Extract 'delta' content if available + if (obj.find("choices") != obj.end() && !obj["choices"].get().empty()) { + picojson::object& choices = + obj["choices"].get()[0].get(); + + if (!(choices["finish_reason"].is())) { + // Get the finish reason + std::string finish_reason = choices["finish_reason"].get(); + if (finish_reason == "length") { + finish_reason = "length"; + } + } + if (choices.find("delta") != choices.end()) { + picojson::object& delta = choices["delta"].get(); + if (delta.find("content") != delta.end()) { + std::string content = delta["content"].get(); + + std::cout << content << std::flush; + output += content; + } + } + } + + // Extract 'usage' details if available + if (obj.find("usage") != obj.end()) { + last_chunk_arrived = true; + std::cout << std::endl; + picojson::object& usage = obj["usage"].get(); + + // Access the 'usage' details + double prompt_tokens = usage["prompt_tokens"].get(); + double completion_tokens = usage["completion_tokens"].get(); + double total_tokens = usage["total_tokens"].get(); + + // Access the 'extra' details + picojson::object& extra = usage["extra"].get(); + double prefill_tokens_per_s = extra["prefill_tokens_per_s"].get(); + double decode_tokens_per_s = extra["decode_tokens_per_s"].get(); + double end_to_end_latency_s = extra["end_to_end_latency_s"].get(); + + // fill the stats details + this->decode_tokens_per_s = decode_tokens_per_s; + this->prefill_tokens_per_s = prefill_tokens_per_s; + this->prompt_tokens = prompt_tokens; + this->completion_tokens = completion_tokens; + } + } + } + } catch (const std::exception& exception) { + mod.GetFunction("abort")(request_id); + throw; + } + return output; +} + +void EngineStateCli::getStats() { + std::cout << " decode : " << this->decode_tokens_per_s << " tok/sec (" << this->completion_tokens + << " tokens in " << this->completion_tokens / this->decode_tokens_per_s << " sec)" + << ", prefill : " << this->prefill_tokens_per_s << " tok/sec (" << this->prompt_tokens + << " tokens in " << this->prompt_tokens / this->prefill_tokens_per_s << " sec)" + << std::endl; +} + +// Default Constructor +BackgroundLoops::BackgroundLoops() { terminated = false; } + +// Parametrized constructor +BackgroundLoops::BackgroundLoops(tvm::runtime::Module mod) { + this->__mod = mod; + auto background_loop = mod.GetFunction("run_background_loop"); + auto background_stream_back_loop = mod.GetFunction("run_background_stream_back_loop"); + + background_loop_thread = (std::thread)(background_loop); + background_stream_back_loop_thread = (std::thread)(background_stream_back_loop); + + terminated = false; +} +BackgroundLoops::~BackgroundLoops() { terminate(); } + +void BackgroundLoops::terminate() { + if (!terminated) { + terminated = true; + + try { + __mod.GetFunction("exit_background_loop")(); + } catch (const std::exception& e) { + std::cerr << "Error calling exit_background_loop: " << e.what() << std::endl; + } + + if (background_loop_thread.joinable()) { + background_loop_thread.join(); + } + if (background_stream_back_loop_thread.joinable()) { + background_stream_back_loop_thread.join(); + } + } +} + +// Default constructor +Completions::Completions() {} + +Completions::Completions(std::shared_ptr engine_state, tvm::runtime::Module mod) { + this->engine_state = engine_state; + this->__mod = mod; +} + +// Method to generate a unique string for each process +inline std::string Completions::GenerateUUID(size_t length) { + auto randchar = []() -> char { + const char charset[] = + "0123456789" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz"; + const size_t max_index = (sizeof(charset) - 1); + return charset[rand() % max_index]; + }; + std::string str(length, 0); + std::generate_n(str.begin(), length, randchar); + return str; +} + +std::string Completions::create(std::vector& messages) { + std::string request_id{""}; + // Method to generate random string + std::string generate_random_string{GenerateUUID(16)}; + + // Unique ID for each chat completion process + request_id = "chatcmpl-" + generate_random_string; + + std::string history_string{""}; + std::string left_braces{"{"}; + std::string right_braces{"}"}; + + std::string prompt = messagesToString(messages); + std::string jsonStart = R"({"messages":[)"; + std::string jsonEnd = "]}"; + + std::string request_str = jsonStart + prompt + jsonEnd; + + std::string output_res = + engine_state->handle_chat_completion(__mod, request_str, true, request_id); + return output_res; +} + +Chat::Chat() {} + +Chat::Chat(std::shared_ptr engine_state, tvm::runtime::Module mod) { + this->completions = Completions(engine_state, mod); +} + +// Device str to DLDevice map +DLDeviceType GetDevice(std::string device) { + if ("cuda" == device) { + return kDLCUDA; + } else if ("cpu" == device || "llvm" == device) { + return kDLCPU; + } else if ("opencl" == device) { + return kDLOpenCL; + } else if ("vulkan" == device) { + return kDLVulkan; + } else if ("metal" == device) { + return kDLMetal; + } else { + LOG(FATAL) << "Unsupported device :" << device; + } +} + +// Default constructor with No arguments +JSONFFIEngineWrapper::JSONFFIEngineWrapper() {} + +JSONFFIEngineWrapper::JSONFFIEngineWrapper(std::string model_path, std::string model_lib_path, + std::string mode, std::string device, + int device_id = 0) { + // Create an instance of EngineStateCli + this->engine_state = std::make_shared(); + + auto engine = tvm::runtime::Registry::Get("mlc.json_ffi.CreateJSONFFIEngine"); + if (engine == nullptr) { + std::cout << "\nError: Unable to access TVM global registry mlc.json_ffi.CreateJSONFFIEngine" + << std::endl; + } + + tvm::runtime::Module module_tvm = (*engine)(); + + this->mod = module_tvm; + + // We can give mod as an argument to this + background_loops = std::make_shared(mod); + + this->engine_config = std::make_shared(make_object()); + (*engine_config)->model = model_path; + (*engine_config)->model_lib = model_lib_path; + (*engine_config)->verbose = false; + + if (mode == "interactive") { + (*engine_config)->mode = EngineMode::kInteractive; + } else if (mode == "local") { + (*engine_config)->mode = EngineMode::kLocal; + } else if (mode == "server") { + (*engine_config)->mode = EngineMode::kServer; + } + + const std::string file_path = model_path + "/mlc-chat-config.json"; + std::ifstream file(file_path); + if (!file.is_open()) { + std::cerr << "Error: Unable to open " << file_path << std::endl; + // return 1; + } + + std::string config_content((std::istreambuf_iterator(file)), + std::istreambuf_iterator()); + + // Parse the JSON object + picojson::value config_object; + std::string err; + picojson::parse(config_object, config_content.begin(), config_content.end(), &err); + if (!err.empty()) { + std::cerr << "Error: Unable to parse the JSON object: " << err << std::endl; + } + + // Accessing the parsed data + if (config_object.is()) { + const picojson::object& model_config = config_object.get(); + if (model_config.find("prefill_chunk_size") != model_config.end()) { + double prefill_chunk_size = model_config.at("prefill_chunk_size").get(); + (*engine_config)->prefill_chunk_size = prefill_chunk_size; + } else { + std::cerr << "Error: 'prefill_chunk_size' not found in the JSON object" << std::endl; + } + } else { + std::cerr << "Error: Invalid JSON format" << std::endl; + } + + auto call_back = engine_state->get_request_stream_callback(); + + // Typecasting to the TVM Packed Function + auto tvm_callback = tvm::runtime::TypedPackedFunc(call_back); + + // Call to Initialise Background Engine + mod.GetFunction("init_background_engine")(static_cast(GetDevice(device)), device_id, + tvm_callback); + + std::string engine_config_json_str{(*engine_config)->AsJSONString()}; + + // Call to Reload Function of JSONFFIEngineImpl + mod.GetFunction("reload")(engine_config_json_str); + + chat = Chat(engine_state, mod); +} diff --git a/apps/mlc_cli_chat/engine.h b/apps/mlc_cli_chat/engine.h new file mode 100644 index 0000000000..79db32e7f0 --- /dev/null +++ b/apps/mlc_cli_chat/engine.h @@ -0,0 +1,100 @@ +/*! + * Copyright (c) 2023-2025 by Contributors + * \file engine.h + */ + +#ifndef MLC_CLI_CHAT_ENGINE_H +#define MLC_CLI_CHAT_ENGINE_H + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "base.h" + +class EngineStateCli { + public: + std::queue sync_queue; + bool last_chunk_arrived = false; + std::string output; + std::string finish_reason; + std::shared_ptr queue_mutex; + std::shared_ptr queue_cv; + double decode_tokens_per_s; + double prefill_tokens_per_s; + double prompt_tokens; + double completion_tokens; + + EngineStateCli(); + std::function get_request_stream_callback(); + std::string handle_chat_completion(tvm::runtime::Module mod, const std::string& request_json, + bool include_usage, const std::string& request_id); + void getStats(); +}; + +class Completions { + public: + std::shared_ptr engine_state; + tvm::runtime::Module __mod; + + Completions(); + + Completions(std::shared_ptr engine_state, tvm::runtime::Module mod); + + inline std::string GenerateUUID(size_t length); + + std::string create(std::vector& messages); +}; + +class Chat { + public: + Completions completions; + + Chat(); + + Chat(std::shared_ptr engine_state, tvm::runtime::Module mod); +}; + +class BackgroundLoops { + private: + // Default threads + std::thread background_loop_thread; + std::thread background_stream_back_loop_thread; + bool terminated = false; + tvm::runtime::Module __mod; + + public: + // Default Constructor + BackgroundLoops(); + + // Parametrized constructor + BackgroundLoops(tvm::runtime::Module mod); + ~BackgroundLoops(); + + void terminate(); +}; + +class JSONFFIEngineWrapper { + public: + Chat chat; + std::shared_ptr engine_config; + tvm::runtime::Module mod; + std::shared_ptr engine_state; + std::shared_ptr background_loops; + + JSONFFIEngineWrapper(); + + JSONFFIEngineWrapper(std::string model_path, std::string model_lib_path, std::string mode, + std::string device, int device_id); +}; + +#endif diff --git a/apps/mlc_cli_chat/mlc_cli_chat.cc b/apps/mlc_cli_chat/mlc_cli_chat.cc new file mode 100644 index 0000000000..6864136425 --- /dev/null +++ b/apps/mlc_cli_chat/mlc_cli_chat.cc @@ -0,0 +1,114 @@ +/*! + * Copyright (c) 2023-2025 by Contributors + * \file mlc_cli_chat.cc + */ + +#include + +#include "chat_state.h" +#include "engine.h" + +struct Args { + std::string model; + std::string model_lib_path; + std::string device = "auto"; + bool evaluate = false; + int eval_prompt_len = 128; + int eval_gen_len = 1024; + std::string prompt; +}; + +// Help Prompt +void printHelp() { + std::cout + << "MLCChat CLI is the command line tool to run MLC-compiled LLMs out of the box.\n" + << "Note: the --model argument is required. It can either be the model name with its " + << "quantization scheme or a full path to the model folder. In the former case, the " + << "provided name will be used to search for the model folder over possible paths. " + << "--model-lib-path argument is optional. If unspecified, the --model argument will be used " + << "to search for the library file over possible paths.\n\n" + << "Usage: mlc_cli_chat [options]\n" + << "Options:\n" + << " --model [required] the model to use\n" + << " --model-lib [optional] the full path to the model library file to use\n" + << " --device (default: auto)\n" + << " --with-prompt [optional] runs one session with given prompt\n" + << " --help [optional] Tool usage information\n" + ; +} + +// Method to parse the args +Args parseArgs(int argc, char* argv[]) { + Args args; + + // Taking the arguments after the exectuable + std::vector arguments(argv + 1, argv + argc); + + for (size_t i = 0; i < arguments.size(); ++i) { + if (arguments[i] == "--model" && i + 1 < arguments.size()) { + args.model = arguments[++i]; + } else if (arguments[i] == "--model-lib" && i + 1 < arguments.size()) { + args.model_lib_path = arguments[++i]; + } else if (arguments[i] == "--device" && i + 1 < arguments.size()) { + args.device = arguments[++i]; + } else if (arguments[i] == "--evaluate") { + args.evaluate = true; + } else if (arguments[i] == "--eval-prompt-len" && i + 1 < arguments.size()) { + args.eval_prompt_len = std::stoi(arguments[++i]); + } else if (arguments[i] == "--eval-gen-len" && i + 1 < arguments.size()) { + args.eval_gen_len = std::stoi(arguments[++i]); + } else if (arguments[i] == "--with-prompt" && i + 1 < arguments.size()) { + args.prompt = arguments[++i]; + } else if (arguments[i] == "--help") { + printHelp(); + exit(0); + } else { + printHelp(); + throw std::runtime_error("Unknown or incomplete argument: " + arguments[i]); + } + } + + if (args.model.empty()) { + printHelp(); + throw std::runtime_error("Invalid arguments"); + } + + return args; +} + +// Method to detect the device +std::pair DetectDevice(std::string device) { + std::string device_name; + int device_id; + int delimiter_pos = device.find(":"); + + // cuda:0 which means the device name is cuda and the device id is 0 + if (delimiter_pos == std::string::npos) { + device_name = device; + device_id = 0; + } else { + device_name = device.substr(0, delimiter_pos); + device_id = std::stoi(device.substr(delimiter_pos + 1, device.length())); + } + return {device_name, device_id}; +} + +int main(int argc, char* argv[]) { + Args args = parseArgs(argc, argv); + + // model path + std::string model_path = args.model; + + // model-lib path + std::string model_lib_path = args.model_lib_path; + + // Get the device name and device id + auto [device_name, device_id] = DetectDevice(args.device); + + // mode of interaction + std::string mode{"interactive"}; + + ChatState chat_state(model_path, model_lib_path, mode, device_name, 0); + + return chat_state.chat(args.prompt); +} diff --git a/cpp/serve/config.h b/cpp/serve/config.h index 9da3ba2517..d60ed99843 100644 --- a/cpp/serve/config.h +++ b/cpp/serve/config.h @@ -12,6 +12,7 @@ #include +#include "../base.h" #include "../metadata/model.h" #include "../support/result.h" @@ -219,7 +220,7 @@ enum class PrefillMode : int { class InferrableEngineConfig; /*! \brief The configuration of engine execution config. */ -class EngineConfigNode : public Object { +class MLC_LLM_DLL EngineConfigNode : public Object { public: /*************** Models ***************/ From 11da9ed2e63a07a329a084a96159a5347206777c Mon Sep 17 00:00:00 2001 From: Siva Date: Thu, 30 Jan 2025 20:14:58 +0530 Subject: [PATCH 2/2] Lint --- apps/mlc_cli_chat/README.md | 2 +- apps/mlc_cli_chat/engine.cc | 5 +++-- apps/mlc_cli_chat/mlc_cli_chat.cc | 3 +-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/apps/mlc_cli_chat/README.md b/apps/mlc_cli_chat/README.md index e613a1770e..93110b9547 100644 --- a/apps/mlc_cli_chat/README.md +++ b/apps/mlc_cli_chat/README.md @@ -1,3 +1,3 @@ -# MLC Chat Cli Application +#MLC Chat Cli Application A native app application that can load and run MLC models on cli. diff --git a/apps/mlc_cli_chat/engine.cc b/apps/mlc_cli_chat/engine.cc index 6f093d06ff..a162f78d38 100644 --- a/apps/mlc_cli_chat/engine.cc +++ b/apps/mlc_cli_chat/engine.cc @@ -70,8 +70,9 @@ std::function EngineStateCli::get_request_stream_callb } std::string EngineStateCli::handle_chat_completion(tvm::runtime::Module mod, - const std::string& request_json, bool include_usage, - const std::string& request_id) { + const std::string& request_json, + bool include_usage, + const std::string& request_id) { // Clear the queue making sure that queue is empty // Not really required since this process should ideally make the queue empty { diff --git a/apps/mlc_cli_chat/mlc_cli_chat.cc b/apps/mlc_cli_chat/mlc_cli_chat.cc index 6864136425..bfdfc0b5c6 100644 --- a/apps/mlc_cli_chat/mlc_cli_chat.cc +++ b/apps/mlc_cli_chat/mlc_cli_chat.cc @@ -33,8 +33,7 @@ void printHelp() { << " --model-lib [optional] the full path to the model library file to use\n" << " --device (default: auto)\n" << " --with-prompt [optional] runs one session with given prompt\n" - << " --help [optional] Tool usage information\n" - ; + << " --help [optional] Tool usage information\n"; } // Method to parse the args