From b097fff6824fbb478d2a38e974f2c92816b46739 Mon Sep 17 00:00:00 2001 From: Liu Jun Date: Thu, 1 Feb 2024 18:07:31 +0800 Subject: [PATCH 1/8] fix: render negative prompt with non-str variable (#239) * fix prompt render * revert format * enable negative prompt for non txt2image --- src/qianfan/common/prompt/prompt.py | 9 ++++----- src/qianfan/tests/prompt_class_test.py | 4 ++-- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/qianfan/common/prompt/prompt.py b/src/qianfan/common/prompt/prompt.py index e4374f2d..1b4c04da 100644 --- a/src/qianfan/common/prompt/prompt.py +++ b/src/qianfan/common/prompt/prompt.py @@ -290,17 +290,16 @@ def render(self, **kwargs: str) -> Tuple[str, Optional[str]]: raise InvalidArgumentError(f"variable `{v}` is not provided") prompt = prompt.replace(f"{left_id}{v}{right_id}", str(kwargs[v])) neg_prompt = None - if ( - self.scene_type == PromptSceneType.Text2Image - and self.negative_template is not None - ): + if self.negative_template is not None: if self.negative_variables is None: self.negative_variables = [] neg_prompt = self.negative_template for v in self.negative_variables: if v not in kwargs: raise InvalidArgumentError(f"variable `{v}` is not provided") - neg_prompt = neg_prompt.replace(f"{left_id}{v}{right_id}", kwargs[v]) + neg_prompt = neg_prompt.replace( + f"{left_id}{v}{right_id}", str(kwargs[v]) + ) return prompt, neg_prompt def delete(self) -> None: diff --git a/src/qianfan/tests/prompt_class_test.py b/src/qianfan/tests/prompt_class_test.py index 4da669f2..6ca41832 100644 --- a/src/qianfan/tests/prompt_class_test.py +++ b/src/qianfan/tests/prompt_class_test.py @@ -166,9 +166,9 @@ def test_render(): assert p.variables == ["v2", "v3"] assert p.render(v1="a", v2="3", v3="4") == ("{v1}3x 4", None) - p = Prompt(template="{v1} {v2}") + p = Prompt(template="{v1} {v2}", negative_template="{v1} {v3}") assert p.variables == ["v1", "v2"] - assert p.render(v1=1, v2={}) == ("1 {}", None) + assert p.render(v1=1, v2={}, v3=[]) == ("1 {}", "1 []") def test_delete(): From 155f75d85d91ca773532b182d242f07362d66533 Mon Sep 17 00:00:00 2001 From: NuODaniel Date: Thu, 1 Feb 2024 18:08:50 +0800 Subject: [PATCH 2/8] chore-sk-doc (#236) --- docs/semantic_kernel.md | 71 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 docs/semantic_kernel.md diff --git a/docs/semantic_kernel.md b/docs/semantic_kernel.md new file mode 100644 index 00000000..0275e9de --- /dev/null +++ b/docs/semantic_kernel.md @@ -0,0 +1,71 @@ +# semantic_kernel + +千帆SDK支持了SK的接入适配,开发者可以方便的在SK中集成千帆平台的大模型调用能力: +- QianfanChatCompletion 文本对话 +- QianfanTextCompletion 文本续写 +- QianfanTextEmbedding 文本向量化 + +## 快速开始 +```python +from qianfan.extensions.semantic_kernel import ( + QianfanChatCompletion, + QianfanChatRequestSettings, +) +import asyncio + +TEST_MESSAGES = [{"role":"user", "content":"hi"}] + +async def run_chat(): + qf_chat = QianfanChatCompletion(model="ERNIE-Bot") + # call chat with messages + res = await qf_chat.complete_chat_async( + TEST_MESSAGES, + QianfanChatRequestSettings(temperature=0.95), + ) + print(res) + + async for r in qf_chat.complete_chat_stream_async( + TEST_MESSAGES, QianfanChatRequestSettings() + ): + print(r) + + # completion with ChatCompletion + res = await qf_chat.complete_async( + TEST_MESSAGES[-1]["content"], QianfanChatRequestSettings() + ) + print(res) + + # streaming completion with ChatCompletion + async for r in qf_chat.complete_stream_async( + TEST_MESSAGES[-1]["content"], QianfanChatRequestSettings() + ): + print(r) + + +asyncio.run(run_chat()) +``` + +## 结合Semantic Kernel框架 +除了直接调用QianfanChatCompletion类的成员函数,我们也可以结合SemanticKernel中的Kernel,Skill一起使用: + +```python +from qianfan.extensions.semantic_kernel import ( + QianfanChatCompletion, +) + +# with kernel +import semantic_kernel as sk + +kernel = sk.Kernel() +kernel.add_text_completion_service( + "qianfan_comp", QianfanChatCompletion(model="ERNIE-Bot"), +) + +prompt = """{{$input}} +生成一段关于以上主题的笑话 +""" + +joke = kernel.create_semantic_function(prompt_template=prompt, temperature=0.2, top_p=0.5) + +print(joke("尔滨")) +``` \ No newline at end of file From 27ffe958aca8aa4d49552bec9d997361c90f69fb Mon Sep 17 00:00:00 2001 From: Dobiichi-Origami <56953648+Dobiichi-Origami@users.noreply.github.com> Date: Fri, 2 Feb 2024 09:30:23 +0800 Subject: [PATCH 3/8] feat:Model list api (#245) * add new api * 1 * add doc * remove --------- Co-authored-by: root --- docs/cli.md | 1 + src/qianfan/common/client/evaluation.py | 38 ++++++++++- src/qianfan/consts.py | 1 + src/qianfan/resources/console/model.py | 28 ++++++-- src/qianfan/tests/model_test.py | 11 ++- src/qianfan/tests/utils/mock_server.py | 90 +++++++++++++++++++++++++ 6 files changed, 163 insertions(+), 6 deletions(-) diff --git a/docs/cli.md b/docs/cli.md index d4000d9c..80682c71 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -277,6 +277,7 @@ $ qianfan trainer run [OPTIONS] * `--train-type TEXT`:训练类型 [required] * `--dataset-id INTEGER`:数据集 id [required] +* `--list-evaluable-models`: 打印支持进行评估的模型列表 * `--help`:展示帮助文档 训练相关配置,参数含义与 [训练 API 文档](https://cloud.baidu.com/doc/WENXINWORKSHOP/s/mlmrgo4yx#body%E5%8F%82%E6%95%B0) 中对应参数含义一致: diff --git a/src/qianfan/common/client/evaluation.py b/src/qianfan/common/client/evaluation.py index 3e6cd626..3be3b71b 100644 --- a/src/qianfan/common/client/evaluation.py +++ b/src/qianfan/common/client/evaluation.py @@ -16,7 +16,7 @@ from typing import List, Optional, Set import typer -from rich.console import Console +from rich.console import Console, RenderableType from rich.pretty import Pretty from rich.table import Table @@ -42,6 +42,7 @@ QianfanRuleEvaluator, ) from qianfan.model import Model +from qianfan.resources.console.model import Model as ModelResource evaluation_app = typer.Typer( no_args_is_help=True, @@ -54,6 +55,34 @@ MANUAL_EVALUATOR_PANEL = "Manual Evaluator Options" +@credential_required +def list_evaluable_models( + ctx: typer.Context, param: typer.CallbackParam, value: bool +) -> None: + """ + Print models of ChatCompletion and exit. + """ + if value: + model_list = ModelResource.evaluable_model_list()["result"] + console = Console() + table = Table(show_lines=True) + col_list = ["Model Name", "Train Type", "Model Version List"] + for col in col_list: + table.add_column(col) + for model in model_list: + row_items: List[RenderableType] = [] + row_items.append(f"{model['modelName']}\n[dim]{model['modelIdStr']}[/]") + row_items.append(model["trainType"]) + version_list = [ + f"{version['version']} [dim]({version['modelVersionIdStr']})[/]" + for version in model["modelVersionList"] + ] + row_items.append("\n".join(version_list)) + table.add_row(*row_items) + console.print(table) + raise typer.Exit() + + @evaluation_app.command() @credential_required def run( @@ -111,6 +140,13 @@ def run( help="Dimensions for evaluation. Use ',' to split multiple dimensions.", rich_help_panel=MANUAL_EVALUATOR_PANEL, ), + list_evaluable_models: Optional[bool] = typer.Option( + None, + "--list-evaluable-models", + callback=list_evaluable_models, + is_eager=True, + help="Print evaluable models.", + ), ) -> None: """ Run evaluation task. diff --git a/src/qianfan/consts.py b/src/qianfan/consts.py index 1bb75398..56d3f461 100644 --- a/src/qianfan/consts.py +++ b/src/qianfan/consts.py @@ -189,6 +189,7 @@ class Consts: ModelEvalResultExportStatusAPI: str = ( "/wenxinworkshop/modelrepo/eval/result/export/info" ) + ModelEvaluableModelListAPI: str = "/wenxinworkshop/modelrepo/eval/model/list" ServiceCreateAPI: str = "/wenxinworkshop/service/apply" ServiceDetailAPI: str = "/wenxinworkshop/service/detail" ServiceListAPI: str = "/wenxinworkshop/service/list" diff --git a/src/qianfan/resources/console/model.py b/src/qianfan/resources/console/model.py index cbfc2495..16854c29 100644 --- a/src/qianfan/resources/console/model.py +++ b/src/qianfan/resources/console/model.py @@ -453,7 +453,7 @@ def preset_list( @classmethod @console_api_request def user_list( - self, + cls, name_filter: Optional[str] = None, model_type: Optional[str] = None, order_by: Optional[str] = None, @@ -502,7 +502,7 @@ def user_list( @classmethod @console_api_request def batch_delete_model( - self, + cls, model_ids: List[Any], **kwargs: Any, ) -> QfRequest: @@ -528,7 +528,7 @@ def batch_delete_model( @classmethod @console_api_request def batch_delete_model_version( - self, + cls, model_version_ids: List[Any], **kwargs: Any, ) -> QfRequest: @@ -543,9 +543,29 @@ def batch_delete_model_version( Note: The `@console_api_request` decorator is applied to this method, enabling it to send the generated QfRequest and return a QfResponse to the user. - """ req = QfRequest(method="POST", url=Consts.ModelVersionBatchDeleteAPI) req.json_body = {"modelVersionIds": model_version_ids} return req + + @classmethod + @console_api_request + def evaluable_model_list( + cls, + **kwargs: Any, + ) -> QfRequest: + """ + get all evaluable model list + + Parameters: + **kwargs (Any): + arbitrary arguments + + Note: + The `@console_api_request` decorator is applied to this method, enabling it to + send the generated QfRequest and return a QfResponse to the user. + """ + + req = QfRequest(method="POST", url=Consts.ModelEvaluableModelListAPI) + return req diff --git a/src/qianfan/tests/model_test.py b/src/qianfan/tests/model_test.py index 02a70d61..9a3e5f3c 100644 --- a/src/qianfan/tests/model_test.py +++ b/src/qianfan/tests/model_test.py @@ -16,7 +16,6 @@ Unit test for FineTune """ - from qianfan.resources import Model from qianfan.resources.console.consts import EvaluationResultExportDestinationType @@ -150,3 +149,13 @@ def test_get_evaluation_result_export_task_status(): resp = Model.get_evaluation_result_export_task_status(12) assert resp["_request"]["exportID"] == 12 + + +def test_evaluable_model_list(): + """ + test Model.evaluable_model_list + """ + + resp = Model.evaluable_model_list() + + assert len(resp["_request"]) == 0 diff --git a/src/qianfan/tests/utils/mock_server.py b/src/qianfan/tests/utils/mock_server.py index 00f75971..dcd5b46f 100644 --- a/src/qianfan/tests/utils/mock_server.py +++ b/src/qianfan/tests/utils/mock_server.py @@ -1378,6 +1378,96 @@ def gen() -> str: return flask.Response(gen()) +@app.route(Consts.ModelEvaluableModelListAPI, methods=["POST"]) +@iam_auth_checker +def evaluable_model_list(): + """mock get all evaluable model list api""" + return json_response( + { + "log_id": "2347238209", + "result": [ + { + "modelId": 8, + "modelIdStr": "am-ay2k0r83q9qr", + "modelName": "ERNIE-Bot-turbo", + "source": "PlatformPreset", + "modelType": 0, + "trainType": "ernieBotLite", + "modelVersionList": [ + { + "modelVersionId": 600, + "modelVersionIdStr": "amv-nsjesf9kasjt", + "version": "ERNIE-Bot-turbo-0922", + "sourceType": "PlatformPreset", + "framework": "paddle", + "algorithm": "ERNIE_EB-ERNIEBOT_V202_FUSE", + "modelNet": "paddlepaddle-ERNIE_EB-ERNIEBOT_V202_LORA_FUSE", + "trainType": "ernieBotLite", + "description": "通过数据和策略迭代,提升模型生成效果。", + }, + { + "modelVersionId": 492, + "modelVersionIdStr": "amv-4u0rw8juur1p", + "version": "ERNIE-Bot-turbo-0725", + "sourceType": "PlatformPreset", + "framework": "paddle", + "algorithm": "ERNIE_EB-ERNIEBOT_V201_8K", + "modelNet": "paddlepaddle-ERNIE_EB-ERNIEBOT_V201_8K", + "trainType": "ernieBotLite", + "description": ( + "支持7K输入+1K输出,支持系统设置,新增推理参数" + ), + }, + { + "modelVersionId": 244, + "modelVersionIdStr": "amv-70ahikpspjqs", + "version": "ERNIE-Bot-turbo-0704", + "sourceType": "PlatformPreset", + "framework": "paddle", + "algorithm": "ERNIE_EB-ERNIEBOT_V200", + "modelNet": "paddlepaddle-ERNIE_EB-ERNIEBOT_V200", + "trainType": "ernieBotLite", + "description": "优化推理效果,修复部分问题", + }, + ], + }, + { + "modelId": 446, + "modelIdStr": "am-44f8ji8eegp0", + "modelName": "Yi-34B", + "source": "PlatformPreset", + "modelType": 0, + "trainType": "", + "modelVersionList": [ + { + "modelVersionId": 635, + "modelVersionIdStr": "amv-mpjrtxej6hye", + "version": "Yi-34B-Chat", + "sourceType": "PlatformPreset", + "framework": "Pytorch", + "algorithm": "opensource-yi-34b", + "modelNet": "pytorch-yi-34b-chat-1.13.1", + "trainType": "", + "description": "支持对话的chat版本", + }, + { + "modelVersionId": 620, + "modelVersionIdStr": "amv-ffff66e1fm3d", + "version": "Yi-34B", + "sourceType": "PlatformPreset", + "framework": "Pytorch", + "algorithm": "opensource-yi-34b", + "modelNet": "pytorch-yi-34b-1.13.1", + "trainType": "", + "description": "初始预训练版本", + }, + ], + }, + ], + } + ) + + @app.route(Consts.ServiceCreateAPI, methods=["POST"]) @iam_auth_checker def create_service(): From b51d3004f938a86139ffed241a7e56059e92be56 Mon Sep 17 00:00:00 2001 From: Liu Jun Date: Fri, 2 Feb 2024 09:50:14 +0800 Subject: [PATCH 4/8] fix: access token not refreshed when using stream (#240) * fix stream access token expired * fix async access token --- src/qianfan/resources/requestor/base.py | 78 ------------ .../resources/requestor/openapi_requestor.py | 111 +++++++++++++++++- src/qianfan/tests/retry_test.py | 37 ++++++ 3 files changed, 147 insertions(+), 79 deletions(-) diff --git a/src/qianfan/resources/requestor/base.py b/src/qianfan/resources/requestor/base.py index 640c2a12..a7e8d5dc 100644 --- a/src/qianfan/resources/requestor/base.py +++ b/src/qianfan/resources/requestor/base.py @@ -42,7 +42,6 @@ ) import qianfan.errors as errors -from qianfan.consts import Consts from qianfan.resources.http_client import HTTPClient from qianfan.resources.rate_limiter import RateLimiter from qianfan.resources.typing import QfRequest, QfResponse, RetryConfig @@ -256,49 +255,6 @@ def _request( resp.request.json_body = copy.deepcopy(request.json_body) return data_postprocess(resp) - @_with_latency - def _request_stream( - self, - request: QfRequest, - data_postprocess: Callable[[QfResponse], QfResponse] = lambda x: x, - ) -> Iterator[QfResponse]: - """ - stream sync request - """ - with self._rate_limiter: - responses = self._client.request_stream(request) - event = "" - for body, resp in responses: - _check_if_status_code_is_200(resp) - body_str = body.decode("utf-8") - if body_str == "": - continue - if body_str.startswith(Consts.STREAM_RESPONSE_EVENT_PREFIX): - # event indicator for the type of data - event = body_str[len(Consts.STREAM_RESPONSE_EVENT_PREFIX) :] - continue - elif not body_str.startswith(Consts.STREAM_RESPONSE_PREFIX): - try: - # the response might be error message in json format - json_body = json.loads(body_str) - self._check_error(json_body) - except json.JSONDecodeError: - # the response is not json format, ignore and raise InternalError - pass - - raise errors.RequestError( - f"got unexpected stream response from server: {body_str}" - ) - body_str = body_str[len(Consts.STREAM_RESPONSE_PREFIX) :] - json_body = json.loads(body_str) - if event != "": - json_body["_event"] = event - event = "" - parsed = self._parse_response(json_body, resp) - parsed.request = QfRequest.from_requests(resp.request) - parsed.request.json_body = copy.deepcopy(request.json_body) - yield data_postprocess(parsed) - @_with_latency async def _async_request( self, @@ -327,40 +283,6 @@ async def _async_request( resp.request.json_body = copy.deepcopy(request.json_body) return data_postprocess(resp) - @_with_latency - async def _async_request_stream( - self, - request: QfRequest, - data_postprocess: Callable[[QfResponse], QfResponse] = lambda x: x, - ) -> AsyncIterator[QfResponse]: - """ - async stream request - """ - async with self._rate_limiter: - responses = self._client.arequest_stream(request) - async for body, resp in responses: - _async_check_if_status_code_is_200(resp) - body_str = body.decode("utf-8") - if body_str.strip() == "": - continue - if not body_str.startswith(Consts.STREAM_RESPONSE_PREFIX): - try: - # the response might be error message in json format - json_body: Dict[str, Any] = json.loads(body_str) - self._check_error(json_body) - except json.JSONDecodeError: - # the response is not json format, ignore and raise RequestError - pass - raise errors.RequestError( - f"got unexpected stream response from server: {body_str}" - ) - body_str = body_str[len(Consts.STREAM_RESPONSE_PREFIX) :] - json_body = json.loads(body_str) - parsed = self._parse_async_response(json_body, resp) - parsed.request = QfRequest.from_aiohttp(resp.request_info) - parsed.request.json_body = copy.deepcopy(request.json_body) - yield data_postprocess(parsed) - def _parse_response( self, body: Dict[str, Any], resp: requests.Response ) -> QfResponse: diff --git a/src/qianfan/resources/requestor/openapi_requestor.py b/src/qianfan/resources/requestor/openapi_requestor.py index 5b715002..f3afc87c 100644 --- a/src/qianfan/resources/requestor/openapi_requestor.py +++ b/src/qianfan/resources/requestor/openapi_requestor.py @@ -16,6 +16,8 @@ Qianfan API Requestor """ +import copy +import json from typing import ( Any, AsyncIterator, @@ -34,7 +36,12 @@ from qianfan.consts import APIErrorCode, Consts from qianfan.resources.auth.iam import iam_sign from qianfan.resources.auth.oauth import Auth -from qianfan.resources.requestor.base import BaseAPIRequestor +from qianfan.resources.requestor.base import ( + BaseAPIRequestor, + _async_check_if_status_code_is_200, + _check_if_status_code_is_200, + _with_latency, +) from qianfan.resources.typing import QfRequest, QfResponse, RetryConfig from qianfan.utils.logging import log_error, log_info @@ -74,6 +81,64 @@ def retry_wrapper(*args: Any, **kwargs: Any) -> _T: return retry_wrapper + @_with_latency + def _request_stream( + self, + request: QfRequest, + data_postprocess: Callable[[QfResponse], QfResponse] = lambda x: x, + ) -> Iterator[QfResponse]: + """ + stream sync request + """ + with self._rate_limiter: + responses = self._client.request_stream(request) + event = "" + token_refreshed = False + while True: + try: + body, resp = next(responses) + except StopIteration: + break + _check_if_status_code_is_200(resp) + body_str = body.decode("utf-8") + if body_str == "": + continue + if body_str.startswith(Consts.STREAM_RESPONSE_EVENT_PREFIX): + # event indicator for the type of data + event = body_str[len(Consts.STREAM_RESPONSE_EVENT_PREFIX) :] + continue + elif not body_str.startswith(Consts.STREAM_RESPONSE_PREFIX): + try: + # the response might be error message in json format + json_body = json.loads(body_str) + self._check_error(json_body) + except errors.AccessTokenExpiredError: + if not token_refreshed: + token_refreshed = True + self._auth.refresh_access_token() + self._add_access_token(request) + with self._rate_limiter: + responses = self._client.request_stream(request) + continue + raise + + except json.JSONDecodeError: + # the response is not json format, ignore and raise InternalError + pass + + raise errors.RequestError( + f"got unexpected stream response from server: {body_str}" + ) + body_str = body_str[len(Consts.STREAM_RESPONSE_PREFIX) :] + json_body = json.loads(body_str) + if event != "": + json_body["_event"] = event + event = "" + parsed = self._parse_response(json_body, resp) + parsed.request = QfRequest.from_requests(resp.request) + parsed.request.json_body = copy.deepcopy(request.json_body) + yield data_postprocess(parsed) + def _async_retry_if_token_expired( self, func: Callable[..., Awaitable[_T]] ) -> Callable[..., Awaitable[_T]]: @@ -97,6 +162,50 @@ async def retry_wrapper(*args: Any, **kwargs: Any) -> _T: return retry_wrapper + @_with_latency + async def _async_request_stream( + self, + request: QfRequest, + data_postprocess: Callable[[QfResponse], QfResponse] = lambda x: x, + ) -> AsyncIterator[QfResponse]: + """ + async stream request + """ + async with self._rate_limiter: + responses = self._client.arequest_stream(request) + token_refreshed = False + async for body, resp in responses: + _async_check_if_status_code_is_200(resp) + body_str = body.decode("utf-8") + if body_str.strip() == "": + continue + if not body_str.startswith(Consts.STREAM_RESPONSE_PREFIX): + try: + # the response might be error message in json format + json_body: Dict[str, Any] = json.loads(body_str) + self._check_error(json_body) + except json.JSONDecodeError: + # the response is not json format, ignore and raise RequestError + pass + except errors.AccessTokenExpiredError: + if not token_refreshed: + token_refreshed = True + await self._auth.arefresh_access_token() + await self._async_add_access_token(request) + async with self._rate_limiter: + responses = self._client.arequest_stream(request) + continue + raise + raise errors.RequestError( + f"got unexpected stream response from server: {body_str}" + ) + body_str = body_str[len(Consts.STREAM_RESPONSE_PREFIX) :] + json_body = json.loads(body_str) + parsed = self._parse_async_response(json_body, resp) + parsed.request = QfRequest.from_aiohttp(resp.request_info) + parsed.request.json_body = copy.deepcopy(request.json_body) + yield data_postprocess(parsed) + def _check_error(self, body: Dict[str, Any]) -> None: """ check whether error_code in response body diff --git a/src/qianfan/tests/retry_test.py b/src/qianfan/tests/retry_test.py index 8ec92df9..5c92964e 100644 --- a/src/qianfan/tests/retry_test.py +++ b/src/qianfan/tests/retry_test.py @@ -37,6 +37,43 @@ def test_retry_accesstoken_expired(): assert "id" in resp["body"] assert resp["object"] == "completion" assert comp.access_token() != access_token + with EnvHelper(QIANFAN_ACCESS_TOKEN=access_token): + comp = qianfan.Completion() + assert comp.access_token() == access_token + resp = comp.do(prompt="test", stream=True) + for r in resp: + assert r is not None + assert r["code"] == 200 + assert "id" in r["body"] + assert r["object"] == "completion" + assert comp.access_token() != access_token + + +@pytest.mark.asyncio +async def test_async_retry_accesstoken_expired(): + """ + Test retry access token expired + """ + access_token = "expired" + with EnvHelper(QIANFAN_ACCESS_TOKEN=access_token): + comp = qianfan.Completion() + assert comp.access_token() == access_token + resp = await comp.ado(prompt="test") + assert resp is not None + assert resp["code"] == 200 + assert "id" in resp["body"] + assert resp["object"] == "completion" + assert comp.access_token() != access_token + with EnvHelper(QIANFAN_ACCESS_TOKEN=access_token): + comp = qianfan.Completion() + assert comp.access_token() == access_token + resp = await comp.ado(prompt="test", stream=True) + async for r in resp: + assert r is not None + assert r["code"] == 200 + assert "id" in r["body"] + assert r["object"] == "completion" + assert comp.access_token() != access_token def test_retry_retry_cnt(): From b658f2656a23603af20e6a576bd26b4422702fa9 Mon Sep 17 00:00:00 2001 From: Liu Jun Date: Fri, 2 Feb 2024 15:13:33 +0800 Subject: [PATCH 5/8] feat: support qianfan plugin & reset conversation in client (#243) * support qianfan plugin in client * fix lint * support extra_parameters * add doc * fix lint * optimize input --- docs/cli.md | 32 +++ src/qianfan/common/client/chat.py | 73 ++++-- src/qianfan/common/client/main.py | 2 + src/qianfan/common/client/plugin.py | 360 ++++++++++++++++++++++++++++ src/qianfan/common/client/utils.py | 22 +- src/qianfan/resources/llm/base.py | 4 +- 6 files changed, 470 insertions(+), 23 deletions(-) create mode 100644 src/qianfan/common/client/plugin.py diff --git a/docs/cli.md b/docs/cli.md index 80682c71..5fefbf32 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -32,7 +32,9 @@ $ qianfan [OPTIONS] COMMAND [ARGS]... * `chat` 对话 * `completion` 补全 * `txt2img` 文生图 +* `plugin` 插件 * `dataset` 数据集 +* `evalutaion` 评估 ### chat 对话 @@ -53,6 +55,12 @@ $ qianfan chat [OPTIONS] * `--debug`:调试模式,会打印请求相关的原始信息。 * `--help`:展示帮助文档 +在对话进行过程中,可以通过输入命令实现如下功能: + +* `/reset`:重置对话,清空对话历史 +* `/exit`:结束对话 +* `/help`:展示帮助信息 + ### completion 补全 ![completion](./imgs/cli/completion.gif) @@ -104,6 +112,30 @@ $ qianfan txt2img [OPTIONS] PROMPT * `--debug`:调试模式,会打印请求相关的原始信息。 * `--help`:展示帮助文档 +### plugin 插件 + +**用法**: + +```console +$ qianfan plugin [OPTIONS] +``` + +**Options 选项**: + +* `--endpoint TEXT`:千帆插件的 endpoint [required] +* `--multi-line / --no-multi-line`:多行模式,提交时需要先按下 Esc 再回车,以避免与文本换行冲突 [default:no-multi-line] +* `--plugins`:启用的插件列表,通过 `,` 分隔不同的插件,例如 `uuid-zhishiku,uuid-chatocr,uuid-weatherforecast` +* `--debug`:调试模式,会打印请求相关的原始信息 +* `--bos-path`:BOS 路径,用于上传文件 +* `--help`:展示帮助文档 + +在对话进行过程中,可以通过输入命令实现如下功能: + +* `/image [file_path]`:上传图片并附加至对话中,`file_path` 可以是网络上的链接,也可以是本地文件路径。其中,本地文件会被上传至 BOS 路径,因此需要提供 `bos-path` 参数。 +* `/reset`:重置对话,清空对话历史 +* `/exit`:结束对话 +* `/help`:展示帮助信息 + ### dataset 数据集 ![](./imgs/cli/dataset.webp) diff --git a/src/qianfan/common/client/chat.py b/src/qianfan/common/client/chat.py index 2267c38c..c5eea4a8 100644 --- a/src/qianfan/common/client/chat.py +++ b/src/qianfan/common/client/chat.py @@ -13,11 +13,13 @@ # limitations under the License. +import json from concurrent.futures import ThreadPoolExecutor, wait from typing import Any, List, Optional, Tuple import typer from prompt_toolkit import prompt +from prompt_toolkit.completion import WordCompleter from rich import print as rprint from rich.console import Console, Group, RenderableType from rich.live import Live @@ -29,9 +31,9 @@ import qianfan from qianfan import QfRole from qianfan.common.client.utils import ( + InputEmptyValidator, credential_required, list_model_option, - print_error_msg, print_warn_msg, render_response_debug_info, ) @@ -40,14 +42,25 @@ from qianfan.resources.llm.chat_completion import ChatCompletion from qianfan.resources.typing import QfMessages, QfResponse -END_PROMPT = "\exit" - class ChatClient(object): """ Client object for the chat command """ + END_PROMPT = "/exit" + RESET_PROMPT = "/reset" + HELP_PROMPT = "/help" + + HELP_MESSAGES = { + END_PROMPT: "End the conversation", + RESET_PROMPT: "Reset the conversation", + HELP_PROMPT: "Print help message", + } + input_completer = WordCompleter( + list(HELP_MESSAGES.keys()), sentence=True, meta_dict=HELP_MESSAGES + ) + def __init__( self, model: Optional[str], @@ -168,18 +181,25 @@ def print_hint_msg(self) -> None: if self.multi_line: rprint( "[bold]Hint[/bold]: [green bold]Press Esc before Enter[/] to submit" - f" your message, and use '{END_PROMPT}' to end the conversation." + f" your message, and use '{self.END_PROMPT}' to end the conversation." ) else: rprint( "[bold]Hint[/bold]: Press enter to submit your message, and use" - f" '{END_PROMPT}' to end the conversation." + f" '{self.END_PROMPT}' to end the conversation." ) rprint( "[bold]Hint[/bold]: If you want to submit multiple lines, use the" " '--multi-line' option." ) + def print_help_message(self) -> None: + """ + Print command introduction + """ + for k, v in self.HELP_MESSAGES.items(): + rprint(f"[bold green]{k}[/]: {v}") + def chat_in_terminal(self) -> None: """ Chat in terminal @@ -188,20 +208,12 @@ def chat_in_terminal(self) -> None: self.print_hint_msg() # loop the conversation while True: - # loop the input and check whether the input is valid - while True: - rprint("\n[yellow bold]Enter your message[/yellow bold]:") - message = prompt(multiline=self.multi_line).strip() - # break the loop if input is valid - if len(message) != 0: - break - # if message is empty, print error message and continue to input - print_error_msg("Message cannot be empty!") - - for i in range(len(self.clients)): - msg_history = self.msg_history[i] - if msg_history is not None: - msg_history.append(message) + rprint("\n[yellow bold]Enter your message[/yellow bold]:") + message = prompt( + multiline=self.multi_line, + validator=InputEmptyValidator(), + completer=self.input_completer, + ).strip() extra_info = ( "" @@ -212,9 +224,21 @@ def chat_in_terminal(self) -> None: f"\n[blue][bold]Model response[/bold][/blue][dim] {extra_info}[/dim]:" ) - if message == END_PROMPT: + if message == self.END_PROMPT: rprint("Bye!") raise typer.Exit() + elif message == self.RESET_PROMPT: + self.msg_history = [QfMessages() for _ in range(len(self.clients))] + rprint("Chat history has been cleared.") + continue + elif message == self.HELP_PROMPT: + self.print_help_message() + continue + + for i in range(len(self.clients)): + msg_history = self.msg_history[i] + if msg_history is not None: + msg_history.append(message) # List of (received_msg, is_end, response) for each client msg_list: List[Tuple[str, bool, Optional[QfResponse]]] = [ @@ -260,7 +284,7 @@ def model_response_worker( live.update(self.render_model_response(msg_list)) except Exception as e: msg_list[i] = ( - msg_list[i][0] + "\n\n**Got Exception**: " + str(e), + msg_list[i][0] + "\n\n**Got Exception**: " + repr(e), True, None, ) @@ -355,6 +379,11 @@ def chat_entry( enable_citation: Optional[bool] = typer.Option( None, help="Enable citation", rich_help_panel=MODEL_ARGUMENTS_PANEL ), + extra_parameters: Optional[str] = typer.Option( + None, + help="Extra parameters for the model. This should be a json string.", + rich_help_panel=MODEL_ARGUMENTS_PANEL, + ), ) -> None: """ Chat with the LLM in the terminal. @@ -378,6 +407,8 @@ def add_if_not_none(key: str, value: Any) -> None: if stop is not None: extra_args["stop"] = stop.split(",") + if extra_parameters is not None: + extra_args["extra_parameters"] = json.loads(extra_parameters) client = ChatClient(model, endpoint, multi_line, debug=debug, **extra_args) client.chat_in_terminal() diff --git a/src/qianfan/common/client/main.py b/src/qianfan/common/client/main.py index 1422d74c..7a95ed4c 100644 --- a/src/qianfan/common/client/main.py +++ b/src/qianfan/common/client/main.py @@ -23,6 +23,7 @@ from qianfan.common.client.dataset import dataset_app from qianfan.common.client.embedding import embedding_entry from qianfan.common.client.evaluation import evaluation_app +from qianfan.common.client.plugin import plugin_entry from qianfan.common.client.trainer import trainer_app from qianfan.common.client.txt2img import txt2img_entry from qianfan.common.client.utils import print_error_msg @@ -37,6 +38,7 @@ app.command(name="completion")(completion_entry) app.command(name="txt2img")(txt2img_entry) app.command(name="embedding", no_args_is_help=True)(embedding_entry) +app.command(name="plugin")(plugin_entry) app.add_typer(dataset_app, name="dataset") app.add_typer(trainer_app, name="trainer") app.add_typer(evaluation_app, name="evaluation") diff --git a/src/qianfan/common/client/plugin.py b/src/qianfan/common/client/plugin.py new file mode 100644 index 00000000..b1805c66 --- /dev/null +++ b/src/qianfan/common/client/plugin.py @@ -0,0 +1,360 @@ +# Copyright (c) 2023 Baidu, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +from typing import Any, Dict, List, Optional, Tuple + +import typer +from prompt_toolkit import prompt +from prompt_toolkit.completion import WordCompleter +from prompt_toolkit.document import Document +from prompt_toolkit.validation import ValidationError +from rich import print as rprint +from rich.console import Console, Group, RenderableType +from rich.live import Live +from rich.markdown import Markdown +from rich.spinner import Spinner +from rich.text import Text + +import qianfan +from qianfan import QfRole +from qianfan.common.client.utils import ( + BosPathValidator, + InputEmptyValidator, + credential_required, + print_error_msg, + render_response_debug_info, +) +from qianfan.resources.typing import QfMessages +from qianfan.utils.bos_uploader import BosHelper, parse_bos_path + + +class PluginInputValidator(InputEmptyValidator): + """ + Validator for input in plugin + """ + + def validate(self, document: Document) -> None: + """ + validate input: + - input must not be empty + - if input is /image, file path must be provided + """ + super().validate(document) + text = document.text.strip() + if text.startswith(PluginClient.IMAGE_PROMPT): + path = text[len(PluginClient.IMAGE_PROMPT) :].strip() + if len(path) == 0: + raise ValidationError( + message="Image file path must be provided (e.g. /image car.jpg)." + ) + + +class PluginClient(object): + """ + Client object for the chat command + """ + + END_PROMPT = "/exit" + RESET_PROMPT = "/reset" + IMAGE_PROMPT = "/image" + HELP_PROMPT = "/help" + + HELP_MESSAGES = { + END_PROMPT: "End the conversation", + RESET_PROMPT: "Reset the conversation", + IMAGE_PROMPT: "Attach a local image to the conversation (e.g. /image car.jpg)", + HELP_PROMPT: "Print help message", + } + + input_completer = WordCompleter( + list(HELP_MESSAGES.keys()), sentence=True, meta_dict=HELP_MESSAGES + ) + + def __init__( + self, + model: Optional[str], + endpoint: Optional[str], + multi_line: bool, + debug: bool, + plugins: List[str], + bos_path: Optional[str], + **kwargs: Any, + ) -> None: + """ + Init the chat client + """ + if model is not None: + print_error_msg("ERNIE Bot pulgin is currently not available in sdk.") + raise typer.Exit(1) + if endpoint is None: + print_error_msg("Endpoint must be provided for qianfan plugin.") + raise typer.Exit(1) + self.client = qianfan.Plugin(endpoint=endpoint) + self.msg_history = QfMessages() + self.multi_line = multi_line + self.console = Console() + self.inference_args = kwargs + self.bos_path = bos_path + self.plugins = plugins + self.debug = debug + + def print_hint_msg(self) -> None: + """ + Print hint message when startup + """ + if self.multi_line: + rprint( + "[bold]Hint[/bold]: [green bold]Press Esc before Enter[/] to submit" + f" your message, or use '{self.HELP_PROMPT}' to acquire more commands." + ) + else: + rprint( + "[bold]Hint[/bold]: Press enter to submit your message, or use" + f" '{self.HELP_PROMPT}' to acquire more commands.." + ) + rprint( + "[bold]Hint[/bold]: If you want to submit multiple lines, use the" + " '--multi-line' option." + ) + rprint(f"[dim]Using qianfan plugin with {self.plugins}...[/]") + + def get_bos_path(self) -> str: + """ + Get bos path. If bos_path is not provided, prompt user to input + """ + if self.bos_path is None: + rprint("Please input bos bucket path [dim](bos://)[/]: ") + self.bos_path = prompt("bos:/", validator=BosPathValidator()) + self.bos_path = "bos:/" + self.bos_path + if not self.bos_path.endswith("/"): + self.bos_path = self.bos_path + "/" + return self.bos_path + + def upload_file_to_bos(self, filepath: str) -> Tuple[str, str]: + """ + Upload file to bos and get bos_url and http_url + """ + bos_helper = BosHelper() + bucket, bos_path = parse_bos_path(self.get_bos_path()) + + bos_path = bos_path + os.path.basename(filepath) + with self.console.status("Uploading file to bos..."): + bos_helper.upload_file_to_bos(filepath, bos_path, bucket) + url = bos_helper.get_bos_file_shared_url(bos_path, bucket) + bos_url = f"bos:/{bucket}{bos_path}" + return bos_url, url + + def print_help_message(self) -> None: + """ + Print command introduction + """ + for k, v in self.HELP_MESSAGES.items(): + rprint(f"[bold green]{k}[/]: {v}") + + def chat_in_terminal(self) -> None: + """ + Chat in terminal + """ + + self.print_hint_msg() + # loop the conversation + extra_field: Dict[str, Any] = {} + while True: + rprint("\n[yellow bold]Enter your message[/yellow bold]:") + while True: + message = prompt( + multiline=self.multi_line, + validator=PluginInputValidator(), + completer=self.input_completer, + ).strip() + if message.startswith(self.IMAGE_PROMPT): + path = message[len(self.IMAGE_PROMPT) :].strip() + if not path.startswith("http"): + bos_path, http_path = self.upload_file_to_bos(path) + rprint(f"File has been uploaded to: {bos_path}") + rprint(f"File share url: {http_path}\n") + path = http_path + rprint( + "[yellow bold]Please continue to input your prompt[/yellow" + " bold]:" + ) + + extra_field["fileurl"] = path + continue + break + + rprint("\n[blue][bold]Model response[/bold][/blue]:") + + if message == self.END_PROMPT: + rprint("Bye!") + raise typer.Exit() + elif message == self.RESET_PROMPT: + self.msg_history = QfMessages() + extra_field = {} + rprint("Chat history has been cleared.") + continue + elif message == self.HELP_PROMPT: + self.print_help_message() + continue + else: + self.msg_history.append(message) + + with Live( + Spinner("dots", text="Thinking...", style="status.spinner"), + auto_refresh=True, + refresh_per_second=24, + console=self.console, + ) as live: + response = self.client.do( + message, + plugins=self.plugins, + llm=self.inference_args, + stream=True, + history=self.msg_history._to_list()[:-1], + **extra_field, + ) + + m = "" + for r in response: + render_list: List[RenderableType] = [] + m += r["result"] + render_list.append(Markdown(m)) + if not r["is_end"]: + render_list.append( + Spinner( + "dots", text="Generating...", style="status.spinner" + ) + ) + stat = r.statistic + render_list.append( + Text.from_markup( + "\n[dim]First token latentcy:" + f" {stat['first_token_latency']:.2f}s, Total latency:" + f" {stat['total_latency']:.2f}s.[/]" + ) + ) + if r["is_end"]: + if "usage" in r: + token_usage = r["usage"] + render_list.append( + Text.from_markup( + f"[dim]Input token: {token_usage['prompt_tokens']}," + " Output token:" + f" {token_usage['completion_tokens']}, Total token:" + f" {token_usage['total_tokens']}.[/]" + ) + ) + if self.debug: + render_list.append(render_response_debug_info(response=r)) + + live.update(Group(*render_list)) + + self.msg_history.append(m, role=QfRole.Assistant) + + +MODEL_ARGUMENTS_PANEL = ( + "Model Arguments (Some arguments are not supported by every model)" +) + + +@credential_required +def plugin_entry( + endpoint: Optional[str] = typer.Option( + ..., + help="Endpoint of the plugin.", + ), + # tui: bool = typer.Option(False, help="Using Terminal UI"), + multi_line: bool = typer.Option( + False, + "--multi-line", + help="Multi-line mode which needs to press Esc before enter to submit message.", + ), + debug: bool = typer.Option( + False, + "--debug", + help="Enable debug mode. The request infomation will be printed.", + ), + plugins: str = typer.Option( + ..., + help=( + "Plugins enabled. Use comma(,) to split. (e.g." + " uuid-zhishiku,uuid-chatocr,uuid-weatherforecast)" + ), + ), + bos_path: Optional[str] = typer.Option(None, help="Bos path used for upload file."), + temperature: Optional[float] = typer.Option( + None, + help=( + "Controls the randomness of the generated text. A higher temperature makes" + " the model more creative and produces more diverse, but potentially less" + " coherent." + ), + rich_help_panel=MODEL_ARGUMENTS_PANEL, + ), + top_p: Optional[float] = typer.Option( + None, + help=( + "Lower top_p value allows the model to focus on a narrowed set of likely" + " next tokens, making the response more conherent but less random." + ), + rich_help_panel=MODEL_ARGUMENTS_PANEL, + ), + penalty_score: Optional[float] = typer.Option( + None, + help="Penalty scores can be applied to discourage repetition.", + rich_help_panel=MODEL_ARGUMENTS_PANEL, + ), + system: Optional[str] = typer.Option( + None, + help="Persona setting for the model.", + rich_help_panel=MODEL_ARGUMENTS_PANEL, + ), + stop: Optional[str] = typer.Option( + None, + help="Stop words. Use comma to split multiple stop words.", + rich_help_panel=MODEL_ARGUMENTS_PANEL, + ), +) -> None: + """ + Chat with the LLM with plugins in the terminal. + """ + qianfan.disable_log() + model = None + + extra_args = {} + + def add_if_not_none(key: str, value: Any) -> None: + if value is not None: + extra_args[key] = value + + add_if_not_none("temperature", temperature) + add_if_not_none("top_p", top_p) + add_if_not_none("penalty_score", penalty_score) + add_if_not_none("system", system) + + if stop is not None: + extra_args["stop"] = stop.split(",") + + client = PluginClient( + model, + endpoint, + multi_line, + debug=debug, + plugins=plugins.split(","), + bos_path=bos_path, + **extra_args, + ) + client.chat_in_terminal() diff --git a/src/qianfan/common/client/utils.py b/src/qianfan/common/client/utils.py index db5d6d93..ae6c6058 100644 --- a/src/qianfan/common/client/utils.py +++ b/src/qianfan/common/client/utils.py @@ -22,6 +22,8 @@ import click import typer from prompt_toolkit import prompt +from prompt_toolkit.document import Document +from prompt_toolkit.validation import ValidationError, Validator from rich import print as rprint from rich.console import Console, Group, RenderableType from rich.highlighter import JSONHighlighter @@ -34,7 +36,7 @@ from qianfan import QfResponse from qianfan.resources.llm.base import BaseResource from qianfan.resources.typing import QfRequest -from qianfan.utils.bos_uploader import BosHelper +from qianfan.utils.bos_uploader import BosHelper, parse_bos_path from qianfan.utils.utils import camel_to_snake, snake_to_camel BaseResourceType = TypeVar("BaseResourceType", bound=BaseResource) @@ -287,3 +289,21 @@ def render_response_debug_info(response: QfResponse) -> Group: ) return Group(*render_list) + + +class InputEmptyValidator(Validator): + def validate(self, document: Document) -> None: + text = document.text + if len(text.strip()) == 0: + raise ValidationError(message="Input cannot be empty") + + +class BosPathValidator(Validator): + def validate(self, document: Document) -> None: + text = document.text.strip() + if len(text) == 0: + raise ValidationError(message="Input cannot be empty") + try: + parse_bos_path("bos:/" + text) + except ValueError: + raise ValidationError(message="Invalid BOS path") diff --git a/src/qianfan/resources/llm/base.py b/src/qianfan/resources/llm/base.py index d124202d..982e9726 100644 --- a/src/qianfan/resources/llm/base.py +++ b/src/qianfan/resources/llm/base.py @@ -505,7 +505,9 @@ def _generate_body( f"The required key `{key}` is not provided." ) kwargs["stream"] = stream - kwargs["extra_parameters"] = {"request_source": f"qianfan_py_sdk_v{VERSION}"} + if "extra_parameters" not in kwargs: + kwargs["extra_parameters"] = {} + kwargs["extra_parameters"]["request_source"] = f"qianfan_py_sdk_v{VERSION}" return kwargs def _data_postprocess(self, data: QfResponse) -> QfResponse: From 62049e3135dbe05e925b8dea9d00cebbebed6c2a Mon Sep 17 00:00:00 2001 From: Liu Jun Date: Fri, 2 Feb 2024 15:36:49 +0800 Subject: [PATCH 6/8] feat: golang version sdk (#224) * feat: mv src to python * init go * optimize * support stream * optimize code * optimize code * add ut * add ut * add ci * update go version * update ci * fix mock server * update script * update script * fix makefile * update script * update script * fix * optimize setting * fix * compat * fix lint * add python version * fix ci * fix script * fix script * update script * fix script * fix script * fix script * fix script * fix script * fix script * update version * add ut * add lint * fix ci * fix folder * update readme && move doc template * update readme * refactor * refactor option * test * test * test * test * fix lint * rename * refactor * reset * fix script * update readme * fix lint * add comment * add api error * change go version * add close in body * update workflow * Go (#4) * feat: mv src to python * init go * optimize * support stream * optimize code * optimize code * add ut * add ut * add ci * update go version * update ci * fix mock server * update script * update script * fix makefile * update script * update script * fix * optimize setting * fix * compat * fix lint * add python version * fix ci * fix script * fix script * update script * fix script * fix script * fix script * fix script * fix script * fix script * update version * add ut * add lint * fix ci * fix folder * update readme && move doc template * update readme * refactor * refactor option * test * test * test * test * fix lint * rename * Multi language go (#3) * fix: dataset, trainer, evaluaton bug fix (#237) * fix * add more import on __init__.py in dataset package * cookbook update * refactor * reset * fix script * update readme * fix lint * add comment * add api error * change go version * add close in body * update workflow --------- Co-authored-by: Dobiichi-Origami <56953648+Dobiichi-Origami@users.noreply.github.com> --------- Co-authored-by: shikuan Co-authored-by: Dobiichi-Origami <56953648+Dobiichi-Origami@users.noreply.github.com> * fix workflow * add cache to actions * fix workflow * update workflow * update workflow * remove changelog * fix rtd * add go release action * add copyright --------- Co-authored-by: shikuan Co-authored-by: Dobiichi-Origami <56953648+Dobiichi-Origami@users.noreply.github.com> --- .github/workflows/doc_build.yml | 60 +++ .github/workflows/doc_release.yml | 52 +++ .github/workflows/go_ci.yml | 65 ++++ .github/workflows/go_release.yml | 78 ++++ .github/workflows/{pyci.yml => py_ci.yml} | 28 +- .../workflows/{release.yml => py_release.yml} | 57 +-- .gitignore | 3 +- .readthedocs.yaml | 4 +- CHANGELOG.md | 79 ---- Makefile | 35 +- {src/qianfan/docs => docs/template}/conf.py_t | 0 .../docs => docs/template}/root_doc.rst_t | 0 go/Makefile | 2 + go/README.md | 186 +++++++++ go/qianfan/base_model.go | 101 +++++ go/qianfan/chat_completion.go | 223 +++++++++++ go/qianfan/chat_completion_test.go | 115 ++++++ go/qianfan/completion.go | 194 ++++++++++ go/qianfan/completion_test.go | 93 +++++ go/qianfan/config.go | 76 ++++ go/qianfan/consts.go | 25 ++ go/qianfan/embdding.go | 130 +++++++ go/qianfan/embedding_test.go | 38 ++ go/qianfan/go.mod | 34 ++ go/qianfan/go.sum | 69 ++++ go/qianfan/logger.go | 21 ++ go/qianfan/options.go | 44 +++ go/qianfan/requestor.go | 354 ++++++++++++++++++ go/qianfan/utils.go | 27 ++ go/qianfan/version.go | 25 ++ .../__init__.py => javascript/README.md | 0 .coveragerc => python/.coveragerc | 4 +- python/Makefile | 29 ++ README.pypi.md => python/README.pypi.md | 0 pyproject.toml => python/pyproject.toml | 4 +- {src => python}/qianfan/__init__.py | 0 {src => python}/qianfan/common/__init__.py | 0 .../qianfan/common/client}/__init__.py | 0 {src => python}/qianfan/common/client/chat.py | 0 .../qianfan/common/client/completion.py | 0 .../qianfan/common/client/dataset.py | 0 .../qianfan/common/client/embedding.py | 0 .../qianfan/common/client/evaluation.py | 0 {src => python}/qianfan/common/client/main.py | 0 .../qianfan/common/client/plugin.py | 0 .../qianfan/common/client/trainer.py | 0 .../qianfan/common/client/txt2img.py | 0 .../qianfan/common/client/utils.py | 0 .../qianfan/common/hub/__init__.py | 0 {src => python}/qianfan/common/hub/hub.py | 0 .../qianfan/common/hub/interface.py | 0 .../qianfan/common/prompt}/__init__.py | 0 .../qianfan/common/prompt/prompt.py | 0 .../qianfan/common/runnable}/__init__.py | 0 .../qianfan/common/runnable/base.py | 0 .../qianfan/common/tool/baidu_search_tool.py | 0 .../qianfan/common/tool/base_tool.py | 0 .../common/tool/duckduckgo_search_tool.py | 0 .../qianfan/common/tool/wikipedia_tool.py | 0 {src => python}/qianfan/config.py | 0 {src => python}/qianfan/consts.py | 0 {src => python}/qianfan/dataset/__init__.py | 0 {src => python}/qianfan/dataset/consts.py | 0 .../qianfan/dataset/data_operator.py | 0 .../qianfan/dataset/data_source/__init__.py | 0 .../dataset/data_source/baidu_qianfan.py | 0 .../qianfan/dataset/data_source/base.py | 0 .../qianfan/dataset/data_source/bos.py | 0 .../qianfan/dataset/data_source/file.py | 0 .../qianfan/dataset/data_source/utils.py | 0 {src => python}/qianfan/dataset/dataset.py | 0 .../qianfan/dataset/dataset_utils.py | 0 .../dataset/local_data_operators/__init__.py | 0 .../dataset/local_data_operators/base.py | 0 .../check_character_repetition_filter.py | 0 .../check_flagged_words.py | 0 .../check_sentence_length_filter.py | 0 .../check_special_characters.py | 0 .../local_data_operators/check_stopwords.py | 0 .../local_data_operators/check_word_number.py | 0 .../check_word_repetition_filter.py | 0 .../dataset/local_data_operators/consts.py | 0 .../dataset/local_data_operators/utils.py | 0 .../dataset/local_data_operators/word_list.py | 0 .../qianfan/dataset/process_interface.py | 0 .../qianfan/dataset/qianfan_data_operators.py | 161 ++++++++ {src => python}/qianfan/dataset/schema.py | 0 {src => python}/qianfan/dataset/table.py | 0 .../qianfan/dataset/table_utils.py | 0 {src => python}/qianfan/errors.py | 0 .../qianfan/evaluation/__init__.py | 0 {src => python}/qianfan/evaluation/consts.py | 0 .../qianfan/evaluation/evaluation_manager.py | 0 .../qianfan/evaluation/evaluation_result.py | 0 .../qianfan/evaluation/evaluator.py | 0 .../evaluation/opencompass_evaluator.py | 0 {src => python}/qianfan/extensions/README.md | 0 .../qianfan/extensions}/__init__.py | 0 .../qianfan/extensions/langchain/__init__.py | 0 .../extensions/langchain/agents/__init__.py | 0 .../agents/baidu_qianfan_endpoint.py | 0 .../extensions/semantic_kernel/__init__.py | 0 .../semantic_kernel/connectors}/__init__.py | 0 .../connectors/qianfan_chat_completion.py | 0 .../connectors/qianfan_settings.py | 0 .../connectors/qianfan_text_completion.py | 0 .../connectors/qianfan_text_embedding.py | 0 {src => python}/qianfan/model/__init__.py | 0 {src => python}/qianfan/model/configs.py | 0 {src => python}/qianfan/model/consts.py | 0 {src => python}/qianfan/model/model.py | 0 {src => python}/qianfan/py.typed | 0 {src => python}/qianfan/resources/__init__.py | 0 .../qianfan/resources/auth}/__init__.py | 0 {src => python}/qianfan/resources/auth/iam.py | 0 .../qianfan/resources/auth/oauth.py | 0 .../qianfan/resources/console}/__init__.py | 0 .../qianfan/resources/console/consts.py | 0 .../qianfan/resources/console/data.py | 0 .../qianfan/resources/console/finetune.py | 0 .../qianfan/resources/console/model.py | 0 .../qianfan/resources/console/prompt.py | 0 .../qianfan/resources/console/service.py | 0 .../qianfan/resources/console/utils.py | 0 .../qianfan/resources/http_client.py | 0 .../qianfan/resources/images}/__init__.py | 0 .../qianfan/resources/images/image2text.py | 0 .../qianfan/resources/images/text2image.py | 0 .../qianfan/resources/llm}/__init__.py | 0 {src => python}/qianfan/resources/llm/base.py | 0 .../qianfan/resources/llm/chat_completion.py | 0 .../qianfan/resources/llm/completion.py | 0 .../qianfan/resources/llm/embedding.py | 0 .../qianfan/resources/llm/plugin.py | 0 .../qianfan/resources/rate_limiter.py | 0 .../qianfan/resources/requestor}/__init__.py | 0 .../qianfan/resources/requestor/base.py | 0 .../resources/requestor/console_requestor.py | 0 .../resources/requestor/openapi_requestor.py | 0 .../qianfan/resources/tools}/__init__.py | 0 .../qianfan/resources/tools/tokenizer.py | 0 .../qianfan/resources/tools/utils.py | 0 {src => python}/qianfan/resources/typing.py | 0 .../qianfan/tests}/__init__.py | 0 {src => python}/qianfan/tests/auth_test.py | 0 .../qianfan/tests/chat_completion_test.py | 0 .../qianfan/tests/completion_test.py | 0 {src => python}/qianfan/tests/config_test.py | 0 {src => python}/qianfan/tests/conftest.py | 0 .../qianfan/tests/data_api_test.py | 0 .../qianfan/tests/dataset}/__init__.py | 0 .../qianfan/tests/dataset/data_source_test.py | 0 .../qianfan/tests/dataset/dataset_test.py | 0 .../qianfan/tests/dataset/table_test.py | 0 .../qianfan/tests/embedding_test.py | 0 .../qianfan/tests/finetune_test.py | 0 {src => python}/qianfan/tests/hub_test.py | 0 .../qianfan/tests/image2text_test.py | 0 python/qianfan/tests/langchain/__init__.py | 0 .../qianfan/tests/langchain/agent_test.py | 0 {src => python}/qianfan/tests/latency_test.py | 0 {src => python}/qianfan/tests/model_test.py | 0 {src => python}/qianfan/tests/plugin_test.py | 0 .../qianfan/tests/prompt_class_test.py | 0 .../qianfan/tests/prompt_resource_test.py | 0 .../qianfan/tests/rate_limiter_test.py | 0 {src => python}/qianfan/tests/retry_test.py | 0 {src => python}/qianfan/tests/service_test.py | 0 .../qianfan/tests/text2image_test.py | 0 .../qianfan/tests/tokenizer_test.py | 0 {src => python}/qianfan/tests/tool_test.py | 0 {src => python}/qianfan/tests/trainer_test.py | 0 .../qianfan/tests/utils/__init__.py | 0 .../qianfan/tests/utils/mock_server.py | 2 +- {src => python}/qianfan/tests/utils/utils.py | 0 {src => python}/qianfan/trainer/__init__.py | 0 {src => python}/qianfan/trainer/actions.py | 0 {src => python}/qianfan/trainer/base.py | 0 {src => python}/qianfan/trainer/configs.py | 0 {src => python}/qianfan/trainer/consts.py | 0 {src => python}/qianfan/trainer/event.py | 0 {src => python}/qianfan/trainer/finetune.py | 0 {src => python}/qianfan/utils/__init__.py | 0 {src => python}/qianfan/utils/bos_uploader.py | 0 {src => python}/qianfan/utils/helper.py | 0 {src => python}/qianfan/utils/logging.py | 0 .../qianfan/utils/pydantic/__init__.py | 0 {src => python}/qianfan/utils/utils.py | 0 {src => python}/qianfan/version.py | 5 +- python/scripts/build.sh | 14 + python/scripts/build_doc.sh | 16 + {src => python}/scripts/release_github.sh | 0 python/scripts/run_mock_server.sh | 22 ++ {src => python}/scripts/run_test.sh | 0 src/scripts/build.sh | 14 - src/scripts/build_doc.sh | 9 - 196 files changed, 2322 insertions(+), 176 deletions(-) create mode 100644 .github/workflows/doc_build.yml create mode 100644 .github/workflows/doc_release.yml create mode 100644 .github/workflows/go_ci.yml create mode 100644 .github/workflows/go_release.yml rename .github/workflows/{pyci.yml => py_ci.yml} (63%) rename .github/workflows/{release.yml => py_release.yml} (55%) delete mode 100644 CHANGELOG.md rename {src/qianfan/docs => docs/template}/conf.py_t (100%) rename {src/qianfan/docs => docs/template}/root_doc.rst_t (100%) create mode 100644 go/Makefile create mode 100644 go/README.md create mode 100644 go/qianfan/base_model.go create mode 100644 go/qianfan/chat_completion.go create mode 100644 go/qianfan/chat_completion_test.go create mode 100644 go/qianfan/completion.go create mode 100644 go/qianfan/completion_test.go create mode 100644 go/qianfan/config.go create mode 100644 go/qianfan/consts.go create mode 100644 go/qianfan/embdding.go create mode 100644 go/qianfan/embedding_test.go create mode 100644 go/qianfan/go.mod create mode 100644 go/qianfan/go.sum create mode 100644 go/qianfan/logger.go create mode 100644 go/qianfan/options.go create mode 100644 go/qianfan/requestor.go create mode 100644 go/qianfan/utils.go create mode 100644 go/qianfan/version.go rename src/qianfan/common/client/__init__.py => javascript/README.md (100%) rename .coveragerc => python/.coveragerc (79%) create mode 100644 python/Makefile rename README.pypi.md => python/README.pypi.md (100%) rename pyproject.toml => python/pyproject.toml (98%) rename {src => python}/qianfan/__init__.py (100%) rename {src => python}/qianfan/common/__init__.py (100%) rename {src/qianfan/common/prompt => python/qianfan/common/client}/__init__.py (100%) rename {src => python}/qianfan/common/client/chat.py (100%) rename {src => python}/qianfan/common/client/completion.py (100%) rename {src => python}/qianfan/common/client/dataset.py (100%) rename {src => python}/qianfan/common/client/embedding.py (100%) rename {src => python}/qianfan/common/client/evaluation.py (100%) rename {src => python}/qianfan/common/client/main.py (100%) rename {src => python}/qianfan/common/client/plugin.py (100%) rename {src => python}/qianfan/common/client/trainer.py (100%) rename {src => python}/qianfan/common/client/txt2img.py (100%) rename {src => python}/qianfan/common/client/utils.py (100%) rename {src => python}/qianfan/common/hub/__init__.py (100%) rename {src => python}/qianfan/common/hub/hub.py (100%) rename {src => python}/qianfan/common/hub/interface.py (100%) rename {src/qianfan/common/runnable => python/qianfan/common/prompt}/__init__.py (100%) rename {src => python}/qianfan/common/prompt/prompt.py (100%) rename {src/qianfan/extensions => python/qianfan/common/runnable}/__init__.py (100%) rename {src => python}/qianfan/common/runnable/base.py (100%) rename {src => python}/qianfan/common/tool/baidu_search_tool.py (100%) rename {src => python}/qianfan/common/tool/base_tool.py (100%) rename {src => python}/qianfan/common/tool/duckduckgo_search_tool.py (100%) rename {src => python}/qianfan/common/tool/wikipedia_tool.py (100%) rename {src => python}/qianfan/config.py (100%) rename {src => python}/qianfan/consts.py (100%) rename {src => python}/qianfan/dataset/__init__.py (100%) rename {src => python}/qianfan/dataset/consts.py (100%) rename src/qianfan/dataset/qianfan_data_operators.py => python/qianfan/dataset/data_operator.py (100%) rename {src => python}/qianfan/dataset/data_source/__init__.py (100%) rename {src => python}/qianfan/dataset/data_source/baidu_qianfan.py (100%) rename {src => python}/qianfan/dataset/data_source/base.py (100%) rename {src => python}/qianfan/dataset/data_source/bos.py (100%) rename {src => python}/qianfan/dataset/data_source/file.py (100%) rename {src => python}/qianfan/dataset/data_source/utils.py (100%) rename {src => python}/qianfan/dataset/dataset.py (100%) rename {src => python}/qianfan/dataset/dataset_utils.py (100%) rename {src => python}/qianfan/dataset/local_data_operators/__init__.py (100%) rename {src => python}/qianfan/dataset/local_data_operators/base.py (100%) rename {src => python}/qianfan/dataset/local_data_operators/check_character_repetition_filter.py (100%) rename {src => python}/qianfan/dataset/local_data_operators/check_flagged_words.py (100%) rename {src => python}/qianfan/dataset/local_data_operators/check_sentence_length_filter.py (100%) rename {src => python}/qianfan/dataset/local_data_operators/check_special_characters.py (100%) rename {src => python}/qianfan/dataset/local_data_operators/check_stopwords.py (100%) rename {src => python}/qianfan/dataset/local_data_operators/check_word_number.py (100%) rename {src => python}/qianfan/dataset/local_data_operators/check_word_repetition_filter.py (100%) rename {src => python}/qianfan/dataset/local_data_operators/consts.py (100%) rename {src => python}/qianfan/dataset/local_data_operators/utils.py (100%) rename {src => python}/qianfan/dataset/local_data_operators/word_list.py (100%) rename {src => python}/qianfan/dataset/process_interface.py (100%) create mode 100644 python/qianfan/dataset/qianfan_data_operators.py rename {src => python}/qianfan/dataset/schema.py (100%) rename {src => python}/qianfan/dataset/table.py (100%) rename {src => python}/qianfan/dataset/table_utils.py (100%) rename {src => python}/qianfan/errors.py (100%) rename {src => python}/qianfan/evaluation/__init__.py (100%) rename {src => python}/qianfan/evaluation/consts.py (100%) rename {src => python}/qianfan/evaluation/evaluation_manager.py (100%) rename {src => python}/qianfan/evaluation/evaluation_result.py (100%) rename {src => python}/qianfan/evaluation/evaluator.py (100%) rename {src => python}/qianfan/evaluation/opencompass_evaluator.py (100%) rename {src => python}/qianfan/extensions/README.md (100%) rename {src/qianfan/extensions/semantic_kernel/connectors => python/qianfan/extensions}/__init__.py (100%) rename {src => python}/qianfan/extensions/langchain/__init__.py (100%) rename {src => python}/qianfan/extensions/langchain/agents/__init__.py (100%) rename {src => python}/qianfan/extensions/langchain/agents/baidu_qianfan_endpoint.py (100%) rename {src => python}/qianfan/extensions/semantic_kernel/__init__.py (100%) rename {src/qianfan/resources/auth => python/qianfan/extensions/semantic_kernel/connectors}/__init__.py (100%) rename {src => python}/qianfan/extensions/semantic_kernel/connectors/qianfan_chat_completion.py (100%) rename {src => python}/qianfan/extensions/semantic_kernel/connectors/qianfan_settings.py (100%) rename {src => python}/qianfan/extensions/semantic_kernel/connectors/qianfan_text_completion.py (100%) rename {src => python}/qianfan/extensions/semantic_kernel/connectors/qianfan_text_embedding.py (100%) rename {src => python}/qianfan/model/__init__.py (100%) rename {src => python}/qianfan/model/configs.py (100%) rename {src => python}/qianfan/model/consts.py (100%) rename {src => python}/qianfan/model/model.py (100%) rename {src => python}/qianfan/py.typed (100%) rename {src => python}/qianfan/resources/__init__.py (100%) rename {src/qianfan/resources/console => python/qianfan/resources/auth}/__init__.py (100%) rename {src => python}/qianfan/resources/auth/iam.py (100%) rename {src => python}/qianfan/resources/auth/oauth.py (100%) rename {src/qianfan/resources/images => python/qianfan/resources/console}/__init__.py (100%) rename {src => python}/qianfan/resources/console/consts.py (100%) rename {src => python}/qianfan/resources/console/data.py (100%) rename {src => python}/qianfan/resources/console/finetune.py (100%) rename {src => python}/qianfan/resources/console/model.py (100%) rename {src => python}/qianfan/resources/console/prompt.py (100%) rename {src => python}/qianfan/resources/console/service.py (100%) rename {src => python}/qianfan/resources/console/utils.py (100%) rename {src => python}/qianfan/resources/http_client.py (100%) rename {src/qianfan/resources/llm => python/qianfan/resources/images}/__init__.py (100%) rename {src => python}/qianfan/resources/images/image2text.py (100%) rename {src => python}/qianfan/resources/images/text2image.py (100%) rename {src/qianfan/resources/requestor => python/qianfan/resources/llm}/__init__.py (100%) rename {src => python}/qianfan/resources/llm/base.py (100%) rename {src => python}/qianfan/resources/llm/chat_completion.py (100%) rename {src => python}/qianfan/resources/llm/completion.py (100%) rename {src => python}/qianfan/resources/llm/embedding.py (100%) rename {src => python}/qianfan/resources/llm/plugin.py (100%) rename {src => python}/qianfan/resources/rate_limiter.py (100%) rename {src/qianfan/resources/tools => python/qianfan/resources/requestor}/__init__.py (100%) rename {src => python}/qianfan/resources/requestor/base.py (100%) rename {src => python}/qianfan/resources/requestor/console_requestor.py (100%) rename {src => python}/qianfan/resources/requestor/openapi_requestor.py (100%) rename {src/qianfan/tests => python/qianfan/resources/tools}/__init__.py (100%) rename {src => python}/qianfan/resources/tools/tokenizer.py (100%) rename {src => python}/qianfan/resources/tools/utils.py (100%) rename {src => python}/qianfan/resources/typing.py (100%) rename {src/qianfan/tests/dataset => python/qianfan/tests}/__init__.py (100%) rename {src => python}/qianfan/tests/auth_test.py (100%) rename {src => python}/qianfan/tests/chat_completion_test.py (100%) rename {src => python}/qianfan/tests/completion_test.py (100%) rename {src => python}/qianfan/tests/config_test.py (100%) rename {src => python}/qianfan/tests/conftest.py (100%) rename {src => python}/qianfan/tests/data_api_test.py (100%) rename {src/qianfan/tests/langchain => python/qianfan/tests/dataset}/__init__.py (100%) rename {src => python}/qianfan/tests/dataset/data_source_test.py (100%) rename {src => python}/qianfan/tests/dataset/dataset_test.py (100%) rename {src => python}/qianfan/tests/dataset/table_test.py (100%) rename {src => python}/qianfan/tests/embedding_test.py (100%) rename {src => python}/qianfan/tests/finetune_test.py (100%) rename {src => python}/qianfan/tests/hub_test.py (100%) rename {src => python}/qianfan/tests/image2text_test.py (100%) create mode 100644 python/qianfan/tests/langchain/__init__.py rename {src => python}/qianfan/tests/langchain/agent_test.py (100%) rename {src => python}/qianfan/tests/latency_test.py (100%) rename {src => python}/qianfan/tests/model_test.py (100%) rename {src => python}/qianfan/tests/plugin_test.py (100%) rename {src => python}/qianfan/tests/prompt_class_test.py (100%) rename {src => python}/qianfan/tests/prompt_resource_test.py (100%) rename {src => python}/qianfan/tests/rate_limiter_test.py (100%) rename {src => python}/qianfan/tests/retry_test.py (100%) rename {src => python}/qianfan/tests/service_test.py (100%) rename {src => python}/qianfan/tests/text2image_test.py (100%) rename {src => python}/qianfan/tests/tokenizer_test.py (100%) rename {src => python}/qianfan/tests/tool_test.py (100%) rename {src => python}/qianfan/tests/trainer_test.py (100%) rename {src => python}/qianfan/tests/utils/__init__.py (100%) rename {src => python}/qianfan/tests/utils/mock_server.py (99%) rename {src => python}/qianfan/tests/utils/utils.py (100%) rename {src => python}/qianfan/trainer/__init__.py (100%) rename {src => python}/qianfan/trainer/actions.py (100%) rename {src => python}/qianfan/trainer/base.py (100%) rename {src => python}/qianfan/trainer/configs.py (100%) rename {src => python}/qianfan/trainer/consts.py (100%) rename {src => python}/qianfan/trainer/event.py (100%) rename {src => python}/qianfan/trainer/finetune.py (100%) rename {src => python}/qianfan/utils/__init__.py (100%) rename {src => python}/qianfan/utils/bos_uploader.py (100%) rename {src => python}/qianfan/utils/helper.py (100%) rename {src => python}/qianfan/utils/logging.py (100%) rename {src => python}/qianfan/utils/pydantic/__init__.py (100%) rename {src => python}/qianfan/utils/utils.py (100%) rename {src => python}/qianfan/version.py (85%) create mode 100644 python/scripts/build.sh create mode 100644 python/scripts/build_doc.sh rename {src => python}/scripts/release_github.sh (100%) create mode 100644 python/scripts/run_mock_server.sh rename {src => python}/scripts/run_test.sh (100%) delete mode 100644 src/scripts/build.sh delete mode 100644 src/scripts/build_doc.sh diff --git a/.github/workflows/doc_build.yml b/.github/workflows/doc_build.yml new file mode 100644 index 00000000..67a57d3d --- /dev/null +++ b/.github/workflows/doc_build.yml @@ -0,0 +1,60 @@ +on: + push: + branches: ['main'] + pull_request: + path: + - "python/**" + - "docs/**" + - "README.md" + workflow_dispatch: + +name: Docs + +defaults: + run: + shell: bash + working-directory: ./python + +env: + RUFF_OUTPUT_FORMAT: github + +jobs: + build: + name: Build + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: + - 'ubuntu-latest' + python-version: + - '3.11' + steps: + - name: Checkout code + uses: actions/checkout@v4 + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: ${{matrix.python-version}} + - name: Install Poetry + run: | + pip install poetry + echo "$HOME/.poetry/bin" >> $GITHUB_PATH + poetry lock + - name: Get Poetry version + run: poetry --version + - name: Setup Python Cache + uses: actions/setup-python@v5 + with: + python-version: ${{matrix.python-version}} + cache: "poetry" + - name: Install deps + run: | + make install + - name: Build docs + run: make doc + - name: Prepare artifact + run: tar czvf /tmp/docs.tar.gz ./output/docs/_build + - uses: actions/upload-artifact@v4 + with: + name: docs + path: /tmp/docs.tar.gz diff --git a/.github/workflows/doc_release.yml b/.github/workflows/doc_release.yml new file mode 100644 index 00000000..51f3fb4f --- /dev/null +++ b/.github/workflows/doc_release.yml @@ -0,0 +1,52 @@ +# Simple workflow for deploying static content to GitHub Pages +name: Release docs + +on: + # Allows you to run this workflow manually from the Actions tab + workflow_dispatch: + workflow_run: + workflows: [Go Release, Python Release] + types: + - completed + +# Sets permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages +permissions: + contents: read + pages: write + id-token: write + +# Allow only one concurrent deployment, skipping runs queued between the run in-progress and latest queued. +# However, do NOT cancel in-progress runs as we want to allow these production deployments to complete. +concurrency: + group: "pages" + cancel-in-progress: false + +jobs: + docs: + environment: + name: github-pages + url: ${{ steps.deployment.outputs.page_url }} + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Setup Pages + uses: actions/configure-pages@v3 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + - name: Install + run: | + pip3 install poetry + make install + - name: Build docs + run: | + make doc + - name: Upload artifact + uses: actions/upload-pages-artifact@v2 + with: + path: './output/docs/_build/html' + - name: Deploy to GitHub Pages + id: deployment + uses: actions/deploy-pages@v2 diff --git a/.github/workflows/go_ci.yml b/.github/workflows/go_ci.yml new file mode 100644 index 00000000..8bc4f498 --- /dev/null +++ b/.github/workflows/go_ci.yml @@ -0,0 +1,65 @@ +on: + push: + branches: ['main'] + pull_request: + path: + - "go/**" + workflow_dispatch: + +name: Go CI + +defaults: + run: + shell: bash + +jobs: + build: + name: Unit tests + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: + - 'ubuntu-latest' + - 'macos-latest' + - 'windows-latest' + go-version: [ '1.19', '1.20', '1.21.x' ] + python-version: [ '3.11' ] + steps: + - name: Checkout code + uses: actions/checkout@v4 + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: ${{matrix.python-version}} + - name: Install Poetry + run: | + pip install poetry + echo "$HOME/.poetry/bin" >> $GITHUB_PATH + cd python && poetry lock + - name: Setup Python Cache + uses: actions/setup-python@v5 + with: + python-version: ${{matrix.python-version}} + cache: "poetry" + - name: Setup Go ${{ matrix.go-version }} + uses: actions/setup-go@v4 + with: + go-version: ${{ matrix.go-version }} + cache-dependency-path: go/qianfan/go.mod + - name: Display version + run: | + go version + python --version + poetry --version + - name: Install deps + run: | + make install + - name: Run test + run: | + make mock + make -C go test + - name: golangci-lint + uses: golangci/golangci-lint-action@v3 + with: + version: latest + working-directory: go/qianfan diff --git a/.github/workflows/go_release.yml b/.github/workflows/go_release.yml new file mode 100644 index 00000000..aa28f8a9 --- /dev/null +++ b/.github/workflows/go_release.yml @@ -0,0 +1,78 @@ +name: Go Release + +on: + workflow_dispatch: + +defaults: + run: + shell: bash + +jobs: + build: + name: Release + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: + - 'ubuntu-latest' + go-version: [ '1.21.x' ] + python-version: [ '3.11' ] + steps: + - name: Checkout code + uses: actions/checkout@v4 + - name: Setup Go ${{ matrix.go-version }} + uses: actions/setup-go@v4 + with: + go-version: ${{ matrix.go-version }} + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: ${{matrix.python-version}} + - name: Display version + run: | + go version + python --version + - name: Install latest version of Poetry + if: steps.cache-poetry.outputs.cache-hit != 'true' + run: | + pip install poetry + - name: Add Poetry to $PATH + run: | + echo "$HOME/.poetry/bin" >> $GITHUB_PATH + - name: Get Poetry version + run: poetry --version + - name: Install deps + if: steps.cache-deps.cache-hit != 'true' + run: | + make install + - name: Run test + run: | + make mock + make -C go test + - name: golangci-lint + uses: golangci/golangci-lint-action@v3 + with: + version: latest + working-directory: go/qianfan + - name: Check version + id: check-version + run: | + VERSION=$(cat go/qianfan/version.go|grep 'const Version ='|sed -r "s/.*\"v(.*)\"/\1/g") + echo $VERSION + echo "version=$VERSION" >> "$GITHUB_OUTPUT" + - name: Create Release + uses: ncipollo/release-action@v1 + with: + name: ${{ format('Go v{0}', steps.check-version.outputs.version) }} + token: ${{ secrets.QF_GITHUB_TOKEN }} + draft: false + generateReleaseNotes: true + tag: ${{ format('go/qianfan/v{0}', steps.check-version.outputs.version) }} + prerelease: ${{ contains(steps.check-version.outputs.version, 'rc') }} + commit: main + - name: Update Go module index + env: + GOPROXY: proxy.golang.org + VERSION: ${{ steps.check-version.outputs.version }} + run: | + go list -m github.com/baidubce/bce-qianfan-sdk/go/qianfan@v$VERSION diff --git a/.github/workflows/pyci.yml b/.github/workflows/py_ci.yml similarity index 63% rename from .github/workflows/pyci.yml rename to .github/workflows/py_ci.yml index f36d4066..10ac628a 100644 --- a/.github/workflows/pyci.yml +++ b/.github/workflows/py_ci.yml @@ -2,20 +2,23 @@ on: push: branches: ['main'] pull_request: + path: + - "python/**" workflow_dispatch: -name: Build +name: Python CI defaults: run: shell: bash + working-directory: ./python env: RUFF_OUTPUT_FORMAT: github jobs: build: - name: Build tool + name: Unit tests runs-on: ${{ matrix.os }} strategy: matrix: @@ -33,23 +36,22 @@ jobs: - name: Checkout code uses: actions/checkout@v4 - name: Setup Python - uses: actions/setup-python@v2.2.2 + uses: actions/setup-python@v5 with: python-version: ${{matrix.python-version}} - # Only runs when key from caching step changes - - name: Install latest version of Poetry - if: steps.cache-poetry.outputs.cache-hit != 'true' + - name: Install Poetry run: | pip install poetry - # Poetry still needs to be re-prepended to the PATH on each run, since - # PATH does not persist between runs. - - name: Add Poetry to $PATH - run: | echo "$HOME/.poetry/bin" >> $GITHUB_PATH + poetry lock - name: Get Poetry version run: poetry --version + - name: Setup Python Cache + uses: actions/setup-python@v5 + with: + python-version: ${{matrix.python-version}} + cache: "poetry" - name: Install deps - if: steps.cache-deps.cache-hit != 'true' run: | make install - name: run lint @@ -58,3 +60,7 @@ jobs: run: make test - name: Build artifacts run: make build + - uses: actions/upload-artifact@v4 + if: ${{ matrix.python-version == '3.11' && matrix.os == 'ubuntu-latest' }} + with: + path: ${{ github.workspace }}/python/output/** diff --git a/.github/workflows/release.yml b/.github/workflows/py_release.yml similarity index 55% rename from .github/workflows/release.yml rename to .github/workflows/py_release.yml index ae0f6597..dc2cb9c2 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/py_release.yml @@ -1,30 +1,29 @@ # Simple workflow for deploying static content to GitHub Pages -name: Release docs & dist +name: Python Release on: # Allows you to run this workflow manually from the Actions tab workflow_dispatch: -# Sets permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages -permissions: - contents: read - pages: write - id-token: write - # Allow only one concurrent deployment, skipping runs queued between the run in-progress and latest queued. # However, do NOT cancel in-progress runs as we want to allow these production deployments to complete. concurrency: group: "pages" cancel-in-progress: false +defaults: + run: + shell: bash + working-directory: ./python + jobs: release: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: "3.11" - name: Install @@ -36,22 +35,26 @@ jobs: run: make test - name: Build artifacts run: make build + - name: Build docs + run: make doc - name: Check Version id: check-version run: | echo version=$(poetry version --short) >> $GITHUB_OUTPUT - name: cleanup before release run: | - zip -r ./output/docs.zip ./output/docs/html + zip -r ./output/docs.zip ./output/docs/_build/html rm -rf ./output/docs - name: Create Release uses: ncipollo/release-action@v1 with: + name: ${{ format('Python v{0}', steps.check-version.outputs.version) }} artifacts: "output/*" token: ${{ secrets.QF_GITHUB_TOKEN }} draft: false generateReleaseNotes: true - tag: ${{ steps.check-version.outputs.version }} + tag: ${{ format('py/v{0}', steps.check-version.outputs.version) }} + prerelease: ${{ contains(steps.check-version.outputs.version, 'rc') }} commit: main - name: release to pypi env: @@ -59,33 +62,5 @@ jobs: run: | export TAG_NAME=${{ steps.check-version.outputs.version }} echo "Triggered by tag: $TAG_NAME" - bash ./src/scripts/release_github.sh - # Single deploy job since we're just deploying - docs: - environment: - name: github-pages - url: ${{ steps.deployment.outputs.page_url }} - runs-on: ubuntu-latest - steps: - - name: Checkout - uses: actions/checkout@v3 - - name: Setup Pages - uses: actions/configure-pages@v3 - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: "3.11" - - name: Install - run: | - pip3 install poetry - poetry install - - name: Build docs - run: | - make doc - - name: Upload artifact - uses: actions/upload-pages-artifact@v2 - with: - path: './build/docs/_build/html' - - name: Deploy to GitHub Pages - id: deployment - uses: actions/deploy-pages@v2 \ No newline at end of file + bash ./scripts/release_github.sh + diff --git a/.gitignore b/.gitignore index ef50b2f9..daf32c92 100644 --- a/.gitignore +++ b/.gitignore @@ -58,4 +58,5 @@ poetry.lock .*_cache .env* .ipynb_checkpoints -.DS_Store \ No newline at end of file +.DS_Store +coverage.out \ No newline at end of file diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 342e72fa..b517da19 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -14,13 +14,13 @@ build: - pip install poetry - poetry config virtualenvs.create false post_install: - - poetry install --no-interaction + - make install pre_build: - make doc # Build documentation in the "docs/" directory with Sphinx sphinx: - configuration: build/docs/conf.py + configuration: output/docs/conf.py # You can configure Sphinx to use a different builder, for instance use the dirhtml builder for simpler URLs # builder: "dirhtml" # Fail on all warnings to avoid broken references diff --git a/CHANGELOG.md b/CHANGELOG.md deleted file mode 100644 index 3fc2c3d7..00000000 --- a/CHANGELOG.md +++ /dev/null @@ -1,79 +0,0 @@ -# Release Note - -## 0.1.0 - -Feature: -- 增加ChatCompletion,Completion -对于[ERNIE-Bot-4](https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t)的支持,通过`model=ERNIE-Bot-4`进行调用 -- `qianfan` SDK正式开源 -- 增加对 `模型管理` API 支持(`qianfan.Model`) -- 增加对 `模型服务发布` API支持 (`qianfan.Service`) - -- 开放[SDK API接口文档](https://qianfan.readthedocs.io/en/stable/qianfan.html) - -Bug修复: -- 修复同时使用model,endpoint导致的模型使用错误的问题 - -优化: -- 增加 `模型SFT调优` 任务控制支持 -- 限流接口优化,支持传入QPS限制配置 - -## 0.0.7 - -Feature: -- 增加对 SFT 相关管控类 API 支持(`qianfan.FineTune`) -- 增加模型相关接口的限流功能 -Bug修复: -- 修复使用Completion时,通过参数传递 AK、SK 仍会提示未找到 AK、SK 的问题 -- 修复使用Completion时传model+endpoint导致异常的问题 -优化: -- 对 Auth 模块进行重构,同时支持同步及异步请求 -- 默认不再使用 ERNIE-Bot-SDK -- 更新错误码 - -## 0.0.6 - -Bug修复: - -- 修复 Embedding 内置模型 endpoint 错误的问题 -- 修复 Completion 同时提供 chat 模型和自定义 endpoint 时,无法正确使用 chat 模型进行模拟的问题 -- 修复 QfMessage 没有正确处理 QfRole 的问题 - -## 0.0.5 - -> commit 149b5d76...01e68dce - -Bug修复: -- 修复为同一个AK、SK第二次手动设置AccessToken时不生效的问题 -- 重构重试逻辑 - - 修复API返回错误时,非高负载的错误异常未抛出的问题 - - 修复多次重试失败后,没有异常抛出的问题 -- 修复打印 warn 和 error 级别日志时打印过多trace的问题 -- 修复无法正确识别AccessTokenExpired错误的问题 - -Feature: -- QfMessage支持function_call类型的消息 -- 增加控制SDK log级别的功能 qianfan.enable_log(logging.INFO)/ qianfan.disable_log(),默认 WARN 级别 -- 增加对 EBSDK 是否安装的检测,若未安装则回退全部使用千帆SDK实现 - -周边工具相关: -- 单元测试 - - 切换至 pytest,对测试整体架构进行重构 - - 增加 mock_server 用于测试 - - 增加对 auth 相关的测试用例 - - 增加测试相关脚本 - - make test即可进行测试,并打印覆盖率的报告 -- 构建相关 - - 采用 Poetry 进行包管理(pyproject.toml) - - setuptools 被标记 depracated,新的python库(如langchain)都采用这一工具进行管理 - - 支持区分包的依赖和开发时的依赖 - - pypi上可以设置例如主页等更为详细的元信息 - - 包版本的设置统一在 pyproject.toml 文件中设置,代码中qianfan.__version__修改成动态获取 - - 增加构建脚本 - - make build即可生成whl以及文档,输出在 output 目录中 - - make doc可以生成文档 - - 流水线上构建可以产出产物 -- 流水线相关 - - 支持流水线上进行单测(流水线python版本3.7,可测试兼容性) - - 支持发布流水线 - - 拉出 release-{VERSION}分支后可手动触发 \ No newline at end of file diff --git a/Makefile b/Makefile index 2411c5b9..8488c27e 100644 --- a/Makefile +++ b/Makefile @@ -1,29 +1,38 @@ -build: - bash src/scripts/build.sh +prepare_output: + mkdir -p output + +build: prepare_output + $(MAKE) -C python build + mv python/output/* ./output + rm -rf python/output install: - poetry install -E all + $(MAKE) -C python install uninstall: pip uninstall -y qianfan clean: - rm -rf build output dist qianfan.egg-info + rm -rf output + $(MAKE) -C python clean -doc: install - poetry run bash src/scripts/build_doc.sh +doc: install prepare_output + $(MAKE) -C python doc + rm -rf output/docs + mv python/output/* ./output + rm -rf python/output format: install - poetry run black ./src/qianfan - poetry run ruff --select I --fix ./src/qianfan + $(MAKE) -C python format lint: install - poetry run black ./src/qianfan --check - poetry run ruff check ./src/qianfan - poetry run mypy ./src/qianfan --install-types --non-interactive + $(MAKE) -C python lint test: clean install - cd src && bash scripts/run_test.sh + $(MAKE) -C python test + $(MAKE) -C go test +mock: + bash ./python/scripts/run_mock_server.sh -.PHONY: build install uninstall clean \ No newline at end of file +.PHONY: build install uninstall clean diff --git a/src/qianfan/docs/conf.py_t b/docs/template/conf.py_t similarity index 100% rename from src/qianfan/docs/conf.py_t rename to docs/template/conf.py_t diff --git a/src/qianfan/docs/root_doc.rst_t b/docs/template/root_doc.rst_t similarity index 100% rename from src/qianfan/docs/root_doc.rst_t rename to docs/template/root_doc.rst_t diff --git a/go/Makefile b/go/Makefile new file mode 100644 index 00000000..97feb83b --- /dev/null +++ b/go/Makefile @@ -0,0 +1,2 @@ +test: + cd qianfan && go test -race -timeout=120s -v -coverprofile=coverage.out ./... \ No newline at end of file diff --git a/go/README.md b/go/README.md new file mode 100644 index 00000000..0bbd7f53 --- /dev/null +++ b/go/README.md @@ -0,0 +1,186 @@ +# 百度千帆大模型平台 Go SDK + +## 如何使用 + +首先可以通过如下命令安装 SDK: + +``` +go get github.com/baidubce/bce-qianfan-sdk/go/qianfan +``` + +之后就可以在代码中通过如下方式引入 SDK: + +``` +import ( + "github.com/baidubce/bce-qianfan-sdk/go/qianfan" +) +``` + +### 鉴权 + +在使用千帆 SDK 之前,用户需要 [百度智能云控制台 - 安全认证](https://console.bce.baidu.com/iam/#/iam/accesslist) 页面获取 Access Key 与 Secret Key,并在 [千帆控制台](https://console.bce.baidu.com/qianfan/ais/console/applicationConsole/application) 中创建应用,选择需要启用的服务,具体流程参见平台 [说明文档](https://cloud.baidu.com/doc/Reference/s/9jwvz2egb)。 + +SDK 支持从当前目录的 `.env` 中读取配置,也可以通过环境变量 `QIANFAN_ACCESS_KEY` 和 `QIANFAN_SECRET_KEY` 获取配置,这一步骤会在使用 SDK 时自动完成。 + +```bash +export QIANFAN_ACCESS_KEY=your_access_key +export QIANFAN_SECRET_KEY=your_secret_key +``` + +同时,也可以在代码中手动设置 `AccessKey` 和 `SecretKey`,具体如下: + +```go +qianfan.GetConfig().AccessKey = "your_access_key" +qianfan.GetConfig().SecretKey = "your_secret_key" +``` + +### Chat 对话 + +可以使用 `ChatCompletion` 对象完成对话相关操作,可以通过如下方法获取一个 `ChatCompletion` 对象: + +```go +chat := qianfan.NewChatCompletion() // 默认使用 ERNIE-Bot-turbo 模型 + +// 可以通过 WithModel 指定模型 +chat := qianfan.NewChatCompletion( + qianfan.WithModel("ERNIE-Bot-4"), // 支持的模型可以通过 chat.ModelList() 获取 +) +// 或者通过 WithEndpoint 指定 endpoint +chat := qianfan.NewChatCompletion( + qianfan.WithEndpoint("your_custom_endpoint"), +) +``` + +之后就可以通过 `Do` 方法进行对话: + +``` +resp, err := chat.Do( + context.TODO(), + &qianfan.ChatCompletionRequest{ + Messages: []qianfan.ChatCompletionMessage{ + qianfan.ChatCompletionUserMessage("你好"), + }, + }, +) +if err != nil { + fmt.Print(err) +} +fmt.Print(resp.Result) +``` + +也可以调用 `Stream` 方法实现流式返回 + +```go +chat := client.ChatCompletion() + +resp, err := chat.Stream( // Stream 启用流式返回,参数与 Do 相同 + context.TODO(), + &qianfan.ChatCompletionRequest{ + Messages: []qianfan.ChatCompletionMessage{ + qianfan.ChatCompletionUserMessage("你好"), + }, + }, +) +if err != nil { + return err +} +for { + r, err := resp.Recv() + if err != nil { + return err + } + if resp.IsEnd { // 判断是否结束 + break + } + fmt.Print(r.Result) +} +``` + +### Completion 续写 + +对于不需要对话,仅需要根据 prompt 进行补全的场景来说,用户可以使用 `Completion` 来完成这一任务。 + +```go +completion := qianfan.NewCompletion() // 默认使用 ERNIE-Bot-turbo 模型 + +// 可以通过 WithModel 指定模型 +completion := qianfan.NewCompletion( + qianfan.WithModel("ERNIE-Bot-4"), + // 支持的模型可以通过 completion.ModelList() 获取 +) +// 或者通过 WithEndpoint 指定 endpoint +completion := qianfan.NewCompletion( + qianfan.WithEndpoint("your_custom_endpoint"), +) +``` + +与对话相同,可以调用 `Do` 方法实现续写 + +```go +resp, err := completion.Do( + context.TODO(), + &CompletionRequest{ + Prompt: prompt, + } +) +if err != nil { + return err +} +fmt.Printf(resp.Result) // 模型返回的结果 +``` + +也可以调用 `Stream` 方法实现流式返回 + +```go +resp, err := completion.Stream( // Stream 启用流式返回,参数与 Do 相同 + context.TODO(), + &CompletionRequest{ + Prompt: prompt, + } +) +if err != nil { + return err +} +for { + r, err := resp.Recv() + if err != nil { + return err + } + if resp.IsEnd { // 判断是否结束 + break + } + fmt.Print(r.Result) +} +``` + +### Embedding 向量化 + +千帆 SDK 同样支持调用千帆大模型平台中的模型,将输入文本转化为用浮点数表示的向量形式。转化得到的语义向量可应用于文本检索、信息推荐、知识挖掘等场景。 + +```go +embed := qianfan.NewEmbedding() // 默认使用 Embedding-V1 模型 + +// 可以通过 WithModel 指定模型 +embed := qianfan.NewEmbedding( + qianfan.WithModel("ERNIE-Bot-4"), // 支持的模型可以通过 embed.ModelList() 获取 +) +// 或者通过 WithEndpoint 指定 endpoint +embed := qianfan.NewEmbedding( + qianfan.WithEndpoint("your_custom_endpoint"), +) +``` + +之后使用 `Do` 方法进行调用 + +```go +resp, err := embed.Do( + context.TODO(), + &EmbeddingRequest{ + Input: []string{"hello1", "hello2"}, + } +) +if err != nil { + return err +} +embed := resp.Data[0].Embedding // 获取第一个输入的向量 +``` diff --git a/go/qianfan/base_model.go b/go/qianfan/base_model.go new file mode 100644 index 00000000..203a364f --- /dev/null +++ b/go/qianfan/base_model.go @@ -0,0 +1,101 @@ +// Copyright (c) 2024 Baidu, Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package qianfan + +import "fmt" + +// 模型相关的结构体基类 +type BaseModel struct { + Model string // 使用的模型名称 + Endpoint string // 使用的模型服务地址 + *Requestor // Requstor 作为基类 +} + +// 使用量信息 +type ModelUsage struct { + PromptTokens int `json:"prompt_tokens"` // 问题tokens数 + CompletionTokens int `json:"completion_tokens"` // 回答tokens数 + TotalTokens int `json:"total_tokens"` // tokens总数 +} + +type ModelAPIResponse interface { + GetError() (int, string) +} + +// API 错误信息 +type ModelAPIError struct { + ErrorCode int `json:"error_code"` // 错误码 + ErrorMsg string `json:"error_msg"` // 错误消息 +} + +func (e *ModelAPIError) GetError() (int, string) { + return e.ErrorCode, e.ErrorMsg +} + +// 搜索结果 +type SearchResult struct { + Index int `json:"index"` // 序号 + URL string `json:"url"` // 搜索结果URL + Title string `json:"title"` // 搜索结果标题 +} + +// 搜索结果列表 +type SearchInfo struct { + SearchResults []SearchResult `json:"search_results"` // 搜索结果列表 +} + +// 模型响应的结果 +type ModelResponse struct { + Id string `json:"id"` // 本轮对话的id + Object string `json:"object"` // 回包类型 + Created int `json:"created"` // 时间戳 + SentenceId int `json:"sentence_id"` // 表示当前子句的序号。只有在流式接口模式下会返回该字段 + IsEnd bool `json:"is_end"` // 表示当前子句是否是最后一句。只有在流式接口模式下会返回该字段 + IsTruncated bool `json:"is_truncated"` // 当前生成的结果是否被截断 + Result string `json:"result"` // 对话返回结果 + NeedClearHistory bool `json:"need_clear_history"` // 表示用户输入是否存在安全风险,是否关闭当前会话,清理历史会话信息 + Usage ModelUsage `json:"usage"` // token统计信息 + FunctionCall *FunctionCall `json:"function_call"` // 由模型生成的函数调用,包含函数名称,和调用参数 + BanRound int `json:"ban_round"` // 当need_clear_history为true时,此字段会告知第几轮对话有敏感信息,如果是当前问题,ban_round=-1 + SearchInfo *SearchInfo `json:"search_info"` // 搜索数据,当请求参数enable_citation为true并且触发搜索时,会返回该字段 + ModelAPIError // API 错误信息 + baseResponse // 通用的响应信息 +} + +// 用于获取ModelResponse流式结果的结构体 +type ModelResponseStream struct { + *streamInternal +} + +// 获取ModelResponse流式结果 +func (s *ModelResponseStream) Recv() (*ModelResponse, error) { + var resp ModelResponse + err := s.streamInternal.Recv(&resp) + if err != nil { + return nil, err + } + if err = checkResponseError(&resp); err != nil { + return &resp, err + } + return &resp, nil +} + +func checkResponseError(resp ModelAPIResponse) error { + errCode, errMsg := resp.GetError() + if errCode != 0 { + return fmt.Errorf("API return error. code: %d, msg: %s", errCode, errMsg) + } + return nil +} diff --git a/go/qianfan/chat_completion.go b/go/qianfan/chat_completion.go new file mode 100644 index 00000000..3ada973c --- /dev/null +++ b/go/qianfan/chat_completion.go @@ -0,0 +1,223 @@ +// Copyright (c) 2024 Baidu, Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package qianfan + +import ( + "context" + "fmt" +) + +// 表示对话内容的结构体 +type ChatCompletionMessage struct { + Role string `json:"role"` // 角色,可选 "user", "assistant", "function" + Content string `json:"content"` // 对话内容 + Name string `json:"name,omitempty"` // message 作者 + FunctionCall *FunctionCall `json:"function_call,omitempty"` // 函数调用 +} + +// 用于 chat 类型模型的结构体 +type ChatCompletion struct { + BaseModel +} + +// 函数调用的结构体 +type FunctionCall struct { + Name string `json:"name"` // 触发的function名 + Arguments string `json:"arguments"` // 请求参数 + Thoughts string `json:"thoughts,omitempty"` // 模型思考过程 +} + +// function调用的示例 +type FunctionExample struct { + Role string `json:"role"` // 角色,可选 "user", "assistant", "function" + Content string `json:"content"` // 对话内容 + Name string `json:"name,omitempty"` // message 作者 + FunctionCall *FunctionCall `json:"function_call,omitempty"` // 函数调用 +} + +// 表示函数的结构体 +type Function struct { + Name string `json:"name"` // 函数名 + Description string `json:"description"` // 函数描述 + Parameters any `json:"parameters"` // 函数请求参数 + Responses any `json:"responses,omitempty"` // 函数响应参数 + Examples [][]FunctionExample `json:"examples,omitempty"` // function调用的一些历史示例 +} + +// 可选的工具 +type ToolChoice struct { + Type string `json:"type"` // 指定工具类型 + Function *Function `json:"function"` // 指定要使用的函数 + Name string `json:"name"` // 指定要使用的函数名 +} + +// chat 模型的请求结构体 +type ChatCompletionRequest struct { + BaseRequestBody `mapstructure:"-"` + Messages []ChatCompletionMessage `mapstructure:"messages"` // 聊天上下文信息 + Temperature float64 `mapstructure:"temperature,omitempty"` // 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定,范围 (0, 1.0],不能为0 + TopP float64 `mapstructure:"top_p,omitempty"` // 影响输出文本的多样性,取值越大,生成文本的多样性越强。取值范围 [0, 1.0] + PenaltyScore float64 `mapstructure:"penalty_score,omitempty"` // 通过对已生成的token增加惩罚,减少重复生成的现象。说明:值越大表示惩罚越大,取值范围:[1.0, 2.0] + System string `mapstructure:"system,omitempty"` // 模型人设,主要用于人设设定 + Stop []string `mapstructure:"stop,omitempty"` // 生成停止标识,当模型生成结果以stop中某个元素结尾时,停止文本生成 + DisableSearch bool `mapstructure:"disable_search,omitempty"` // 是否强制关闭实时搜索功能 + EnableCitation bool `mapstructure:"enable_citation,omitempty"` // 是否开启上角标返回 + MaxOutputTokens int `mapstructure:"max_output_tokens,omitempty"` // 指定模型最大输出token数 + ResponseFormat string `mapstructure:"response_format,omitempty"` // 指定响应内容的格式 + UserID string `mapstructure:"user_id,omitempty"` // 表示最终用户的唯一标识符 + Functions []Function `mapstructure:"functions,omitempty"` // 一个可触发函数的描述列表 + ToolChoice *ToolChoice `mapstructure:"tool_choice,omitempty"` // 在函数调用场景下,提示大模型选择指定的函数 +} + +// 内置 chat 模型的 endpoint +var ChatModelEndpoint = map[string]string{ + "ERNIE-Bot-turbo": "/chat/eb-instant", + "ERNIE-Bot": "/chat/completions", + "ERNIE-Bot-4": "/chat/completions_pro", + "ERNIE-Bot-8k": "/chat/ernie_bot_8k", + "ERNIE-Speed": "/chat/eb_speed", + "ERNIE-Bot-turbo-AI": "/chat/ai_apaas", + "EB-turbo-AppBuilder": "/chat/ai_apaas", + "BLOOMZ-7B": "/chat/bloomz_7b1", + "Llama-2-7b-chat": "/chat/llama_2_7b", + "Llama-2-13b-chat": "/chat/llama_2_13b", + "Llama-2-70b-chat": "/chat/llama_2_70b", + "Qianfan-BLOOMZ-7B-compressed": "/chat/qianfan_bloomz_7b_compressed", + "Qianfan-Chinese-Llama-2-7B": "/chat/qianfan_chinese_llama_2_7b", + "ChatGLM2-6B-32K": "/chat/chatglm2_6b_32k", + "AquilaChat-7B": "/chat/aquilachat_7b", + "XuanYuan-70B-Chat-4bit": "/chat/xuanyuan_70b_chat", + "Qianfan-Chinese-Llama-2-13B": "/chat/qianfan_chinese_llama_2_13b", + "ChatLaw": "/chat/chatlaw", + "Yi-34B-Chat": "/chat/yi_34b_chat", +} + +// 创建一个 User 的消息 +func ChatCompletionUserMessage(message string) ChatCompletionMessage { + return ChatCompletionMessage{ + Role: "user", + Content: message, + } +} + +// 创建一个 Assistant 的消息 +func ChatCompletionAssistantMessage(message string) ChatCompletionMessage { + return ChatCompletionMessage{ + Role: "assistant", + Content: message, + } +} + +// 内部根据 options 创建一个 ChatCompletion 对象 +func newChatCompletion(options *Options) *ChatCompletion { + chat := &ChatCompletion{ + BaseModel{ + Model: DefaultChatCompletionModel, + Endpoint: "", + Requestor: newRequestor(options), + }, + } + if options.Model != nil { + chat.Model = *options.Model + } + if options.Endpoint != nil { + chat.Endpoint = *options.Endpoint + } + return chat +} + +// 将 endpoint 转换成完整的 url +func (c *ChatCompletion) realEndpoint() (string, error) { + url := modelAPIPrefix + if c.Model != "" { + endpoint, ok := ChatModelEndpoint[c.Model] + if !ok { + return "", fmt.Errorf("model %s is not supported", c.Model) + } + url += endpoint + } else { + url += "/chat/" + c.Endpoint + } + return url, nil +} + +// 发送 chat 请求 +func (c *ChatCompletion) Do(ctx context.Context, request *ChatCompletionRequest) (*ModelResponse, error) { + url, err := c.realEndpoint() + if err != nil { + return nil, err + } + req, err := newModelRequest("POST", url, request) + if err != nil { + return nil, err + } + var resp ModelResponse + err = c.Requestor.request(req, &resp) + if err != nil { + return nil, err + } + if err = checkResponseError(&resp); err != nil { + return &resp, err + } + return &resp, nil +} + +// 发送流式请求 +func (c *ChatCompletion) Stream(ctx context.Context, request *ChatCompletionRequest) (*ModelResponseStream, error) { + url, err := c.realEndpoint() + if err != nil { + return nil, err + } + request.SetStream() + req, err := newModelRequest("POST", url, request) + if err != nil { + return nil, err + } + stream, err := c.Requestor.requestStream(req) + if err != nil { + return nil, err + } + return &ModelResponseStream{ + streamInternal: stream, + }, nil +} + +// chat 支持的模型列表 +func (c *ChatCompletion) ModelList() []string { + i := 0 + list := make([]string, len(ChatModelEndpoint)) + for k := range ChatModelEndpoint { + list[i] = k + i++ + } + return list +} + +// 创建一个 ChatCompletion 对象 +// +// chat := qianfan.NewChatCompletion() // 默认使用 ERNIE-Bot-turbo 模型 +// +// // 可以通过 WithModel 指定模型 +// chat := qianfan.NewChatCompletion( +// qianfan.WithModel("ERNIE-Bot-4"), // 支持的模型可以通过 chat.ModelList() 获取 +// ) +// // 或者通过 WithEndpoint 指定 endpoint +// chat := qianfan.NewChatCompletion( +// qianfan.WithEndpoint("your_custom_endpoint"), +// ) +func NewChatCompletion(optionList ...Option) *ChatCompletion { + options := makeOptions(optionList...) + return newChatCompletion(options) +} diff --git a/go/qianfan/chat_completion_test.go b/go/qianfan/chat_completion_test.go new file mode 100644 index 00000000..c95c3823 --- /dev/null +++ b/go/qianfan/chat_completion_test.go @@ -0,0 +1,115 @@ +// Copyright (c) 2024 Baidu, Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package qianfan + +import ( + "context" + "encoding/json" + "io" + "net/http" + "os" + "strings" + "testing" + + "github.com/mitchellh/mapstructure" + "github.com/stretchr/testify/assert" +) + +func TestChatCompletion(t *testing.T) { + for model, endpoint := range ChatModelEndpoint { + chat := NewChatCompletion(WithModel(model)) + resp, err := chat.Do( + context.Background(), + &ChatCompletionRequest{ + Messages: []ChatCompletionMessage{ + ChatCompletionUserMessage("你好"), + }, + }, + ) + assert.NoError(t, err) + assert.Equal(t, resp.RawResponse.StatusCode, 200) + assert.NotEqual(t, resp.Id, nil) + assert.Equal(t, resp.Object, "chat.completion") + assert.True(t, strings.Contains(resp.RawResponse.Request.URL.Path, endpoint)) + assert.True(t, strings.Contains(resp.Result, "你好")) + + req, err := getRequestBody[ChatCompletionRequest](resp.RawResponse) + assert.NoError(t, err) + assert.Equal(t, req.Messages[0].Content, "你好") + } +} + +func TestChatCompletionStream(t *testing.T) { + for model, endpoint := range ChatModelEndpoint { + chat := NewChatCompletion(WithModel(model)) + resp, err := chat.Stream( + context.Background(), + &ChatCompletionRequest{ + Messages: []ChatCompletionMessage{ + ChatCompletionUserMessage("你好"), + }, + }, + ) + assert.NoError(t, err) + turn_count := 0 + for { + r, err := resp.Recv() + assert.NoError(t, err) + if resp.IsEnd { + break + } + turn_count++ + assert.Equal(t, r.RawResponse.StatusCode, 200) + assert.NotEqual(t, r.Id, nil) + assert.Equal(t, r.Object, "chat.completion") + assert.Contains(t, r.RawResponse.Request.URL.Path, endpoint) + assert.Contains(t, r.Result, "你好") + req, err := getRequestBody[ChatCompletionRequest](r.RawResponse) + assert.NoError(t, err) + assert.Equal(t, req.Messages[0].Content, "你好") + } + assert.True(t, turn_count > 1) + } +} + +func TestMain(m *testing.M) { + os.Setenv("QIANFAN_BASE_URL", "http://127.0.0.1:8866") + os.Setenv("QIANFAN_ACCESS_KEY", "test_access_key") + os.Setenv("QIANFAN_SECRET_KEY", "test_secret_key") + + os.Exit(m.Run()) +} + +func getRequestBody[T any](response *http.Response) (*T, error) { + var body T + rawBody, err := response.Request.GetBody() + if err != nil { + return nil, err + } + bodyBytes, err := io.ReadAll(rawBody) + if err != nil { + return nil, err + } + bodyMap := make(map[string]interface{}) + err = json.Unmarshal(bodyBytes, &bodyMap) + if err != nil { + return nil, err + } + err = mapstructure.Decode(bodyMap, &body) + if err != nil { + return nil, err + } + return &body, nil +} diff --git a/go/qianfan/completion.go b/go/qianfan/completion.go new file mode 100644 index 00000000..7a654b25 --- /dev/null +++ b/go/qianfan/completion.go @@ -0,0 +1,194 @@ +// Copyright (c) 2024 Baidu, Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package qianfan + +import ( + "context" + "fmt" +) + +// Completion 模型请求的参数结构体,但并非每个模型都完整支持如下参数,具体是否支持以 API 文档为准 +type CompletionRequest struct { + BaseRequestBody + Prompt string `mapstructure:"prompt"` // 请求信息 + Temperature float64 `mapstructure:"temperature,omitempty"` // 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定,范围 (0, 1.0],不能为0 + TopK int `mapstructure:"top_k,omitempty"` // Top-K 采样参数,在每轮token生成时,保留k个概率最高的token作为候选 + TopP float64 `mapstructure:"top_p,omitempty"` // 影响输出文本的多样性,取值越大,生成文本的多样性越强。取值范围 [0, 1.0] + PenaltyScore float64 `mapstructure:"penalty_score,omitempty"` // 通过对已生成的token增加惩罚,减少重复生成的现象。说明:值越大表示惩罚越大,取值范围:[1.0, 2.0] + System string `mapstructure:"system,omitempty"` // 模型人设,主要用于人设设定 + Stop []string `mapstructure:"stop,omitempty"` // 生成停止标识,当模型生成结果以stop中某个元素结尾时,停止文本生成 + DisableSearch bool `mapstructure:"disable_search,omitempty"` // 是否强制关闭实时搜索功能 + EnableCitation bool `mapstructure:"enable_citation,omitempty"` // 是否开启上角标返回 + MaxOutputTokens int `mapstructure:"max_output_tokens,omitempty"` // 指定模型最大输出token数 + ResponseFormat string `mapstructure:"response_format,omitempty"` // 指定响应内容的格式 + UserID string `mapstructure:"user_id,omitempty"` // 表示最终用户的唯一标识符 +} + +// 用于 Completion 模型请求的结构体 +type Completion struct { + BaseModel + chatWrapper *ChatCompletion +} + +// 内置 Completion 模型的 endpoint +var CompletionModelEndpoint = map[string]string{ + "SQLCoder-7B": "/completions/sqlcoder_7b", + "CodeLlama-7b-Instruct": "/completions/codellama_7b_instruct", +} + +// 内部根据 Options 创建 Completion 对象 +func newCompletion(options *Options) *Completion { + hasModel := options.Model != nil + hasEndpoint := options.Endpoint != nil + comp := Completion{ + BaseModel: BaseModel{ + Model: DefaultCompletionModel, + Endpoint: "", + Requestor: newRequestor(options), + }, + chatWrapper: nil, + } + // 如果 model 和 endpoint 都没提供,那就用 chatWrapper 默认值 + if !hasModel && !hasEndpoint { + comp.chatWrapper = newChatCompletion(options) + } + // 如果提供了 model + if hasModel { + // 那就看模型是否是 chat 模型,如果是,就使用 chatWrapper + _, ok := ChatModelEndpoint[*options.Model] + if ok { + comp.chatWrapper = newChatCompletion(options) + } else { + comp.Model = *options.Model + } + } + if hasEndpoint { + comp.Endpoint = *options.Endpoint + } + return &comp +} + +// 将 endpoint 转换成完整的 endpoint +func (c *Completion) realEndpoint() (string, error) { + url := modelAPIPrefix + if c.Model != "" { + endpoint, ok := CompletionModelEndpoint[c.Model] + if !ok { + return "", fmt.Errorf("model %s is not supported", c.Model) + } + url += endpoint + } else { + url += "/completions/" + c.Endpoint + } + return url, nil +} + +// 将 completion 的请求转换为 chat 的请求 +func convertCompletionReqToChatReq(request *CompletionRequest) *ChatCompletionRequest { + chatReq := ChatCompletionRequest{ + BaseRequestBody: request.BaseRequestBody, + Messages: []ChatCompletionMessage{ + ChatCompletionUserMessage(request.Prompt), + }, + Temperature: request.Temperature, + TopP: request.TopP, + PenaltyScore: request.PenaltyScore, + System: request.System, + Stop: request.Stop, + DisableSearch: request.DisableSearch, + EnableCitation: request.EnableCitation, + MaxOutputTokens: request.MaxOutputTokens, + ResponseFormat: request.ResponseFormat, + UserID: request.UserID, + } + return &chatReq +} + +// 发送请求 +func (c *Completion) Do(ctx context.Context, request *CompletionRequest) (*ModelResponse, error) { + if c.chatWrapper != nil { + return c.chatWrapper.Do(ctx, convertCompletionReqToChatReq(request)) + } + url, err := c.realEndpoint() + if err != nil { + return nil, err + } + req, err := newModelRequest("POST", url, request) + if err != nil { + return nil, err + } + var resp ModelResponse + err = c.Requestor.request(req, &resp) + if err != nil { + return nil, err + } + if err = checkResponseError(&resp); err != nil { + return &resp, err + } + return &resp, nil +} + +// 发送流式请求 +func (c *Completion) Stream(ctx context.Context, request *CompletionRequest) (*ModelResponseStream, error) { + if c.chatWrapper != nil { + return c.chatWrapper.Stream(ctx, convertCompletionReqToChatReq(request)) + } + url, err := c.realEndpoint() + if err != nil { + return nil, err + } + request.SetStream() + req, err := newModelRequest("POST", url, request) + if err != nil { + return nil, err + } + stream, err := c.Requestor.requestStream(req) + if err != nil { + return nil, err + } + return &ModelResponseStream{ + streamInternal: stream, + }, nil +} + +// 创建一个 Completion 实例 +// +// completion := qianfan.NewCompletion() // 默认使用 ERNIE-Bot-turbo 模型 +// +// // 可以通过 WithModel 指定模型 +// completion := qianfan.NewCompletion( +// qianfan.WithModel("ERNIE-Bot-4"), +// // 支持的模型可以通过 completion.ModelList() 获取 +// ) +// // 或者通过 WithEndpoint 指定 endpoint +// completion := qianfan.NewCompletion( +// qianfan.WithEndpoint("your_custom_endpoint"), +// ) +func NewCompletion(optionList ...Option) *Completion { + options := makeOptions(optionList...) + return newCompletion(options) +} + +// Completion 支持的模型列表 +func (c *Completion) ModelList() []string { + i := 0 + list := make([]string, len(CompletionModelEndpoint)) + for k := range CompletionModelEndpoint { + list[i] = k + i++ + } + list = append(list, (&ChatCompletion{}).ModelList()...) + return list +} diff --git a/go/qianfan/completion_test.go b/go/qianfan/completion_test.go new file mode 100644 index 00000000..d96544e0 --- /dev/null +++ b/go/qianfan/completion_test.go @@ -0,0 +1,93 @@ +// Copyright (c) 2024 Baidu, Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package qianfan + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCompletion(t *testing.T) { + prompt := "hello" + + completion := NewCompletion() + resp, err := completion.Do( + context.Background(), + &CompletionRequest{ + Prompt: prompt, + }, + ) + assert.NoError(t, err) + assert.Equal(t, resp.RawResponse.StatusCode, 200) + assert.NotEqual(t, resp.Id, nil) + assert.Equal(t, resp.Object, "chat.completion") + assert.Contains(t, + resp.RawResponse.Request.URL.Path, + ChatModelEndpoint[DefaultCompletionModel], + ) + assert.Contains(t, resp.Result, prompt) + request, err := getRequestBody[ChatCompletionRequest](resp.RawResponse) + assert.NoError(t, err) + assert.Equal(t, request.Messages[0].Content, prompt) + + completion = NewCompletion(WithModel("SQLCoder-7B")) + resp, err = completion.Do( + context.Background(), + &CompletionRequest{ + Prompt: prompt, + Temperature: 0.5, + }, + ) + assert.NoError(t, err) + assert.Equal(t, resp.Object, "completion") + assert.Contains(t, resp.RawResponse.Request.URL.Path, CompletionModelEndpoint["SQLCoder-7B"]) + assert.Contains(t, resp.Result, prompt) + reqComp, err := getRequestBody[CompletionRequest](resp.RawResponse) + assert.NoError(t, err) + assert.Equal(t, reqComp.Prompt, prompt) + assert.Equal(t, reqComp.Temperature, 0.5) +} + +func TestCompletionStream(t *testing.T) { + + modelList := []string{"ERNIE-Bot-turbo", "SQLCoder-7B"} + for _, m := range modelList { + chat := NewCompletion( + WithModel(m), + ) + resp, err := chat.Stream( + context.Background(), + &CompletionRequest{ + Prompt: "hello", + Temperature: 0.5, + }, + ) + assert.NoError(t, err) + defer resp.Close() + turnCount := 0 + for { + resp, err := resp.Recv() + assert.NoError(t, err) + if resp.IsEnd { + break + } + turnCount++ + assert.Contains(t, resp.Result, "hello") + } + assert.Greater(t, turnCount, 1) + } +} diff --git a/go/qianfan/config.go b/go/qianfan/config.go new file mode 100644 index 00000000..abbca1e5 --- /dev/null +++ b/go/qianfan/config.go @@ -0,0 +1,76 @@ +// Copyright (c) 2024 Baidu, Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package qianfan + +import ( + "github.com/spf13/viper" +) + +// 默认配置 +var defaultConfig = map[string]string{ + "QIANFAN_ACCESS_KEY": "", + "QIANFAN_SECRET_KEY": "", + "QIANFAN_BASE_URL": "https://aip.baidubce.com", + "QIANFAN_IAM_SIGN_EXPIRATION_SEC": "300", + "QIANFAN_CONSOLE_BASE_URL": "https://qianfan.baidubce.com", +} + +// SDK 使用的全局配置,可以用 GetConfig() 获取 +type Config struct { + AccessKey string `mapstructure:"QIANFAN_ACCESS_KEY"` + SecretKey string `mapstructure:"QIANFAN_SECRET_KEY"` + BaseURL string `mapstructure:"QIANFAN_BASE_URL"` + IAMSignExpirationSeconds int `mapstructure:"QIANFAN_IAM_SIGN_EXPIRATION_SEC"` + ConsoleBaseURL string `mapstructure:"QIANFAN_CONSOLE_BASE_URL"` +} + +func setConfigDeafultValue(vConfig *viper.Viper) { + // 因为 viper 自动绑定无法在 unmarshal 时使用,所以这里要手动设置默认值 + for k, v := range defaultConfig { + vConfig.SetDefault(k, v) + } +} + +func loadConfigFromEnv() *Config { + vConfig := viper.New() + + vConfig.SetConfigFile(".env") + vConfig.SetConfigType("dotenv") + vConfig.AutomaticEnv() + setConfigDeafultValue(vConfig) + + // ignore error if config file not found + _ = vConfig.ReadInConfig() + + config := &Config{} + if err := vConfig.Unmarshal(&config); err != nil { + logger.Panicf("load config file failed with error `%v`, please check your config.", err) + } + return config +} + +var _config *Config = nil + +// 获取全局配置,可以通过如下方式修改配置 +// 可以在代码中手动设置 `AccessKey` 和 `SecretKey`,具体如下: +// +// qianfan.GetConfig().AccessKey = "your_access_key" +// qianfan.GetConfig().SecretKey = "your_secret_key" +func GetConfig() *Config { + if _config == nil { + _config = loadConfigFromEnv() + } + return _config +} diff --git a/go/qianfan/consts.go b/go/qianfan/consts.go new file mode 100644 index 00000000..bb202aab --- /dev/null +++ b/go/qianfan/consts.go @@ -0,0 +1,25 @@ +// Copyright (c) 2024 Baidu, Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package qianfan + +// 模型请求的前缀 +const modelAPIPrefix = "/rpc/2.0/ai_custom/v1/wenxinworkshop" + +// 默认使用的模型 +const ( + DefaultChatCompletionModel = "ERNIE-Bot-turbo" + DefaultCompletionModel = "ERNIE-Bot-turbo" + DefaultEmbeddingModel = "Embedding-V1" +) diff --git a/go/qianfan/embdding.go b/go/qianfan/embdding.go new file mode 100644 index 00000000..4a4bcda7 --- /dev/null +++ b/go/qianfan/embdding.go @@ -0,0 +1,130 @@ +// Copyright (c) 2024 Baidu, Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package qianfan + +import ( + "context" + "fmt" +) + +// 用于 Embedding 相关操作的结构体 +type Embedding struct { + BaseModel +} + +// Embedding 请求 +type EmbeddingRequest struct { + BaseRequestBody + Input []string `mapstructure:"input"` // 输入的文本列表 + UserID string `mapstructure:"user_id,omitempty"` // 表示最终用户的唯一标识符 +} + +// 具体的 Embedding 信息 +type EmbeddingData struct { + Object string `json:"object"` // 固定值"embedding" + Embedding []float64 `json:"embedding"` // embedding 内容 + Index int `json:"index"` // 序号 +} + +// 返回的 Embedding 数据 +type EmbeddingResponse struct { + Id string `json:"id"` // 请求的id + Object string `json:"object"` // 回包类型,固定值“embedding_list” + Created int `json:"created"` // 创建时间 + Usage ModelUsage `json:"usage"` // token统计信息 + Data []EmbeddingData `json:"data"` // embedding 数据 + ModelAPIError // API 错误信息 + baseResponse // 基础的响应字段 +} + +// 内置 Embedding 模型的 endpoint +var EmbeddingEndpoint = map[string]string{ + "Embedding-V1": "/embeddings/embedding-v1", + "bge-large-en": "/embeddings/bge_large_en", + "bge-large-zh": "/embeddings/bge_large_zh", + "tao-8k": "/embeddings/tao_8k", +} + +// 创建 Embedding 实例 +func NewEmbedding(optionList ...Option) *Embedding { + options := makeOptions(optionList...) + return newEmbedding(options) +} + +// 内部根据 options 创建 Embedding 实例 +func newEmbedding(options *Options) *Embedding { + embedding := &Embedding{ + BaseModel{ + Model: DefaultEmbeddingModel, + Endpoint: "", + Requestor: newRequestor(options), + }, + } + if options.Model != nil { + embedding.Model = *options.Model + } + if options.Endpoint != nil { + embedding.Endpoint = *options.Endpoint + } + return embedding +} + +// endpoint 转成完整 url +func (c *Embedding) realEndpoint() (string, error) { + url := modelAPIPrefix + if c.Model != "" { + endpoint, ok := EmbeddingEndpoint[c.Model] + if !ok { + return "", fmt.Errorf("model %s is not supported", c.Model) + } + url += endpoint + } else { + url += "/embeddings/" + c.Endpoint + } + return url, nil +} + +// 发送 Embedding 请求 +func (c *Embedding) Do(ctx context.Context, request *EmbeddingRequest) (*EmbeddingResponse, error) { + url, err := c.realEndpoint() + if err != nil { + return nil, err + } + req, err := newModelRequest("POST", url, request) + if err != nil { + return nil, err + } + resp := &EmbeddingResponse{} + + err = c.Requestor.request(req, resp) + if err != nil { + return nil, err + } + if err = checkResponseError(resp); err != nil { + return resp, err + } + return resp, nil +} + +// 获取 Embedding 支持的模型列表 +func (c *Embedding) ModelList() []string { + list := make([]string, len(EmbeddingEndpoint)) + i := 0 + for k := range EmbeddingEndpoint { + list[i] = k + i++ + } + return list +} diff --git a/go/qianfan/embedding_test.go b/go/qianfan/embedding_test.go new file mode 100644 index 00000000..8ed78760 --- /dev/null +++ b/go/qianfan/embedding_test.go @@ -0,0 +1,38 @@ +// Copyright (c) 2024 Baidu, Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package qianfan + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestEmbedding(t *testing.T) { + embed := NewEmbedding() + resp, err := embed.Do(context.Background(), &EmbeddingRequest{ + Input: []string{"hello1", "hello2"}, + }) + assert.NoError(t, err) + assert.Equal(t, resp.RawResponse.StatusCode, 200) + assert.Equal(t, len(resp.Data), 2) + assert.NotEqual(t, len(resp.Data), 0) + assert.Contains(t, resp.RawResponse.Request.URL.Path, EmbeddingEndpoint[DefaultEmbeddingModel]) + req, err := getRequestBody[EmbeddingRequest](resp.RawResponse) + assert.NoError(t, err) + assert.Equal(t, req.Input[0], "hello1") + assert.Equal(t, req.Input[1], "hello2") +} diff --git a/go/qianfan/go.mod b/go/qianfan/go.mod new file mode 100644 index 00000000..bf129a04 --- /dev/null +++ b/go/qianfan/go.mod @@ -0,0 +1,34 @@ +module github.com/baidubce/bce-qianfan-sdk/go/qianfan + +go 1.18 + +require ( + github.com/baidubce/bce-sdk-go v0.9.164 + github.com/mitchellh/mapstructure v1.5.0 + github.com/sirupsen/logrus v1.9.3 + github.com/spf13/viper v1.18.2 + github.com/stretchr/testify v1.8.4 +) + +require ( + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/fsnotify/fsnotify v1.7.0 // indirect + github.com/hashicorp/hcl v1.0.0 // indirect + github.com/magiconair/properties v1.8.7 // indirect + github.com/pelletier/go-toml/v2 v2.1.0 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + github.com/sagikazarmark/locafero v0.4.0 // indirect + github.com/sagikazarmark/slog-shim v0.1.0 // indirect + github.com/sourcegraph/conc v0.3.0 // indirect + github.com/spf13/afero v1.11.0 // indirect + github.com/spf13/cast v1.6.0 // indirect + github.com/spf13/pflag v1.0.5 // indirect + github.com/subosito/gotenv v1.6.0 // indirect + go.uber.org/atomic v1.9.0 // indirect + go.uber.org/multierr v1.9.0 // indirect + golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect + golang.org/x/sys v0.15.0 // indirect + golang.org/x/text v0.14.0 // indirect + gopkg.in/ini.v1 v1.67.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go/qianfan/go.sum b/go/qianfan/go.sum new file mode 100644 index 00000000..33bcf151 --- /dev/null +++ b/go/qianfan/go.sum @@ -0,0 +1,69 @@ +github.com/baidubce/bce-sdk-go v0.9.164 h1:7gswLMsdQyarovMKuv3i6wxFQ3BQgvc5CmyGXb/D/xA= +github.com/baidubce/bce-sdk-go v0.9.164/go.mod h1:zbYJMQwE4IZuyrJiFO8tO8NbtYiKTFTbwh4eIsqjVdg= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= +github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= +github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= +github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/pelletier/go-toml/v2 v2.1.0 h1:FnwAJ4oYMvbT/34k9zzHuZNrhlz48GB3/s6at6/MHO4= +github.com/pelletier/go-toml/v2 v2.1.0/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ= +github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4= +github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE= +github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= +github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= +github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= +github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= +github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= +github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ= +github.com/spf13/viper v1.18.2/go.mod h1:EKmWIqdnk5lOcmR72yw6hS+8OPYcwD0jteitLMVB+yk= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= +github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= +go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= +go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= +golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g= +golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= +golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= +gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= +gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/go/qianfan/logger.go b/go/qianfan/logger.go new file mode 100644 index 00000000..21ad0b3d --- /dev/null +++ b/go/qianfan/logger.go @@ -0,0 +1,21 @@ +// Copyright (c) 2024 Baidu, Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package qianfan + +import ( + "github.com/sirupsen/logrus" +) + +var logger = logrus.New() diff --git a/go/qianfan/options.go b/go/qianfan/options.go new file mode 100644 index 00000000..0db46f6a --- /dev/null +++ b/go/qianfan/options.go @@ -0,0 +1,44 @@ +// Copyright (c) 2024 Baidu, Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package qianfan + +type Option func(*Options) +type Options struct { + Model *string + Endpoint *string +} + +// 用于模型类对象设置使用的模型 +func WithModel(model string) Option { + return func(options *Options) { + options.Model = &model + } +} + +// 用于模型类对象设置使用的 endpoint +func WithEndpoint(endpoint string) Option { + return func(options *Options) { + options.Endpoint = &endpoint + } +} + +// 将多个 Option 转换成最终的 Options 对象 +func makeOptions(options ...Option) *Options { + option := Options{} + for _, opt := range options { + opt(&option) + } + return &option +} diff --git a/go/qianfan/requestor.go b/go/qianfan/requestor.go new file mode 100644 index 00000000..b4cefaf4 --- /dev/null +++ b/go/qianfan/requestor.go @@ -0,0 +1,354 @@ +// Copyright (c) 2024 Baidu, Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package qianfan + +import ( + "bufio" + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strconv" + "strings" + + "github.com/baidubce/bce-sdk-go/auth" + bceHTTP "github.com/baidubce/bce-sdk-go/http" + "github.com/baidubce/bce-sdk-go/util" +) + +// 所有请求类型需实现的接口 +// +// 定义了提供额外参数的接口 +type RequestBody interface { + SetExtra(m map[string]interface{}) + GetExtra() map[string]interface{} +} + +// 请求体基类 +// +// 实现了允许用户传递额外参数的方法 +type BaseRequestBody struct { + Extra map[string]interface{} `mapstructure:"-"` +} + +// 设置额外参数 +func (r *BaseRequestBody) SetExtra(m map[string]interface{}) { + r.Extra = m +} + +// 获取额外参数 +func (r *BaseRequestBody) GetExtra() map[string]interface{} { + return r.Extra +} + +// 将请求设置为流式 +func (r *BaseRequestBody) SetStream() { + if r.Extra == nil { + r.Extra = map[string]interface{}{} + } + r.Extra["stream"] = true +} + +// 将请求体转换成 map,并将额外参数合并到 map 中 +func convertToMap(body RequestBody) (map[string]interface{}, error) { + m, err := dumpToMap(body) + if err != nil { + return nil, err + } + extra := body.GetExtra() + for k, v := range extra { + m[k] = v + } + return m, nil +} + +// 请求类型,用于区分是模型的请求还是管控类请求 +// 在 QfRequest.Type 处被使用 +const ( + modelRequest = "model" + consoleRequest = "console" +) + +// SDK 内部表示请求的类 +type QfRequest struct { + Type string // 请求类型,用于区分是模型的请求 `modelRequest` 还是管控类请求 `consoleRequest` + Method string // HTTP 方法 + URL string // 请求的完整地址 + Headers map[string]string // HTTP 请求头 + Params map[string]string // HTTP 请求参数 + Body map[string]interface{} // HTTP 请求体 +} + +// 创建一个用于模型类请求的 Request +func newModelRequest(method string, url string, body RequestBody) (*QfRequest, error) { + return newRequest(modelRequest, method, url, body) +} + +// 创建一个用于管控类请求的 Request +// 暂时注释避免 lint 报错 +// func newConsoleRequest(method string, url string, body RequestBody) (*QfRequest, error) { +// return newRequest(ConsoleRequest, method, url, body) +// } + +// 创建一个 Request,body 可以是任意实现了 RequestBody 接口的类型 +func newRequest(requestType string, method string, url string, body RequestBody) (*QfRequest, error) { + b, err := convertToMap(body) + if err != nil { + return nil, err + } + return newRequestFromMap(requestType, method, url, b) +} + +// 创建一个 Request,body 是一个 map +func newRequestFromMap(requestType string, method string, url string, body map[string]interface{}) (*QfRequest, error) { + return &QfRequest{ + Type: requestType, + Method: method, + URL: url, + Body: body, + Params: map[string]string{}, + Headers: map[string]string{}, + }, nil +} + +// 所有回复类型的基类 +type baseResponse struct { + Body []byte + RawResponse *http.Response +} + +// 所有回复类型需实现的接口 +type QfResponse interface { + SetResponse(Body []byte, RawResponse *http.Response) +} + +// 设置回复中通用参数的字段 +func (r *baseResponse) SetResponse(Body []byte, RawResponse *http.Response) { + r.Body = Body + r.RawResponse = RawResponse +} + +// 请求器,负责 SDK 中所有请求的发送,是所有对外暴露对象的基类 +type Requestor struct { + client *http.Client + Options *Options +} + +// 创建一个 Requestor +func newRequestor(options *Options) *Requestor { + return &Requestor{ + client: &http.Client{}, + Options: options, + } +} + +// IAM 签名 +func (r *Requestor) sign(request *QfRequest) error { + bceRequest := &bceHTTP.Request{} + bceRequest.SetMethod(request.Method) + bceRequest.SetHeaders(request.Headers) + bceRequest.SetParams(request.Params) + u, err := url.Parse(request.URL) + if err != nil { + return err + } + bceRequest.SetProtocol(u.Scheme) + + bceRequest.SetHost(u.Hostname()) + port := u.Port() + if port == "" { + if u.Scheme == "http" { + port = "80" + } else if u.Scheme == "https" { + port = "443" + } else { + return fmt.Errorf("unrecognized scheme: %s", u.Scheme) + } + } + porti, err := strconv.Atoi(port) + if err != nil { + return err + } + bceRequest.SetPort(porti) + bceRequest.SetUri(u.RequestURI()) + + credentials := &auth.BceCredentials{ + AccessKeyId: GetConfig().AccessKey, + SecretAccessKey: GetConfig().SecretKey, + } + now := util.NowUTCSeconds() + bceRequest.SetHeader("Host", u.Hostname()) + bceRequest.SetHeader("x-bce-date", util.FormatISO8601Date(now)) + headersToSign := make(map[string]struct{}) + for k := range bceRequest.Headers() { + headersToSign[strings.ToLower(k)] = struct{}{} + } + signer := auth.BceV1Signer{} + signOptions := &auth.SignOptions{ + HeadersToSign: headersToSign, + Timestamp: now, + ExpireSeconds: GetConfig().IAMSignExpirationSeconds, + } + signer.Sign(bceRequest, credentials, signOptions) + + request.Headers = bceRequest.Headers() + return nil +} + +// 对请求进行统一处理,并转换成 http.Request +func (r *Requestor) prepareRequest(request *QfRequest) (*http.Request, error) { + // 设置溯源标识 + if request.Type == modelRequest { + request.URL = GetConfig().BaseURL + request.URL + request.Body["extra_parameters"] = map[string]string{ + "request_source": versionIndicator, + } + } else if request.Type == consoleRequest { + request.URL = GetConfig().ConsoleBaseURL + request.URL + request.Headers["request-source"] = versionIndicator + } + bodyBytes, err := json.Marshal(request.Body) + if err != nil { + return nil, err + } + req, err := http.NewRequest(request.Method, request.URL, bytes.NewBuffer(bodyBytes)) + if err != nil { + return nil, err + } + request.Headers["Content-Type"] = "application/json" + // IAM 签名 + err = r.sign(request) + if err != nil { + return nil, err + } + for k, v := range request.Headers { + req.Header.Set(k, v) + } + + q := req.URL.Query() + for k, v := range request.Params { + q.Add(k, v) + } + req.URL.RawQuery = q.Encode() + + return req, nil +} + +// 进行请求,返回原始的 baseResponse,并将结果解析至 resp +func (r *Requestor) request(request *QfRequest, response QfResponse) error { + req, err := r.prepareRequest(request) + if err != nil { + return err + } + resp, err := r.client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return err + } + response.SetResponse(body, resp) + err = json.Unmarshal(body, response) + if err != nil { + return err + } + return nil +} + +// 流的内部实现,用于接收流中的响应 +type streamInternal struct { + httpResponse *http.Response // 原始的 http.Response + scanner *bufio.Scanner // 读取流的 scanner + IsEnd bool // 流是否已经结束 +} + +// 创建一个流 +func newStreamInternal(httpResponse *http.Response) (*streamInternal, error) { + return &streamInternal{ + httpResponse: httpResponse, + scanner: bufio.NewScanner(httpResponse.Body), + IsEnd: false, + }, nil +} + +// 关闭流 +func (si *streamInternal) Close() { + _ = si.httpResponse.Body.Close() +} + +// 接受流中的响应,并将结果解析至 resp +func (si *streamInternal) Recv(resp QfResponse) error { + var eventData []byte + for len(eventData) == 0 { + for { + if !si.scanner.Scan() { + si.IsEnd = true + si.Close() + return si.scanner.Err() + } + + line := si.scanner.Bytes() + if len(line) == 0 { + break + } + var ( + // field []byte = line + value []byte + ) + if i := bytes.IndexRune(line, ':'); i != -1 { + // field = line[:i] + value = line[i+1:] + if len(value) != 0 && value[0] == ' ' { + value = value[1:] + } + } + eventData = append(eventData, value...) + } + } + response := baseResponse{ + Body: eventData, + RawResponse: si.httpResponse, + } + + resp.SetResponse(response.Body, response.RawResponse) + err := json.Unmarshal(response.Body, resp) + if err != nil { + si.IsEnd = true + return err + } + return nil +} + +// 发送请求,返回流对象 +func (r *Requestor) requestStream(request *QfRequest) (*streamInternal, error) { + req, err := r.prepareRequest(request) + if err != nil { + return nil, err + } + resp, err := r.client.Do(req) + if err != nil { + return nil, err + } + stream, err := newStreamInternal(resp) + if err != nil { + return nil, err + } + return stream, nil +} diff --git a/go/qianfan/utils.go b/go/qianfan/utils.go new file mode 100644 index 00000000..c6e9a705 --- /dev/null +++ b/go/qianfan/utils.go @@ -0,0 +1,27 @@ +// Copyright (c) 2024 Baidu, Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package qianfan + +import "github.com/mitchellh/mapstructure" + +// 转换任意对象成 map +func dumpToMap(input interface{}) (map[string]interface{}, error) { + target := map[string]interface{}{} + err := mapstructure.Decode(input, &target) + if err != nil { + return nil, err + } + return target, nil +} diff --git a/go/qianfan/version.go b/go/qianfan/version.go new file mode 100644 index 00000000..2480beda --- /dev/null +++ b/go/qianfan/version.go @@ -0,0 +1,25 @@ +// Copyright (c) 2024 Baidu, Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// # 百度千帆大模型平台 Go SDK +// +// 千帆SDK提供大模型工具链最佳实践,让AI工作流和AI原生应用优雅且便捷地访问千帆大模型平台。 +// 目前 SDK 提供了以下功能: +// +// * 大模型推理:实现了对一言(ERNIE-Bot)系列、开源大模型等模型推理的接口封装,支持对话、补全、Embedding等。 +package qianfan + +// SDK 版本 +const Version = "v0.0.1" +const versionIndicator = "qianfan_go_sdk_" + Version diff --git a/src/qianfan/common/client/__init__.py b/javascript/README.md similarity index 100% rename from src/qianfan/common/client/__init__.py rename to javascript/README.md diff --git a/.coveragerc b/python/.coveragerc similarity index 79% rename from .coveragerc rename to python/.coveragerc index 785277c7..01272815 100644 --- a/.coveragerc +++ b/python/.coveragerc @@ -1,11 +1,11 @@ # .coveragerc [run] branch = True -source = ./src/qianfan +source = ./qianfan omit = *test.py, test*.py, Test*.py, qianfan/tests/* [report] omit = *test.py, test*.py, Test*.py [html] directory = htmlcov [xml] -output = coverage.xml \ No newline at end of file +output = coverage.xml diff --git a/python/Makefile b/python/Makefile new file mode 100644 index 00000000..f89d10b0 --- /dev/null +++ b/python/Makefile @@ -0,0 +1,29 @@ +build: + bash ./scripts/build.sh + +install: + poetry install -E all + +uninstall: + pip uninstall -y qianfan + +clean: + rm -rf build output dist qianfan.egg-info + +doc: install + poetry run bash ./scripts/build_doc.sh + +format: install + poetry run black ./qianfan + poetry run ruff --select I --fix ./qianfan + +lint: install + poetry run black ./qianfan --check + poetry run ruff check ./qianfan + poetry run mypy ./qianfan --install-types --non-interactive + +test: clean install + cd . && bash scripts/run_test.sh + + +.PHONY: build install uninstall clean diff --git a/README.pypi.md b/python/README.pypi.md similarity index 100% rename from README.pypi.md rename to python/README.pypi.md diff --git a/pyproject.toml b/python/pyproject.toml similarity index 98% rename from pyproject.toml rename to python/pyproject.toml index 0b082746..1ce7bf05 100644 --- a/pyproject.toml +++ b/python/pyproject.toml @@ -6,8 +6,8 @@ authors = [] license = "Apache-2.0" readme = "README.pypi.md" exclude = [ - "src/qianfan/tests", - "src/qianfan/docs", + "qianfan/tests", + "qianfan/docs", ] homepage = "https://cloud.baidu.com/product/wenxinworkshop" repository = "https://github.com/baidubce/bce-qianfan-sdk" diff --git a/src/qianfan/__init__.py b/python/qianfan/__init__.py similarity index 100% rename from src/qianfan/__init__.py rename to python/qianfan/__init__.py diff --git a/src/qianfan/common/__init__.py b/python/qianfan/common/__init__.py similarity index 100% rename from src/qianfan/common/__init__.py rename to python/qianfan/common/__init__.py diff --git a/src/qianfan/common/prompt/__init__.py b/python/qianfan/common/client/__init__.py similarity index 100% rename from src/qianfan/common/prompt/__init__.py rename to python/qianfan/common/client/__init__.py diff --git a/src/qianfan/common/client/chat.py b/python/qianfan/common/client/chat.py similarity index 100% rename from src/qianfan/common/client/chat.py rename to python/qianfan/common/client/chat.py diff --git a/src/qianfan/common/client/completion.py b/python/qianfan/common/client/completion.py similarity index 100% rename from src/qianfan/common/client/completion.py rename to python/qianfan/common/client/completion.py diff --git a/src/qianfan/common/client/dataset.py b/python/qianfan/common/client/dataset.py similarity index 100% rename from src/qianfan/common/client/dataset.py rename to python/qianfan/common/client/dataset.py diff --git a/src/qianfan/common/client/embedding.py b/python/qianfan/common/client/embedding.py similarity index 100% rename from src/qianfan/common/client/embedding.py rename to python/qianfan/common/client/embedding.py diff --git a/src/qianfan/common/client/evaluation.py b/python/qianfan/common/client/evaluation.py similarity index 100% rename from src/qianfan/common/client/evaluation.py rename to python/qianfan/common/client/evaluation.py diff --git a/src/qianfan/common/client/main.py b/python/qianfan/common/client/main.py similarity index 100% rename from src/qianfan/common/client/main.py rename to python/qianfan/common/client/main.py diff --git a/src/qianfan/common/client/plugin.py b/python/qianfan/common/client/plugin.py similarity index 100% rename from src/qianfan/common/client/plugin.py rename to python/qianfan/common/client/plugin.py diff --git a/src/qianfan/common/client/trainer.py b/python/qianfan/common/client/trainer.py similarity index 100% rename from src/qianfan/common/client/trainer.py rename to python/qianfan/common/client/trainer.py diff --git a/src/qianfan/common/client/txt2img.py b/python/qianfan/common/client/txt2img.py similarity index 100% rename from src/qianfan/common/client/txt2img.py rename to python/qianfan/common/client/txt2img.py diff --git a/src/qianfan/common/client/utils.py b/python/qianfan/common/client/utils.py similarity index 100% rename from src/qianfan/common/client/utils.py rename to python/qianfan/common/client/utils.py diff --git a/src/qianfan/common/hub/__init__.py b/python/qianfan/common/hub/__init__.py similarity index 100% rename from src/qianfan/common/hub/__init__.py rename to python/qianfan/common/hub/__init__.py diff --git a/src/qianfan/common/hub/hub.py b/python/qianfan/common/hub/hub.py similarity index 100% rename from src/qianfan/common/hub/hub.py rename to python/qianfan/common/hub/hub.py diff --git a/src/qianfan/common/hub/interface.py b/python/qianfan/common/hub/interface.py similarity index 100% rename from src/qianfan/common/hub/interface.py rename to python/qianfan/common/hub/interface.py diff --git a/src/qianfan/common/runnable/__init__.py b/python/qianfan/common/prompt/__init__.py similarity index 100% rename from src/qianfan/common/runnable/__init__.py rename to python/qianfan/common/prompt/__init__.py diff --git a/src/qianfan/common/prompt/prompt.py b/python/qianfan/common/prompt/prompt.py similarity index 100% rename from src/qianfan/common/prompt/prompt.py rename to python/qianfan/common/prompt/prompt.py diff --git a/src/qianfan/extensions/__init__.py b/python/qianfan/common/runnable/__init__.py similarity index 100% rename from src/qianfan/extensions/__init__.py rename to python/qianfan/common/runnable/__init__.py diff --git a/src/qianfan/common/runnable/base.py b/python/qianfan/common/runnable/base.py similarity index 100% rename from src/qianfan/common/runnable/base.py rename to python/qianfan/common/runnable/base.py diff --git a/src/qianfan/common/tool/baidu_search_tool.py b/python/qianfan/common/tool/baidu_search_tool.py similarity index 100% rename from src/qianfan/common/tool/baidu_search_tool.py rename to python/qianfan/common/tool/baidu_search_tool.py diff --git a/src/qianfan/common/tool/base_tool.py b/python/qianfan/common/tool/base_tool.py similarity index 100% rename from src/qianfan/common/tool/base_tool.py rename to python/qianfan/common/tool/base_tool.py diff --git a/src/qianfan/common/tool/duckduckgo_search_tool.py b/python/qianfan/common/tool/duckduckgo_search_tool.py similarity index 100% rename from src/qianfan/common/tool/duckduckgo_search_tool.py rename to python/qianfan/common/tool/duckduckgo_search_tool.py diff --git a/src/qianfan/common/tool/wikipedia_tool.py b/python/qianfan/common/tool/wikipedia_tool.py similarity index 100% rename from src/qianfan/common/tool/wikipedia_tool.py rename to python/qianfan/common/tool/wikipedia_tool.py diff --git a/src/qianfan/config.py b/python/qianfan/config.py similarity index 100% rename from src/qianfan/config.py rename to python/qianfan/config.py diff --git a/src/qianfan/consts.py b/python/qianfan/consts.py similarity index 100% rename from src/qianfan/consts.py rename to python/qianfan/consts.py diff --git a/src/qianfan/dataset/__init__.py b/python/qianfan/dataset/__init__.py similarity index 100% rename from src/qianfan/dataset/__init__.py rename to python/qianfan/dataset/__init__.py diff --git a/src/qianfan/dataset/consts.py b/python/qianfan/dataset/consts.py similarity index 100% rename from src/qianfan/dataset/consts.py rename to python/qianfan/dataset/consts.py diff --git a/src/qianfan/dataset/qianfan_data_operators.py b/python/qianfan/dataset/data_operator.py similarity index 100% rename from src/qianfan/dataset/qianfan_data_operators.py rename to python/qianfan/dataset/data_operator.py diff --git a/src/qianfan/dataset/data_source/__init__.py b/python/qianfan/dataset/data_source/__init__.py similarity index 100% rename from src/qianfan/dataset/data_source/__init__.py rename to python/qianfan/dataset/data_source/__init__.py diff --git a/src/qianfan/dataset/data_source/baidu_qianfan.py b/python/qianfan/dataset/data_source/baidu_qianfan.py similarity index 100% rename from src/qianfan/dataset/data_source/baidu_qianfan.py rename to python/qianfan/dataset/data_source/baidu_qianfan.py diff --git a/src/qianfan/dataset/data_source/base.py b/python/qianfan/dataset/data_source/base.py similarity index 100% rename from src/qianfan/dataset/data_source/base.py rename to python/qianfan/dataset/data_source/base.py diff --git a/src/qianfan/dataset/data_source/bos.py b/python/qianfan/dataset/data_source/bos.py similarity index 100% rename from src/qianfan/dataset/data_source/bos.py rename to python/qianfan/dataset/data_source/bos.py diff --git a/src/qianfan/dataset/data_source/file.py b/python/qianfan/dataset/data_source/file.py similarity index 100% rename from src/qianfan/dataset/data_source/file.py rename to python/qianfan/dataset/data_source/file.py diff --git a/src/qianfan/dataset/data_source/utils.py b/python/qianfan/dataset/data_source/utils.py similarity index 100% rename from src/qianfan/dataset/data_source/utils.py rename to python/qianfan/dataset/data_source/utils.py diff --git a/src/qianfan/dataset/dataset.py b/python/qianfan/dataset/dataset.py similarity index 100% rename from src/qianfan/dataset/dataset.py rename to python/qianfan/dataset/dataset.py diff --git a/src/qianfan/dataset/dataset_utils.py b/python/qianfan/dataset/dataset_utils.py similarity index 100% rename from src/qianfan/dataset/dataset_utils.py rename to python/qianfan/dataset/dataset_utils.py diff --git a/src/qianfan/dataset/local_data_operators/__init__.py b/python/qianfan/dataset/local_data_operators/__init__.py similarity index 100% rename from src/qianfan/dataset/local_data_operators/__init__.py rename to python/qianfan/dataset/local_data_operators/__init__.py diff --git a/src/qianfan/dataset/local_data_operators/base.py b/python/qianfan/dataset/local_data_operators/base.py similarity index 100% rename from src/qianfan/dataset/local_data_operators/base.py rename to python/qianfan/dataset/local_data_operators/base.py diff --git a/src/qianfan/dataset/local_data_operators/check_character_repetition_filter.py b/python/qianfan/dataset/local_data_operators/check_character_repetition_filter.py similarity index 100% rename from src/qianfan/dataset/local_data_operators/check_character_repetition_filter.py rename to python/qianfan/dataset/local_data_operators/check_character_repetition_filter.py diff --git a/src/qianfan/dataset/local_data_operators/check_flagged_words.py b/python/qianfan/dataset/local_data_operators/check_flagged_words.py similarity index 100% rename from src/qianfan/dataset/local_data_operators/check_flagged_words.py rename to python/qianfan/dataset/local_data_operators/check_flagged_words.py diff --git a/src/qianfan/dataset/local_data_operators/check_sentence_length_filter.py b/python/qianfan/dataset/local_data_operators/check_sentence_length_filter.py similarity index 100% rename from src/qianfan/dataset/local_data_operators/check_sentence_length_filter.py rename to python/qianfan/dataset/local_data_operators/check_sentence_length_filter.py diff --git a/src/qianfan/dataset/local_data_operators/check_special_characters.py b/python/qianfan/dataset/local_data_operators/check_special_characters.py similarity index 100% rename from src/qianfan/dataset/local_data_operators/check_special_characters.py rename to python/qianfan/dataset/local_data_operators/check_special_characters.py diff --git a/src/qianfan/dataset/local_data_operators/check_stopwords.py b/python/qianfan/dataset/local_data_operators/check_stopwords.py similarity index 100% rename from src/qianfan/dataset/local_data_operators/check_stopwords.py rename to python/qianfan/dataset/local_data_operators/check_stopwords.py diff --git a/src/qianfan/dataset/local_data_operators/check_word_number.py b/python/qianfan/dataset/local_data_operators/check_word_number.py similarity index 100% rename from src/qianfan/dataset/local_data_operators/check_word_number.py rename to python/qianfan/dataset/local_data_operators/check_word_number.py diff --git a/src/qianfan/dataset/local_data_operators/check_word_repetition_filter.py b/python/qianfan/dataset/local_data_operators/check_word_repetition_filter.py similarity index 100% rename from src/qianfan/dataset/local_data_operators/check_word_repetition_filter.py rename to python/qianfan/dataset/local_data_operators/check_word_repetition_filter.py diff --git a/src/qianfan/dataset/local_data_operators/consts.py b/python/qianfan/dataset/local_data_operators/consts.py similarity index 100% rename from src/qianfan/dataset/local_data_operators/consts.py rename to python/qianfan/dataset/local_data_operators/consts.py diff --git a/src/qianfan/dataset/local_data_operators/utils.py b/python/qianfan/dataset/local_data_operators/utils.py similarity index 100% rename from src/qianfan/dataset/local_data_operators/utils.py rename to python/qianfan/dataset/local_data_operators/utils.py diff --git a/src/qianfan/dataset/local_data_operators/word_list.py b/python/qianfan/dataset/local_data_operators/word_list.py similarity index 100% rename from src/qianfan/dataset/local_data_operators/word_list.py rename to python/qianfan/dataset/local_data_operators/word_list.py diff --git a/src/qianfan/dataset/process_interface.py b/python/qianfan/dataset/process_interface.py similarity index 100% rename from src/qianfan/dataset/process_interface.py rename to python/qianfan/dataset/process_interface.py diff --git a/python/qianfan/dataset/qianfan_data_operators.py b/python/qianfan/dataset/qianfan_data_operators.py new file mode 100644 index 00000000..44e73fba --- /dev/null +++ b/python/qianfan/dataset/qianfan_data_operators.py @@ -0,0 +1,161 @@ +# Copyright (c) 2023 Baidu, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +data operator for qianfan online +""" + + +from qianfan.utils.pydantic import BaseModel, Field + + +class QianfanOperator(BaseModel): + """Basic class for online ETL operator""" + + operator_name: str + operator_type: str + + +class ExceptionRegulator(QianfanOperator): + """Exception class for online ETL operator""" + + operator_type: str = "clean" + + +class Filter(QianfanOperator): + """Filter class for online ETL operator""" + + operator_type: str = "filter" + + +class Deduplicator(QianfanOperator): + """Deduplicator class for online ETL operator""" + + operator_type: str = "deduplication" + + +class DesensitizationProcessor(QianfanOperator): + """Sensitive data processor class for online ETL operator""" + + operator_type: str = "desensitization" + + +class RemoveEmoji(ExceptionRegulator): + """Exception class to remove emoji""" + + operator_name: str = "remove_emoji" + + +class RemoveInvisibleCharacter(ExceptionRegulator): + """Exception class to remove invisible character""" + + operator_name: str = "remove_invisible_character" + + +class ReplaceUniformWhitespace(ExceptionRegulator): + """Exception class to replace uniform whitespace""" + + operator_name: str = "replace_uniform_whitespace" + + +class RemoveNonMeaningCharacters(ExceptionRegulator): + """Exception class to remove non-meaning characters""" + + operator_name: str = "remove_non_meaning_characters" + + +class ReplaceTraditionalChineseToSimplified(ExceptionRegulator): + """Exception class to replace traditional chinese to simplified""" + + operator_name: str = "replace_traditional_chinese_to_simplified" + + +class RemoveWebIdentifiers(ExceptionRegulator): + """Exception class to remove web identifiers""" + + operator_name: str = "remove_web_identifiers" + + +class FilterCheckNumberWords(Filter): + """Filter class to check number of words""" + + operator_name: str = "filter_check_number_words" + number_words_min_cutoff: int = Field(default=1) + number_words_max_cutoff: int = Field(default=10000) + + +class FilterCheckWordRepetitionRemoval(Filter): + """Filter class to check word repetition removal""" + + operator_name: str = "filter_check_word_repetition_removal" + word_repetition_max_cutoff: float + + +class FilterCheckCharacterRepetitionRemoval(Filter): + """Filter class to check character repetition removal""" + + operator_name: str = "filter_check_character_repetition_removal" + default_character_repetition_max_cutoff: float + + +class FilterCheckSpecialCharacters(Filter): + """Filter class to check special characters""" + + operator_name: str = "filter_check_special_characters" + special_characters_max_cutoff: float + + +class FilterCheckFlaggedWords(Filter): + """Filter class to check flagged words""" + + operator_name: str = "filter_check_flagged_words" + flagged_words_max_cutoff: float + + +class FilterCheckLangId(Filter): + """Filter class to check lang id""" + + operator_name: str = "filter_check_lang_id" + lang_id_min_cutoff: float + + +class FilterCheckPerplexity(Filter): + """Filter class to check perplexity""" + + operator_name: str = "filter_check_perplexity" + perplexity_max_cutoff: int + + +class DeduplicationSimhash(Deduplicator): + """Deduplicator class to deduplicate by simhash""" + + operator_name: str = "deduplication_simhash" + distance: float + + +class ReplaceEmails(DesensitizationProcessor): + """Sensitive data processor class to replace emails""" + + operator_name: str = "replace_emails" + + +class ReplaceIp(DesensitizationProcessor): + """Sensitive data processor class to replace ip""" + + operator_name: str = "replace_ip" + + +class ReplaceIdentifier(DesensitizationProcessor): + """Sensitive data processor class to replace identifier""" + + operator_name: str = "replace_identifier" diff --git a/src/qianfan/dataset/schema.py b/python/qianfan/dataset/schema.py similarity index 100% rename from src/qianfan/dataset/schema.py rename to python/qianfan/dataset/schema.py diff --git a/src/qianfan/dataset/table.py b/python/qianfan/dataset/table.py similarity index 100% rename from src/qianfan/dataset/table.py rename to python/qianfan/dataset/table.py diff --git a/src/qianfan/dataset/table_utils.py b/python/qianfan/dataset/table_utils.py similarity index 100% rename from src/qianfan/dataset/table_utils.py rename to python/qianfan/dataset/table_utils.py diff --git a/src/qianfan/errors.py b/python/qianfan/errors.py similarity index 100% rename from src/qianfan/errors.py rename to python/qianfan/errors.py diff --git a/src/qianfan/evaluation/__init__.py b/python/qianfan/evaluation/__init__.py similarity index 100% rename from src/qianfan/evaluation/__init__.py rename to python/qianfan/evaluation/__init__.py diff --git a/src/qianfan/evaluation/consts.py b/python/qianfan/evaluation/consts.py similarity index 100% rename from src/qianfan/evaluation/consts.py rename to python/qianfan/evaluation/consts.py diff --git a/src/qianfan/evaluation/evaluation_manager.py b/python/qianfan/evaluation/evaluation_manager.py similarity index 100% rename from src/qianfan/evaluation/evaluation_manager.py rename to python/qianfan/evaluation/evaluation_manager.py diff --git a/src/qianfan/evaluation/evaluation_result.py b/python/qianfan/evaluation/evaluation_result.py similarity index 100% rename from src/qianfan/evaluation/evaluation_result.py rename to python/qianfan/evaluation/evaluation_result.py diff --git a/src/qianfan/evaluation/evaluator.py b/python/qianfan/evaluation/evaluator.py similarity index 100% rename from src/qianfan/evaluation/evaluator.py rename to python/qianfan/evaluation/evaluator.py diff --git a/src/qianfan/evaluation/opencompass_evaluator.py b/python/qianfan/evaluation/opencompass_evaluator.py similarity index 100% rename from src/qianfan/evaluation/opencompass_evaluator.py rename to python/qianfan/evaluation/opencompass_evaluator.py diff --git a/src/qianfan/extensions/README.md b/python/qianfan/extensions/README.md similarity index 100% rename from src/qianfan/extensions/README.md rename to python/qianfan/extensions/README.md diff --git a/src/qianfan/extensions/semantic_kernel/connectors/__init__.py b/python/qianfan/extensions/__init__.py similarity index 100% rename from src/qianfan/extensions/semantic_kernel/connectors/__init__.py rename to python/qianfan/extensions/__init__.py diff --git a/src/qianfan/extensions/langchain/__init__.py b/python/qianfan/extensions/langchain/__init__.py similarity index 100% rename from src/qianfan/extensions/langchain/__init__.py rename to python/qianfan/extensions/langchain/__init__.py diff --git a/src/qianfan/extensions/langchain/agents/__init__.py b/python/qianfan/extensions/langchain/agents/__init__.py similarity index 100% rename from src/qianfan/extensions/langchain/agents/__init__.py rename to python/qianfan/extensions/langchain/agents/__init__.py diff --git a/src/qianfan/extensions/langchain/agents/baidu_qianfan_endpoint.py b/python/qianfan/extensions/langchain/agents/baidu_qianfan_endpoint.py similarity index 100% rename from src/qianfan/extensions/langchain/agents/baidu_qianfan_endpoint.py rename to python/qianfan/extensions/langchain/agents/baidu_qianfan_endpoint.py diff --git a/src/qianfan/extensions/semantic_kernel/__init__.py b/python/qianfan/extensions/semantic_kernel/__init__.py similarity index 100% rename from src/qianfan/extensions/semantic_kernel/__init__.py rename to python/qianfan/extensions/semantic_kernel/__init__.py diff --git a/src/qianfan/resources/auth/__init__.py b/python/qianfan/extensions/semantic_kernel/connectors/__init__.py similarity index 100% rename from src/qianfan/resources/auth/__init__.py rename to python/qianfan/extensions/semantic_kernel/connectors/__init__.py diff --git a/src/qianfan/extensions/semantic_kernel/connectors/qianfan_chat_completion.py b/python/qianfan/extensions/semantic_kernel/connectors/qianfan_chat_completion.py similarity index 100% rename from src/qianfan/extensions/semantic_kernel/connectors/qianfan_chat_completion.py rename to python/qianfan/extensions/semantic_kernel/connectors/qianfan_chat_completion.py diff --git a/src/qianfan/extensions/semantic_kernel/connectors/qianfan_settings.py b/python/qianfan/extensions/semantic_kernel/connectors/qianfan_settings.py similarity index 100% rename from src/qianfan/extensions/semantic_kernel/connectors/qianfan_settings.py rename to python/qianfan/extensions/semantic_kernel/connectors/qianfan_settings.py diff --git a/src/qianfan/extensions/semantic_kernel/connectors/qianfan_text_completion.py b/python/qianfan/extensions/semantic_kernel/connectors/qianfan_text_completion.py similarity index 100% rename from src/qianfan/extensions/semantic_kernel/connectors/qianfan_text_completion.py rename to python/qianfan/extensions/semantic_kernel/connectors/qianfan_text_completion.py diff --git a/src/qianfan/extensions/semantic_kernel/connectors/qianfan_text_embedding.py b/python/qianfan/extensions/semantic_kernel/connectors/qianfan_text_embedding.py similarity index 100% rename from src/qianfan/extensions/semantic_kernel/connectors/qianfan_text_embedding.py rename to python/qianfan/extensions/semantic_kernel/connectors/qianfan_text_embedding.py diff --git a/src/qianfan/model/__init__.py b/python/qianfan/model/__init__.py similarity index 100% rename from src/qianfan/model/__init__.py rename to python/qianfan/model/__init__.py diff --git a/src/qianfan/model/configs.py b/python/qianfan/model/configs.py similarity index 100% rename from src/qianfan/model/configs.py rename to python/qianfan/model/configs.py diff --git a/src/qianfan/model/consts.py b/python/qianfan/model/consts.py similarity index 100% rename from src/qianfan/model/consts.py rename to python/qianfan/model/consts.py diff --git a/src/qianfan/model/model.py b/python/qianfan/model/model.py similarity index 100% rename from src/qianfan/model/model.py rename to python/qianfan/model/model.py diff --git a/src/qianfan/py.typed b/python/qianfan/py.typed similarity index 100% rename from src/qianfan/py.typed rename to python/qianfan/py.typed diff --git a/src/qianfan/resources/__init__.py b/python/qianfan/resources/__init__.py similarity index 100% rename from src/qianfan/resources/__init__.py rename to python/qianfan/resources/__init__.py diff --git a/src/qianfan/resources/console/__init__.py b/python/qianfan/resources/auth/__init__.py similarity index 100% rename from src/qianfan/resources/console/__init__.py rename to python/qianfan/resources/auth/__init__.py diff --git a/src/qianfan/resources/auth/iam.py b/python/qianfan/resources/auth/iam.py similarity index 100% rename from src/qianfan/resources/auth/iam.py rename to python/qianfan/resources/auth/iam.py diff --git a/src/qianfan/resources/auth/oauth.py b/python/qianfan/resources/auth/oauth.py similarity index 100% rename from src/qianfan/resources/auth/oauth.py rename to python/qianfan/resources/auth/oauth.py diff --git a/src/qianfan/resources/images/__init__.py b/python/qianfan/resources/console/__init__.py similarity index 100% rename from src/qianfan/resources/images/__init__.py rename to python/qianfan/resources/console/__init__.py diff --git a/src/qianfan/resources/console/consts.py b/python/qianfan/resources/console/consts.py similarity index 100% rename from src/qianfan/resources/console/consts.py rename to python/qianfan/resources/console/consts.py diff --git a/src/qianfan/resources/console/data.py b/python/qianfan/resources/console/data.py similarity index 100% rename from src/qianfan/resources/console/data.py rename to python/qianfan/resources/console/data.py diff --git a/src/qianfan/resources/console/finetune.py b/python/qianfan/resources/console/finetune.py similarity index 100% rename from src/qianfan/resources/console/finetune.py rename to python/qianfan/resources/console/finetune.py diff --git a/src/qianfan/resources/console/model.py b/python/qianfan/resources/console/model.py similarity index 100% rename from src/qianfan/resources/console/model.py rename to python/qianfan/resources/console/model.py diff --git a/src/qianfan/resources/console/prompt.py b/python/qianfan/resources/console/prompt.py similarity index 100% rename from src/qianfan/resources/console/prompt.py rename to python/qianfan/resources/console/prompt.py diff --git a/src/qianfan/resources/console/service.py b/python/qianfan/resources/console/service.py similarity index 100% rename from src/qianfan/resources/console/service.py rename to python/qianfan/resources/console/service.py diff --git a/src/qianfan/resources/console/utils.py b/python/qianfan/resources/console/utils.py similarity index 100% rename from src/qianfan/resources/console/utils.py rename to python/qianfan/resources/console/utils.py diff --git a/src/qianfan/resources/http_client.py b/python/qianfan/resources/http_client.py similarity index 100% rename from src/qianfan/resources/http_client.py rename to python/qianfan/resources/http_client.py diff --git a/src/qianfan/resources/llm/__init__.py b/python/qianfan/resources/images/__init__.py similarity index 100% rename from src/qianfan/resources/llm/__init__.py rename to python/qianfan/resources/images/__init__.py diff --git a/src/qianfan/resources/images/image2text.py b/python/qianfan/resources/images/image2text.py similarity index 100% rename from src/qianfan/resources/images/image2text.py rename to python/qianfan/resources/images/image2text.py diff --git a/src/qianfan/resources/images/text2image.py b/python/qianfan/resources/images/text2image.py similarity index 100% rename from src/qianfan/resources/images/text2image.py rename to python/qianfan/resources/images/text2image.py diff --git a/src/qianfan/resources/requestor/__init__.py b/python/qianfan/resources/llm/__init__.py similarity index 100% rename from src/qianfan/resources/requestor/__init__.py rename to python/qianfan/resources/llm/__init__.py diff --git a/src/qianfan/resources/llm/base.py b/python/qianfan/resources/llm/base.py similarity index 100% rename from src/qianfan/resources/llm/base.py rename to python/qianfan/resources/llm/base.py diff --git a/src/qianfan/resources/llm/chat_completion.py b/python/qianfan/resources/llm/chat_completion.py similarity index 100% rename from src/qianfan/resources/llm/chat_completion.py rename to python/qianfan/resources/llm/chat_completion.py diff --git a/src/qianfan/resources/llm/completion.py b/python/qianfan/resources/llm/completion.py similarity index 100% rename from src/qianfan/resources/llm/completion.py rename to python/qianfan/resources/llm/completion.py diff --git a/src/qianfan/resources/llm/embedding.py b/python/qianfan/resources/llm/embedding.py similarity index 100% rename from src/qianfan/resources/llm/embedding.py rename to python/qianfan/resources/llm/embedding.py diff --git a/src/qianfan/resources/llm/plugin.py b/python/qianfan/resources/llm/plugin.py similarity index 100% rename from src/qianfan/resources/llm/plugin.py rename to python/qianfan/resources/llm/plugin.py diff --git a/src/qianfan/resources/rate_limiter.py b/python/qianfan/resources/rate_limiter.py similarity index 100% rename from src/qianfan/resources/rate_limiter.py rename to python/qianfan/resources/rate_limiter.py diff --git a/src/qianfan/resources/tools/__init__.py b/python/qianfan/resources/requestor/__init__.py similarity index 100% rename from src/qianfan/resources/tools/__init__.py rename to python/qianfan/resources/requestor/__init__.py diff --git a/src/qianfan/resources/requestor/base.py b/python/qianfan/resources/requestor/base.py similarity index 100% rename from src/qianfan/resources/requestor/base.py rename to python/qianfan/resources/requestor/base.py diff --git a/src/qianfan/resources/requestor/console_requestor.py b/python/qianfan/resources/requestor/console_requestor.py similarity index 100% rename from src/qianfan/resources/requestor/console_requestor.py rename to python/qianfan/resources/requestor/console_requestor.py diff --git a/src/qianfan/resources/requestor/openapi_requestor.py b/python/qianfan/resources/requestor/openapi_requestor.py similarity index 100% rename from src/qianfan/resources/requestor/openapi_requestor.py rename to python/qianfan/resources/requestor/openapi_requestor.py diff --git a/src/qianfan/tests/__init__.py b/python/qianfan/resources/tools/__init__.py similarity index 100% rename from src/qianfan/tests/__init__.py rename to python/qianfan/resources/tools/__init__.py diff --git a/src/qianfan/resources/tools/tokenizer.py b/python/qianfan/resources/tools/tokenizer.py similarity index 100% rename from src/qianfan/resources/tools/tokenizer.py rename to python/qianfan/resources/tools/tokenizer.py diff --git a/src/qianfan/resources/tools/utils.py b/python/qianfan/resources/tools/utils.py similarity index 100% rename from src/qianfan/resources/tools/utils.py rename to python/qianfan/resources/tools/utils.py diff --git a/src/qianfan/resources/typing.py b/python/qianfan/resources/typing.py similarity index 100% rename from src/qianfan/resources/typing.py rename to python/qianfan/resources/typing.py diff --git a/src/qianfan/tests/dataset/__init__.py b/python/qianfan/tests/__init__.py similarity index 100% rename from src/qianfan/tests/dataset/__init__.py rename to python/qianfan/tests/__init__.py diff --git a/src/qianfan/tests/auth_test.py b/python/qianfan/tests/auth_test.py similarity index 100% rename from src/qianfan/tests/auth_test.py rename to python/qianfan/tests/auth_test.py diff --git a/src/qianfan/tests/chat_completion_test.py b/python/qianfan/tests/chat_completion_test.py similarity index 100% rename from src/qianfan/tests/chat_completion_test.py rename to python/qianfan/tests/chat_completion_test.py diff --git a/src/qianfan/tests/completion_test.py b/python/qianfan/tests/completion_test.py similarity index 100% rename from src/qianfan/tests/completion_test.py rename to python/qianfan/tests/completion_test.py diff --git a/src/qianfan/tests/config_test.py b/python/qianfan/tests/config_test.py similarity index 100% rename from src/qianfan/tests/config_test.py rename to python/qianfan/tests/config_test.py diff --git a/src/qianfan/tests/conftest.py b/python/qianfan/tests/conftest.py similarity index 100% rename from src/qianfan/tests/conftest.py rename to python/qianfan/tests/conftest.py diff --git a/src/qianfan/tests/data_api_test.py b/python/qianfan/tests/data_api_test.py similarity index 100% rename from src/qianfan/tests/data_api_test.py rename to python/qianfan/tests/data_api_test.py diff --git a/src/qianfan/tests/langchain/__init__.py b/python/qianfan/tests/dataset/__init__.py similarity index 100% rename from src/qianfan/tests/langchain/__init__.py rename to python/qianfan/tests/dataset/__init__.py diff --git a/src/qianfan/tests/dataset/data_source_test.py b/python/qianfan/tests/dataset/data_source_test.py similarity index 100% rename from src/qianfan/tests/dataset/data_source_test.py rename to python/qianfan/tests/dataset/data_source_test.py diff --git a/src/qianfan/tests/dataset/dataset_test.py b/python/qianfan/tests/dataset/dataset_test.py similarity index 100% rename from src/qianfan/tests/dataset/dataset_test.py rename to python/qianfan/tests/dataset/dataset_test.py diff --git a/src/qianfan/tests/dataset/table_test.py b/python/qianfan/tests/dataset/table_test.py similarity index 100% rename from src/qianfan/tests/dataset/table_test.py rename to python/qianfan/tests/dataset/table_test.py diff --git a/src/qianfan/tests/embedding_test.py b/python/qianfan/tests/embedding_test.py similarity index 100% rename from src/qianfan/tests/embedding_test.py rename to python/qianfan/tests/embedding_test.py diff --git a/src/qianfan/tests/finetune_test.py b/python/qianfan/tests/finetune_test.py similarity index 100% rename from src/qianfan/tests/finetune_test.py rename to python/qianfan/tests/finetune_test.py diff --git a/src/qianfan/tests/hub_test.py b/python/qianfan/tests/hub_test.py similarity index 100% rename from src/qianfan/tests/hub_test.py rename to python/qianfan/tests/hub_test.py diff --git a/src/qianfan/tests/image2text_test.py b/python/qianfan/tests/image2text_test.py similarity index 100% rename from src/qianfan/tests/image2text_test.py rename to python/qianfan/tests/image2text_test.py diff --git a/python/qianfan/tests/langchain/__init__.py b/python/qianfan/tests/langchain/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/qianfan/tests/langchain/agent_test.py b/python/qianfan/tests/langchain/agent_test.py similarity index 100% rename from src/qianfan/tests/langchain/agent_test.py rename to python/qianfan/tests/langchain/agent_test.py diff --git a/src/qianfan/tests/latency_test.py b/python/qianfan/tests/latency_test.py similarity index 100% rename from src/qianfan/tests/latency_test.py rename to python/qianfan/tests/latency_test.py diff --git a/src/qianfan/tests/model_test.py b/python/qianfan/tests/model_test.py similarity index 100% rename from src/qianfan/tests/model_test.py rename to python/qianfan/tests/model_test.py diff --git a/src/qianfan/tests/plugin_test.py b/python/qianfan/tests/plugin_test.py similarity index 100% rename from src/qianfan/tests/plugin_test.py rename to python/qianfan/tests/plugin_test.py diff --git a/src/qianfan/tests/prompt_class_test.py b/python/qianfan/tests/prompt_class_test.py similarity index 100% rename from src/qianfan/tests/prompt_class_test.py rename to python/qianfan/tests/prompt_class_test.py diff --git a/src/qianfan/tests/prompt_resource_test.py b/python/qianfan/tests/prompt_resource_test.py similarity index 100% rename from src/qianfan/tests/prompt_resource_test.py rename to python/qianfan/tests/prompt_resource_test.py diff --git a/src/qianfan/tests/rate_limiter_test.py b/python/qianfan/tests/rate_limiter_test.py similarity index 100% rename from src/qianfan/tests/rate_limiter_test.py rename to python/qianfan/tests/rate_limiter_test.py diff --git a/src/qianfan/tests/retry_test.py b/python/qianfan/tests/retry_test.py similarity index 100% rename from src/qianfan/tests/retry_test.py rename to python/qianfan/tests/retry_test.py diff --git a/src/qianfan/tests/service_test.py b/python/qianfan/tests/service_test.py similarity index 100% rename from src/qianfan/tests/service_test.py rename to python/qianfan/tests/service_test.py diff --git a/src/qianfan/tests/text2image_test.py b/python/qianfan/tests/text2image_test.py similarity index 100% rename from src/qianfan/tests/text2image_test.py rename to python/qianfan/tests/text2image_test.py diff --git a/src/qianfan/tests/tokenizer_test.py b/python/qianfan/tests/tokenizer_test.py similarity index 100% rename from src/qianfan/tests/tokenizer_test.py rename to python/qianfan/tests/tokenizer_test.py diff --git a/src/qianfan/tests/tool_test.py b/python/qianfan/tests/tool_test.py similarity index 100% rename from src/qianfan/tests/tool_test.py rename to python/qianfan/tests/tool_test.py diff --git a/src/qianfan/tests/trainer_test.py b/python/qianfan/tests/trainer_test.py similarity index 100% rename from src/qianfan/tests/trainer_test.py rename to python/qianfan/tests/trainer_test.py diff --git a/src/qianfan/tests/utils/__init__.py b/python/qianfan/tests/utils/__init__.py similarity index 100% rename from src/qianfan/tests/utils/__init__.py rename to python/qianfan/tests/utils/__init__.py diff --git a/src/qianfan/tests/utils/mock_server.py b/python/qianfan/tests/utils/mock_server.py similarity index 99% rename from src/qianfan/tests/utils/mock_server.py rename to python/qianfan/tests/utils/mock_server.py index dcd5b46f..a35ce82f 100644 --- a/src/qianfan/tests/utils/mock_server.py +++ b/python/qianfan/tests/utils/mock_server.py @@ -258,7 +258,7 @@ def chat(model_name): """ r = request.json request_header = request.headers - request_id = request_header[Consts.XRequestID] + request_id = request_header.get(Consts.XRequestID) if request_id == "custom_req": return json_response( { diff --git a/src/qianfan/tests/utils/utils.py b/python/qianfan/tests/utils/utils.py similarity index 100% rename from src/qianfan/tests/utils/utils.py rename to python/qianfan/tests/utils/utils.py diff --git a/src/qianfan/trainer/__init__.py b/python/qianfan/trainer/__init__.py similarity index 100% rename from src/qianfan/trainer/__init__.py rename to python/qianfan/trainer/__init__.py diff --git a/src/qianfan/trainer/actions.py b/python/qianfan/trainer/actions.py similarity index 100% rename from src/qianfan/trainer/actions.py rename to python/qianfan/trainer/actions.py diff --git a/src/qianfan/trainer/base.py b/python/qianfan/trainer/base.py similarity index 100% rename from src/qianfan/trainer/base.py rename to python/qianfan/trainer/base.py diff --git a/src/qianfan/trainer/configs.py b/python/qianfan/trainer/configs.py similarity index 100% rename from src/qianfan/trainer/configs.py rename to python/qianfan/trainer/configs.py diff --git a/src/qianfan/trainer/consts.py b/python/qianfan/trainer/consts.py similarity index 100% rename from src/qianfan/trainer/consts.py rename to python/qianfan/trainer/consts.py diff --git a/src/qianfan/trainer/event.py b/python/qianfan/trainer/event.py similarity index 100% rename from src/qianfan/trainer/event.py rename to python/qianfan/trainer/event.py diff --git a/src/qianfan/trainer/finetune.py b/python/qianfan/trainer/finetune.py similarity index 100% rename from src/qianfan/trainer/finetune.py rename to python/qianfan/trainer/finetune.py diff --git a/src/qianfan/utils/__init__.py b/python/qianfan/utils/__init__.py similarity index 100% rename from src/qianfan/utils/__init__.py rename to python/qianfan/utils/__init__.py diff --git a/src/qianfan/utils/bos_uploader.py b/python/qianfan/utils/bos_uploader.py similarity index 100% rename from src/qianfan/utils/bos_uploader.py rename to python/qianfan/utils/bos_uploader.py diff --git a/src/qianfan/utils/helper.py b/python/qianfan/utils/helper.py similarity index 100% rename from src/qianfan/utils/helper.py rename to python/qianfan/utils/helper.py diff --git a/src/qianfan/utils/logging.py b/python/qianfan/utils/logging.py similarity index 100% rename from src/qianfan/utils/logging.py rename to python/qianfan/utils/logging.py diff --git a/src/qianfan/utils/pydantic/__init__.py b/python/qianfan/utils/pydantic/__init__.py similarity index 100% rename from src/qianfan/utils/pydantic/__init__.py rename to python/qianfan/utils/pydantic/__init__.py diff --git a/src/qianfan/utils/utils.py b/python/qianfan/utils/utils.py similarity index 100% rename from src/qianfan/utils/utils.py rename to python/qianfan/utils/utils.py diff --git a/src/qianfan/version.py b/python/qianfan/version.py similarity index 85% rename from src/qianfan/version.py rename to python/qianfan/version.py index 1edb9a59..7747698d 100644 --- a/src/qianfan/version.py +++ b/python/qianfan/version.py @@ -19,4 +19,7 @@ except ImportError: # for Python<3.8 import importlib_metadata as metadata # type: ignore -VERSION = metadata.version("qianfan") +try: + VERSION = metadata.version("qianfan") +except Exception: + VERSION = "0.0.0" # means not installed by any package manager diff --git a/python/scripts/build.sh b/python/scripts/build.sh new file mode 100644 index 00000000..4806a566 --- /dev/null +++ b/python/scripts/build.sh @@ -0,0 +1,14 @@ +#!/bin/bash +set -e + +SCRIPT=$(readlink -f -f "$0") +SCRIPTPATH=$(dirname "$SCRIPT") +ROOTPATH="${SCRIPTPATH}/../../" +OUTPUT_PATH="${PWD}/output" + +# build wheel +make clean +poetry build +mkdir -p "${OUTPUT_PATH}" +mv dist/* "${OUTPUT_PATH}" +rm -rf dist diff --git a/python/scripts/build_doc.sh b/python/scripts/build_doc.sh new file mode 100644 index 00000000..42a93411 --- /dev/null +++ b/python/scripts/build_doc.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +set -ex + +SCRIPT=$(readlink -f "$0") +SCRIPTPATH=$(dirname "$SCRIPT") +ROOTPATH="${SCRIPTPATH}/../../" +OUTPUT_PATH="${PWD}/output" +DOCS_PATH="${OUTPUT_PATH}/docs_tmp" + +sphinx-apidoc -f -F -M -o "${DOCS_PATH}" -t "${ROOTPATH}/docs/template" "${SCRIPTPATH}/../qianfan" "*test*" +cp "${ROOTPATH}/README.md" "${DOCS_PATH}" +cp -r "${ROOTPATH}/docs" "${DOCS_PATH}" +make -C "${DOCS_PATH}" html +mv "${DOCS_PATH}" "${OUTPUT_PATH}/docs" +rm -rf "${DOCS_PATH}" \ No newline at end of file diff --git a/src/scripts/release_github.sh b/python/scripts/release_github.sh similarity index 100% rename from src/scripts/release_github.sh rename to python/scripts/release_github.sh diff --git a/python/scripts/run_mock_server.sh b/python/scripts/run_mock_server.sh new file mode 100644 index 00000000..5a66f72d --- /dev/null +++ b/python/scripts/run_mock_server.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +set -x + +SCRIPT=$(readlink -f "$0") +SCRIPTPATH=$(dirname "$SCRIPT") + +cd "${SCRIPTPATH}/../" +export PYTHONPATH="${SCRIPTPATH}/../" +nohup poetry run python "${SCRIPTPATH}/../qianfan/tests/utils/mock_server.py" > /tmp/mock_server 2>&1 & + +for i in {1..20}; do + curl 127.0.0.1:8866 > /dev/null 2>&1 + if [ $? = 0 ]; + then + exit 0 + fi + sleep 0.5 +done +echo "Start mock server failed" +cat /tmp/mock_server +exit 1 \ No newline at end of file diff --git a/src/scripts/run_test.sh b/python/scripts/run_test.sh similarity index 100% rename from src/scripts/run_test.sh rename to python/scripts/run_test.sh diff --git a/src/scripts/build.sh b/src/scripts/build.sh deleted file mode 100644 index ea9814e7..00000000 --- a/src/scripts/build.sh +++ /dev/null @@ -1,14 +0,0 @@ -#!/bin/bash -set -e - -# build wheel -make clean -poetry build -mkdir output -mv dist/* output -rm -rf dist - -# build docs -make doc -mv build/docs/_build/ ./output/docs -rm -rf build \ No newline at end of file diff --git a/src/scripts/build_doc.sh b/src/scripts/build_doc.sh deleted file mode 100644 index 57648a9a..00000000 --- a/src/scripts/build_doc.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/bash - -set -e - -sphinx-apidoc -f -F -M -o build/docs -t src/qianfan/docs src/qianfan "*test*" -cp README.md build/docs -cp -r docs build/docs -cd build/docs -make html \ No newline at end of file From 67ba140e2515a1fbdf7a18b880556df7c197615d Mon Sep 17 00:00:00 2001 From: Liu Jun Date: Fri, 2 Feb 2024 16:39:48 +0800 Subject: [PATCH 7/8] support mixtral (#249) --- .github/workflows/go_ci.yml | 3 ++- .github/workflows/py_ci.yml | 2 +- python/qianfan/resources/llm/chat_completion.py | 15 +++++++++++++++ python/qianfan/resources/llm/completion.py | 15 +++++++++++++++ 4 files changed, 33 insertions(+), 2 deletions(-) diff --git a/.github/workflows/go_ci.yml b/.github/workflows/go_ci.yml index 8bc4f498..b56e67c5 100644 --- a/.github/workflows/go_ci.yml +++ b/.github/workflows/go_ci.yml @@ -2,8 +2,9 @@ on: push: branches: ['main'] pull_request: - path: + paths: - "go/**" + - "python/qianfan/tests/utils/mock_server.py" workflow_dispatch: name: Go CI diff --git a/.github/workflows/py_ci.yml b/.github/workflows/py_ci.yml index 10ac628a..54ac871d 100644 --- a/.github/workflows/py_ci.yml +++ b/.github/workflows/py_ci.yml @@ -2,7 +2,7 @@ on: push: branches: ['main'] pull_request: - path: + paths: - "python/**" workflow_dispatch: diff --git a/python/qianfan/resources/llm/chat_completion.py b/python/qianfan/resources/llm/chat_completion.py index 6cf1f940..78b0534f 100644 --- a/python/qianfan/resources/llm/chat_completion.py +++ b/python/qianfan/resources/llm/chat_completion.py @@ -330,6 +330,21 @@ def _supported_models(cls) -> Dict[str, QfLLMInfo]: "tool_choice", }, ), + "Mixtral-8x7B-Instruct": QfLLMInfo( + endpoint="/chat/mixtral_8x7b_instruct", + required_keys={"messages"}, + optional_keys={ + "stream", + "user_id", + "temperature", + "top_k", + "top_p", + "penalty_score", + "stop", + "tools", + "tool_choice", + }, + ), UNSPECIFIED_MODEL: QfLLMInfo( endpoint="", required_keys={"messages"}, diff --git a/python/qianfan/resources/llm/completion.py b/python/qianfan/resources/llm/completion.py index c0dc74d9..298bafff 100644 --- a/python/qianfan/resources/llm/completion.py +++ b/python/qianfan/resources/llm/completion.py @@ -348,6 +348,21 @@ def _supported_models(cls) -> Dict[str, QfLLMInfo]: "tool_choice", }, ), + "Mixtral-8x7B-Instruct": QfLLMInfo( + endpoint="/chat/mixtral_8x7b_instruct", + required_keys={"messages"}, + optional_keys={ + "stream", + "user_id", + "temperature", + "top_k", + "top_p", + "penalty_score", + "stop", + "tools", + "tool_choice", + }, + ), UNSPECIFIED_MODEL: QfLLMInfo( endpoint="", required_keys={"prompt"}, From 7af9a5df8b089475b423ca80e1ae2b43eedc6659 Mon Sep 17 00:00:00 2001 From: Dobiichi-Origami <56953648+Dobiichi-Origami@users.noreply.github.com> Date: Fri, 2 Feb 2024 16:41:02 +0800 Subject: [PATCH 8/8] fix thread work pool concurrent bug (#248) --- python/qianfan/resources/llm/base.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/python/qianfan/resources/llm/base.py b/python/qianfan/resources/llm/base.py index 982e9726..b6f89fa7 100644 --- a/python/qianfan/resources/llm/base.py +++ b/python/qianfan/resources/llm/base.py @@ -47,6 +47,7 @@ # This constant is used to express no model is spcified, # so that SDK still can get the requirements of API from _supported_models() UNSPECIFIED_MODEL = "UNSPECIFIED_MODEL" +MAX_WORKER_THREAD_COUNT = 100000 class BatchRequestFuture(object): @@ -62,16 +63,19 @@ def __init__( """ Init batch request future """ - future_list: List[Future[Union[QfResponse, Iterator[QfResponse]]]] = [] max_workers = worker_num if worker_num else len(tasks) + 1 + if max_workers > MAX_WORKER_THREAD_COUNT: + max_workers = MAX_WORKER_THREAD_COUNT + + self._future_list: List[Future[Union[QfResponse, Iterator[QfResponse]]]] = [] self._executor = ThreadPoolExecutor(max_workers=max_workers) + self._finished_count = 0 + self._task_count = len(tasks) + self._lock = threading.Lock() for task in tasks: future = self._executor.submit(task) future.add_done_callback(self._future_callback) - future_list.append(future) - self._future_list = future_list - self._finished_count = 0 - self._lock = threading.Lock() + self._future_list.append(future) def _future_callback( self, fn: Future[Union[QfResponse, Iterator[QfResponse]]] @@ -81,7 +85,7 @@ def _future_callback( """ with self._lock: self._finished_count += 1 - if self._finished_count == len(self._future_list): + if self._finished_count == self._task_count: log_info("All tasks finished, exeutor will be shutdown") self._executor.shutdown(wait=False)