Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parameters in clickhouse-cpp #394

Merged
merged 9 commits into from
Feb 17, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions clickhouse/base/wire_format.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <assert.h>
#include "wire_format.h"

#include "input.h"
Expand Down Expand Up @@ -99,4 +100,83 @@ bool WireFormat::SkipString(InputStream& input) {
return false;
}

const char quoted_chars[] = {'\0', '\b', '\t', '\n', '\'', '\\'};

inline const char* find_quoted_chars(const char* start, const char* end) {
while (start < end) {
char c = *start;
for (unsigned i = 0; i < sizeof(quoted_chars); i++) {
if (quoted_chars[i] == c) return start;
}
start++;
}
return nullptr;
}

void WireFormat::WriteQuotedString(OutputStream& output, std::optional<std::string_view> opt) {
if (!opt) //NULL
{
WriteVarint64(output, 5);
WriteAll(output, "'\\\\N'", 5);
return;
}
std::string_view value = *opt;
auto size = value.size();
const char* start = value.data();
const char* end = start + size;
const char* quoted_char = find_quoted_chars(start, end);
if (quoted_char == nullptr) {
WriteVarint64(output, size + 2);
WriteAll(output, "'", 1);
WriteAll(output, start, size);
WriteAll(output, "'", 1);
return;
}

// calculate quoted chars count
int quoted_count = 1;
const char* next_quoted_char = quoted_char + 1;
while ((next_quoted_char = find_quoted_chars(next_quoted_char, end))) {
quoted_count++;
next_quoted_char++;
}
WriteVarint64(output, size + 2 + 3 * quoted_count); // length

WriteAll(output, "'", 1);

do {
auto write_size = quoted_char - start;
WriteAll(output, start, write_size);
WriteAll(output, "\\", 1);
char c = quoted_char[0];
switch (c) {
case '\0':
WriteAll(output, "x00", 3);
break;
case '\b':
WriteAll(output, "x08", 3);
break;
case '\t':
WriteAll(output, "\\\\t", 3);
break;
case '\n':
WriteAll(output, "\\\\\n", 3);
break;
case '\'':
WriteAll(output, "x27", 3);
break;
case '\\':
WriteAll(output, "\\\\\\", 3);
break;
default:
assert(false);
WriteAll(output, "x3F", 3); // out ?
}
start = quoted_char + 1;
quoted_char = find_quoted_chars(start, end);
} while (quoted_char);

WriteAll(output, start, end - start);
WriteAll(output, "'", 1);
}
}
2 changes: 2 additions & 0 deletions clickhouse/base/wire_format.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <string>
#include <cstdint>
#include <optional>

