Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: suppport postpretrain in client #258

Merged
merged 7 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions python/qianfan/common/client/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from prompt_toolkit import prompt
from prompt_toolkit.completion import WordCompleter
from rich import print as rprint
from rich.console import Console, Group, RenderableType
from rich.console import Group, RenderableType
from rich.live import Live
from rich.markdown import Markdown
from rich.spinner import Spinner
Expand All @@ -36,6 +36,7 @@
list_model_option,
print_warn_msg,
render_response_debug_info,
replace_logger_handler,
)
from qianfan.consts import DefaultLLMModel
from qianfan.errors import InternalError
Expand Down Expand Up @@ -83,7 +84,7 @@ def __init__(
QfMessages() for _ in range(len(self.clients))
]
self.multi_line = multi_line
self.console = Console()
self.console = replace_logger_handler()
self.thread_pool = ThreadPoolExecutor(max_workers=len(self.clients))
self.inference_args = kwargs
if len(self.clients) != 1 and len(self.inference_args) != 0:
Expand Down Expand Up @@ -388,7 +389,6 @@ def chat_entry(
"""
Chat with the LLM in the terminal.
"""
qianfan.disable_log()
if model is None and endpoint is None:
model = DefaultLLMModel.ChatCompletion

Expand Down
4 changes: 2 additions & 2 deletions python/qianfan/common/client/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import prompt_toolkit
import typer
from rich.console import Console
from rich.markdown import Markdown

import qianfan
Expand All @@ -28,6 +27,7 @@
print_error_msg,
print_info_msg,
render_response_debug_info,
replace_logger_handler,
)
from qianfan.consts import DefaultLLMModel

Expand All @@ -51,7 +51,7 @@ def __init__(
self.model = model
self.endpoint = endpoint
self.plain = plain
self.console = Console(no_color=self.plain)
self.console = replace_logger_handler(no_color=self.plain)
self.debug = debug
self.inference_args = kwargs

Expand Down
9 changes: 5 additions & 4 deletions python/qianfan/common/client/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from typing import Any, List, Optional

import typer
from rich.console import Console, Group
from rich.console import Group
from rich.pretty import Pretty
from rich.rule import Rule
from rich.table import Table
Expand All @@ -31,6 +31,7 @@
print_error_msg,
print_info_msg,
print_success_msg,
replace_logger_handler,
timestamp,
)
from qianfan.consts import DefaultLLMModel
Expand Down Expand Up @@ -124,7 +125,7 @@ def save(
),
) -> None:
"""Save dataset to platform or local file."""
console = Console()
console = replace_logger_handler()
with console.status("Loading dataset..."):
src_dataset = load_dataset(src)

Expand Down Expand Up @@ -293,7 +294,7 @@ def view(
"""
View the content of the dataset.
"""
console = Console()
console = replace_logger_handler()
if extract_id_from_path(dataset) is not None:
check_credential()
with console.status("Loading dataset..."):
Expand Down Expand Up @@ -407,7 +408,7 @@ def predict(
) -> None:
"""Predict the dataset using a model and save to local file."""
input_column_list = input_columns.split(",")
console = Console()
console = replace_logger_handler()
with console.status("Loading dataset..."):
ds = load_dataset(
dataset, input_columns=input_column_list, reference_column=reference_column
Expand Down
7 changes: 4 additions & 3 deletions python/qianfan/common/client/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import List, Optional, Set

import typer
from rich.console import Console, RenderableType
from rich.console import RenderableType
from rich.pretty import Pretty
from rich.table import Table

Expand All @@ -26,6 +26,7 @@
print_error_msg,
print_info_msg,
print_warn_msg,
replace_logger_handler,
)
from qianfan.errors import InternalError
from qianfan.evaluation import EvaluationManager
Expand Down Expand Up @@ -64,7 +65,7 @@ def list_evaluable_models(
"""
if value:
model_list = ModelResource.evaluable_model_list()["result"]
console = Console()
console = replace_logger_handler()
table = Table(show_lines=True)
col_list = ["Model Name", "Platform Preset", "Train Type", "Model Version List"]
for col in col_list:
Expand Down Expand Up @@ -165,7 +166,7 @@ def run(
"""
ds = load_dataset(dataset_id, is_download_to_local=False)
model_list = [Model(version_id=m) for m in models]
console = Console()
console = replace_logger_handler()
evaluators: List[QianfanEvaluator] = []
if enable_rule_evaluator:
evaluators.append(
Expand Down
7 changes: 7 additions & 0 deletions python/qianfan/common/client/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from typing import Optional

import click
import typer
from typer.completion import completion_init, install_callback, show_callback

Expand Down Expand Up @@ -159,12 +160,18 @@ def entry(
" customize the installation."
),
),
log_level: str = typer.Option(
"WARN",
help="Set log level.",
click_type=click.Choice(["DEBUG", "INFO", "WARN", "ERROR"]),
),
) -> None:
"""
Qianfan CLI which provides access to various Qianfan services.
"""
global _enable_traceback
_enable_traceback = enable_traceback
qianfan.enable_log(log_level)


if __name__ == "__main__":
Expand Down
6 changes: 3 additions & 3 deletions python/qianfan/common/client/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from prompt_toolkit.document import Document
from prompt_toolkit.validation import ValidationError
from rich import print as rprint
from rich.console import Console, Group, RenderableType
from rich.console import Group, RenderableType
from rich.live import Live
from rich.markdown import Markdown
from rich.spinner import Spinner
Expand All @@ -36,6 +36,7 @@
credential_required,
print_error_msg,
render_response_debug_info,
replace_logger_handler,
)
from qianfan.resources.typing import QfMessages
from qianfan.utils.bos_uploader import BosHelper, parse_bos_path
Expand Down Expand Up @@ -105,7 +106,7 @@ def __init__(
self.client = qianfan.Plugin(endpoint=endpoint)
self.msg_history = QfMessages()
self.multi_line = multi_line
self.console = Console()
self.console = replace_logger_handler()
self.inference_args = kwargs
self.bos_path = bos_path
self.plugins = plugins
Expand Down Expand Up @@ -331,7 +332,6 @@ def plugin_entry(
"""
Chat with the LLM with plugins in the terminal.
"""
qianfan.disable_log()
model = None

extra_args = {}
Expand Down
Loading
Loading