namespace clickhouse {

Expand All @@ -22,6 +23,7 @@ class WireFormat {
static void WriteFixed(OutputStream& output, const T& value);
static void WriteBytes(OutputStream& output, const void* buf, size_t len);
static void WriteString(OutputStream& output, std::string_view value);
static void WriteQuotedString(OutputStream& output, std::optional<std::string_view> value);
static void WriteUInt64(OutputStream& output, const uint64_t value);
static void WriteVarint64(OutputStream& output, uint64_t value);

Expand Down
50 changes: 46 additions & 4 deletions clickhouse/client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,13 @@
#define DBMS_MIN_REVISION_WITH_DISTRIBUTED_DEPTH 54448
#define DBMS_MIN_REVISION_WITH_INITIAL_QUERY_START_TIME 54449
#define DBMS_MIN_REVISION_WITH_INCREMENTAL_PROFILE_EVENTS 54451
#define DBMS_MIN_REVISION_WITH_PARALLEL_REPLICAS 54453
#define DBMS_MIN_REVISION_WITH_CUSTOM_SERIALIZATION 54454 // Client can get some fields in JSon format
#define DBMS_MIN_PROTOCOL_VERSION_WITH_ADDENDUM 54458 // send quota key after handshake
#define DBMS_MIN_PROTOCOL_REVISION_WITH_QUOTA_KEY 54458 // the same
#define DBMS_MIN_PROTOCOL_VERSION_WITH_PARAMETERS 54459

#define DMBS_PROTOCOL_REVISION DBMS_MIN_REVISION_WITH_INCREMENTAL_PROFILE_EVENTS
#define DMBS_PROTOCOL_REVISION DBMS_MIN_PROTOCOL_VERSION_WITH_PARAMETERS

namespace clickhouse {

Expand Down Expand Up @@ -433,6 +438,11 @@ bool Client::Impl::Handshake() {
if (!ReceiveHello()) {
return false;
}

if (server_info_.revision >= DBMS_MIN_PROTOCOL_VERSION_WITH_ADDENDUM) {
WireFormat::WriteString(*output_, std::string());
}

return true;
}

Expand Down Expand Up @@ -502,7 +512,7 @@ bool Client::Impl::ReceivePacket(uint64_t* server_packet) {
return false;
}
}
if constexpr (DMBS_PROTOCOL_REVISION >= DBMS_MIN_REVISION_WITH_CLIENT_WRITE_INFO)
if (server_info_.revision >= DBMS_MIN_REVISION_WITH_CLIENT_WRITE_INFO)
{
if (!WireFormat::ReadUInt64(*input_, &info.written_rows)) {
return false;
Expand Down Expand Up @@ -589,7 +599,7 @@ bool Client::Impl::ReceivePacket(uint64_t* server_packet) {

bool Client::Impl::ReadBlock(InputStream& input, Block* block) {
// Additional information about block.
if constexpr (DMBS_PROTOCOL_REVISION >= DBMS_MIN_REVISION_WITH_BLOCK_INFO) {
if (server_info_.revision >= DBMS_MIN_REVISION_WITH_BLOCK_INFO) {
uint64_t num;
BlockInfo info;

Expand Down Expand Up @@ -635,6 +645,16 @@ bool Client::Impl::ReadBlock(InputStream& input, Block* block) {
if (!WireFormat::ReadString(input, &type)) {
return false;
}

if (server_info_.revision >= DBMS_MIN_REVISION_WITH_CUSTOM_SERIALIZATION) {
uint8_t custom_format_len;
if (!WireFormat::ReadFixed(input, &custom_format_len)) {
return false;
}
if (custom_format_len > 0) {
throw UnimplementedError(std::string("unsupported custom serialization"));
}
}

if (ColumnRef col = CreateColumnByType(type, create_column_settings)) {
if (num_rows && !col->Load(&input, num_rows)) {
Expand All @@ -653,7 +673,7 @@ bool Client::Impl::ReadBlock(InputStream& input, Block* block) {
bool Client::Impl::ReceiveData() {
Block block;

if constexpr (DMBS_PROTOCOL_REVISION >= DBMS_MIN_REVISION_WITH_TEMPORARY_TABLES) {
if (server_info_.revision >= DBMS_MIN_REVISION_WITH_TEMPORARY_TABLES) {
if (!WireFormat::SkipString(*input_)) {
return false;
}
Expand Down Expand Up @@ -793,6 +813,11 @@ void Client::Impl::SendQuery(const Query& query) {
throw UnimplementedError(std::string("Can't send open telemetry tracing context to a server, server version is too old"));
}
}
if (server_info_.revision >= DBMS_MIN_REVISION_WITH_PARALLEL_REPLICAS) {
WireFormat::WriteUInt64(*output_, 0);
WireFormat::WriteUInt64(*output_, 0);
WireFormat::WriteUInt64(*output_, 0);
}
}

/// Per query settings
Expand All @@ -817,6 +842,18 @@ void Client::Impl::SendQuery(const Query& query) {
WireFormat::WriteUInt64(*output_, Stages::Complete);
WireFormat::WriteUInt64(*output_, compression_);
WireFormat::WriteString(*output_, query.GetText());

//Send params after query text
if (server_info_.revision >= DBMS_MIN_PROTOCOL_VERSION_WITH_PARAMETERS) {
for(const auto& [name, value] : query.GetParams()) {
// params is like query settings
WireFormat::WriteString(*output_, name);
WireFormat::WriteVarint64(*output_, 2); // Custom
WireFormat::WriteQuotedString(*output_, value);
}
WireFormat::WriteString(*output_, std::string()); // empty string after last param
}

// Send empty block as marker of
// end of data
SendData(Block());
Expand All @@ -842,6 +879,11 @@ void Client::Impl::WriteBlock(const Block& block, OutputStream& output) {
WireFormat::WriteString(output, bi.Name());
WireFormat::WriteString(output, bi.Type()->GetName());

if (server_info_.revision >= DBMS_MIN_REVISION_WITH_CUSTOM_SERIALIZATION) {
// TODO: custom serialization
WireFormat::WriteFixed<uint8_t>(output, 0);
}

// Empty columns are not serialized and occupy exactly 0 bytes.
// ref https://github.com/ClickHouse/ClickHouse/blob/39b37a3240f74f4871c8c1679910e065af6bea19/src/Formats/NativeWriter.cpp#L163
const bool containsData = block.GetRowCount() > 0;
Expand Down
15 changes: 15 additions & 0 deletions clickhouse/query.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ struct QuerySettingsField {
};

using QuerySettings = std::unordered_map<std::string, QuerySettingsField>;
using QueryParamValue = std::optional<std::string>;
using QueryParams = std::unordered_map<std::string, QueryParamValue>;

struct Profile {
uint64_t rows = 0;
Expand Down Expand Up @@ -115,6 +117,18 @@ class Query : public QueryEvents {
return *this;
}

inline const QueryParams& GetParams() const { return query_params_; }

inline Query& SetParams(QueryParams query_params) {
query_params_ = std::move(query_params);
return *this;
}

inline Query& SetParam(const std::string& name, const QueryParamValue& value) {
query_params_[name] = value;
return *this;
}

inline const std::optional<open_telemetry::TracingContext>& GetTracingContext() const {
return tracing_context_;
}
Expand Down Expand Up @@ -219,6 +233,7 @@ class Query : public QueryEvents {
const std::string query_id_;
std::optional<open_telemetry::TracingContext> tracing_context_;
QuerySettings query_settings_;
QueryParams query_params_;
ExceptionCallback exception_cb_;
ProgressCallback progress_cb_;
SelectCallback select_cb_;
Expand Down
73 changes: 73 additions & 0 deletions tests/simple/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,77 @@ inline void GenericExample(Client& client) {
client.Execute("DROP TEMPORARY TABLE test_client");
}

inline void ParamExample(Client& client) {
/// Create a table.
client.Execute("CREATE TEMPORARY TABLE IF NOT EXISTS test_client (id UInt64, name String)");

{
Query query("insert into test_client values ({id: UInt64}, {name: String})");

query.SetParam("id", "1").SetParam("name", "NAME");
client.Execute(query);

query.SetParam("id", "123").SetParam("name", "FromParam");
client.Execute(query);

query.SetParam("id", "333")
.SetParam("name",
std::string("A\000A\001A\002A\003A\004A\005A\006A\007A\010A\011A\012A\013A\014A\015A\016A\017A\020A\021A\022A\023A\024A\025A\026A\027A\030A\031A\032A\033A\034"
"A\035A\036A\037A", 65));
client.Execute(query);

unsigned char big_string[128 - 32];
for (unsigned int i = 0; i < sizeof(big_string); i++) big_string[i] = i + 32;
query.SetParam("id", "444")
.SetParam("name",
std::string((char*)big_string, sizeof(big_string)));
client.Execute(query);

query.SetParam("id", "555")
.SetParam("name", "utf8Русский");
client.Execute(query);
}

/// Select values inserted in the previous step.
Query query ("SELECT id, name, length(name) FROM test_client where id > {a: Int32}");
query.SetParam("a", "4");
SelectCallback cb([](const Block& block)
{
std::cout << PrettyPrintBlock{block} << std::endl;
});
query.OnData(cb);
client.Select(query);
/// Delete table.
client.Execute("DROP TEMPORARY TABLE test_client");
}

inline void ParamNullExample(Client& client) {
client.Execute("CREATE TEMPORARY TABLE IF NOT EXISTS test_client (id UInt64, name Nullable(String))");

Query query("insert into test_client values ({id: UInt64}, {name: Nullable(String)})");

query.SetParam("id", "123").SetParam("name", QueryParamValue());
client.Execute(query);

query.SetParam("id", "456").SetParam("name", "String Value");
client.Execute(query);

client.Select("SELECT id, name FROM test_client", [](const Block& block) {
for (size_t c = 0; c < block.GetRowCount(); ++c) {
std::cerr << block[0]->As<ColumnUInt64>()->At(c) << " ";

auto col_string = block[1]->As<ColumnNullable>();
if (col_string->IsNull(c)) {
std::cerr << "\\N\n";
} else {
std::cerr << col_string->Nested()->As<ColumnString>()->At(c) << "\n";
}
}
});

client.Execute("DROP TEMPORARY TABLE test_client");
}

inline void NullableExample(Client& client) {
/// Create a table.
client.Execute("CREATE TEMPORARY TABLE IF NOT EXISTS test_client (id Nullable(UInt64), date Nullable(Date))");
Expand Down Expand Up @@ -478,6 +549,8 @@ inline void IPExample(Client &client) {
}

static void RunTests(Client& client) {
ParamExample(client);
ParamNullExample(client);
ArrayExample(client);
CancelableExample(client);
DateExample(client);
Expand Down
42 changes: 42 additions & 0 deletions ut/client_ut.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1488,3 +1488,45 @@ TEST(SimpleClientTest, issue_335_reconnects_count) {
<< "\tThere was no attempt to connect to endpoint " << endpoint;
}
}

TEST_P(ClientCase, QueryParameters) {
const std::string table_name = "test_clickhouse_cpp_query_parameter";
client_->Execute("CREATE TEMPORARY TABLE IF NOT EXISTS " + table_name + " (id UInt64, name String)");
{
Query query("insert into " + table_name + " values ({id: UInt64}, {name: String})");

query.SetParam("id", "1").SetParam("name", "NAME");
client_->Execute(query);

query.SetParam("id", "123").SetParam("name", "FromParam");
client_->Execute(query);

query.SetParam("id", "333")
.SetParam("name", std::string("A\000A\001A\002A\003A\004A\005A\006A\007A\010A\011A\012A\013A\014A\015A\016A\017A\020A\021A\022A"
"\023A\024A\025A\026A\027A\030A\031A\032A\033A\034"
"A\035A\036A\037A",
65));
client_->Execute(query);

unsigned char big_string[128 - 32];
for (unsigned int i = 0; i < sizeof(big_string); i++) big_string[i] = i + 32;
query.SetParam("id", "444").SetParam("name", std::string((char*)big_string, sizeof(big_string)));
client_->Execute(query);

query.SetParam("id", "555").SetParam("name", "utf8Русский");
client_->Execute(query);
}

Query query("SELECT id, name, length(name) FROM " + table_name + " where id > {a: Int32}");
query.SetParam("a", "4");
size_t total_count = 0;
SelectCallback cb([&total_count](const Block& block) {
total_count += block.GetRowCount();
//std::cout << PrettyPrintBlock{block} << std::endl;
});
query.OnData(cb);
client_->Select(query);
EXPECT_EQ(4u, total_count);

client_->Execute("DROP TEMPORARY TABLE " + table_name);
}
Loading