Skip to content

Commit

Permalink
Request publish support
Browse files Browse the repository at this point in the history
  • Loading branch information
Bret Ambrose committed Apr 1, 2024
1 parent 1734d1e commit c4313c8
Showing 1 changed file with 221 additions and 22 deletions.
243 changes: 221 additions & 22 deletions source/request-response/request_response_client.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <inttypes.h>

#define MQTT_RR_CLIENT_OPERATION_TABLE_DEFAULT_SIZE 50
#define MQTT_RR_CLIENT_RESPONSE_TABLE_DEFAULT_SIZE 50

enum aws_mqtt_request_response_operation_type {
AWS_MRROT_REQUEST,
Expand Down Expand Up @@ -152,6 +153,48 @@ static void s_aws_rr_operation_list_topic_filter_entry_hash_element_destroy(void
s_aws_rr_operation_list_topic_filter_entry_destroy(value);
}

struct aws_rr_response_path_entry {
struct aws_allocator *allocator;

size_t ref_count;

struct aws_byte_cursor topic_cursor;
struct aws_byte_buf topic;

struct aws_byte_buf correlation_token_json_path;
};

static struct aws_rr_response_path_entry *s_aws_rr_response_path_entry_new(
struct aws_allocator *allocator,
struct aws_byte_cursor topic,
struct aws_byte_cursor correlation_token_json_path) {
struct aws_rr_response_path_entry *entry = aws_mem_calloc(allocator, 1, sizeof(struct aws_rr_response_path_entry));

entry->allocator = allocator;
entry->ref_count = 1;
aws_byte_buf_init_copy_from_cursor(&entry->topic, allocator, topic);
entry->topic_cursor = aws_byte_cursor_from_buf(&entry->topic);

aws_byte_buf_init_copy_from_cursor(&entry->correlation_token_json_path, allocator, correlation_token_json_path);

return entry;
}

static void s_aws_rr_response_path_entry_destroy(struct aws_rr_response_path_entry *entry) {
if (entry == NULL) {
return;
}

aws_byte_buf_clean_up(&entry->topic);
aws_byte_buf_clean_up(&entry->correlation_token_json_path);

aws_mem_release(entry->allocator, entry);
}

static void s_aws_rr_response_path_table_hash_element_destroy(void *value) {
s_aws_rr_response_path_entry_destroy(value);
}

/* All operations have an internal ref to the client they are a part of */

/*
Expand Down Expand Up @@ -484,7 +527,18 @@ struct aws_mqtt_request_response_client {
*/
struct aws_priority_queue operations_by_timeout;

struct aws_hash_table operation_lists_by_subscription_filter;
/*
* Map from cursor (topic filter) -> list of streaming operations using that filter
*/
struct aws_hash_table streaming_operation_subscription_lists;

/*
* Map from cursor (topic) -> request response path (topic, correlation token json path)
*
* We don't garbage collect this table over the course of normal client operation. We only clean it up
* when the client is shutting down.
*/
struct aws_hash_table request_response_paths;
};

struct aws_mqtt_request_response_client *aws_mqtt_request_response_client_acquire_internal(
Expand Down Expand Up @@ -527,7 +581,8 @@ static void s_mqtt_request_response_client_final_destroy(struct aws_mqtt_request
aws_hash_table_clean_up(&client->operations);

aws_priority_queue_clean_up(&client->operations_by_timeout);
aws_hash_table_clean_up(&client->operation_lists_by_subscription_filter);
aws_hash_table_clean_up(&client->streaming_operation_subscription_lists);
aws_hash_table_clean_up(&client->request_response_paths);

aws_mem_release(client->allocator, client);

Expand Down Expand Up @@ -728,6 +783,41 @@ struct aws_rrc_incomplete_publish {
uint64_t operation_id;
};

static void s_aws_rrc_incomplete_publish_destroy(struct aws_rrc_incomplete_publish *user_data) {
if (user_data == NULL) {
return;
}

aws_mqtt_request_response_client_release_internal(user_data->rr_client);

aws_mem_release(user_data->allocator, user_data);
}

static void s_on_request_publish_completion(int error_code, void *userdata) {
struct aws_rrc_incomplete_publish *publish_user_data = userdata;

if (error_code != AWS_ERROR_SUCCESS) {
AWS_LOGF_ERROR(
AWS_LS_MQTT_REQUEST_RESPONSE,
"id=%p: request-response operation %" PRIu64 " failed publish step due to error %d(%s)",
(void *)publish_user_data->rr_client,
publish_user_data->operation_id,
error_code,
aws_error_debug_str(error_code));

struct aws_hash_element *element = NULL;
if (aws_hash_table_find(
&publish_user_data->rr_client->operations, &publish_user_data->operation_id, &element) ==
AWS_OP_SUCCESS &&
element != NULL) {
struct aws_mqtt_rr_client_operation *operation = element->value;
s_complete_request_operation_with_failure(operation, AWS_ERROR_MQTT_REQUEST_RESPONSE_PUBLISH_FAILURE);
}
}

s_aws_rrc_incomplete_publish_destroy(publish_user_data);
}

static void s_make_mqtt_request(
struct aws_mqtt_request_response_client *client,
struct aws_mqtt_rr_client_operation *operation) {
Expand All @@ -737,17 +827,39 @@ static void s_make_mqtt_request(

struct aws_mqtt_request_operation_options *request_options = &operation->storage.request_storage.options;

struct aws_rrc_incomplete_publish *publish_user_data =
aws_mem_calloc(client->allocator, 1, sizeof(struct aws_rrc_incomplete_publish));
publish_user_data->allocator = client->allocator;
publish_user_data->rr_client = aws_mqtt_request_response_client_acquire_internal(client);
publish_user_data->operation_id = operation->id;

struct aws_protocol_adapter_publish_options publish_options = {
.topic = request_options->publish_topic,
.payload = request_options->serialized_request,
.ack_timeout_seconds = client->config.operation_timeout_seconds,
.completion_callback_fn = s_??,
.user_data = ??,
.completion_callback_fn = s_on_request_publish_completion,
.user_data = publish_user_data,
};

if (aws_mqtt_protocol_adapter_publish(client->client_adapter, &publish_options)) {
int error_code = aws_last_error();

AWS_LOGF_ERROR(
AWS_LS_MQTT_REQUEST_RESPONSE,
"id=%p: request-response operation %" PRIu64 " synchronously failed publish step due to error %d(%s)",
(void *)publish_user_data->rr_client,
publish_user_data->operation_id,
error_code,
aws_error_debug_str(error_code));
s_complete_request_operation_with_failure(operation, AWS_ERROR_MQTT_REQUEST_RESPONSE_PUBLISH_FAILURE);
goto error;
}

return;

error:

s_aws_rrc_incomplete_publish_destroy(publish_user_data);
}

struct aws_rr_subscription_status_event_task {
Expand Down Expand Up @@ -853,7 +965,8 @@ static void s_handle_subscription_status_event_task(struct aws_task *task, void
case ARRSET_REQUEST_SUBSCRIBE_SUCCESS:
case ARRSET_REQUEST_SUBSCRIBE_FAILURE:
case ARRSET_REQUEST_SUBSCRIPTION_ENDED:
s_on_request_operation_subscription_status_event(operation, aws_byte_cursor_from_buf(&event_task->topic_filter), event_task->type);
s_on_request_operation_subscription_status_event(
operation, aws_byte_cursor_from_buf(&event_task->topic_filter), event_task->type);
break;

case ARRSET_STREAMING_SUBSCRIPTION_ESTABLISHED:
Expand Down Expand Up @@ -982,7 +1095,7 @@ static void s_aws_rr_client_protocol_adapter_incoming_publish_callback(
/* Streaming operation handling */
struct aws_hash_element *subscription_filter_element = NULL;
if (aws_hash_table_find(
&rr_client->operation_lists_by_subscription_filter, &publish_event->topic, &subscription_filter_element) ==
&rr_client->streaming_operation_subscription_lists, &publish_event->topic, &subscription_filter_element) ==
AWS_OP_SUCCESS) {
if (subscription_filter_element != NULL) {
AWS_LOGF_DEBUG(
Expand Down Expand Up @@ -1093,14 +1206,23 @@ static struct aws_mqtt_request_response_client *s_aws_mqtt_request_response_clie
s_compare_rr_operation_timeouts);

aws_hash_table_init(
&rr_client->operation_lists_by_subscription_filter,
&rr_client->streaming_operation_subscription_lists,
allocator,
MQTT_RR_CLIENT_OPERATION_TABLE_DEFAULT_SIZE,
aws_hash_byte_cursor_ptr,
aws_mqtt_byte_cursor_hash_equality,
NULL,
s_aws_rr_operation_list_topic_filter_entry_hash_element_destroy);

aws_hash_table_init(
&rr_client->request_response_paths,
allocator,
MQTT_RR_CLIENT_RESPONSE_TABLE_DEFAULT_SIZE,
aws_hash_byte_cursor_ptr,
aws_mqtt_byte_cursor_hash_equality,
NULL,
s_aws_rr_response_path_table_hash_element_destroy);

aws_linked_list_init(&rr_client->operation_queue);

aws_task_init(
Expand Down Expand Up @@ -1198,24 +1320,24 @@ static bool s_is_operation_in_list(const struct aws_mqtt_rr_client_operation *op
return aws_linked_list_node_prev_is_valid(&operation->node) && aws_linked_list_node_next_is_valid(&operation->node);
}

static int s_add_operation_to_subscription_topic_filter_table(
static int s_add_streaming_operation_to_subscription_topic_filter_table(
struct aws_mqtt_request_response_client *client,
struct aws_mqtt_rr_client_operation *operation) {

struct aws_byte_cursor topic_filter_cursor = s_aws_mqtt_rr_operation_get_subscription_topic_filter(operation);

struct aws_hash_element *element = NULL;
if (aws_hash_table_find(&client->operation_lists_by_subscription_filter, &topic_filter_cursor, &element)) {
if (aws_hash_table_find(&client->streaming_operation_subscription_lists, &topic_filter_cursor, &element)) {
return aws_raise_error(AWS_ERROR_MQTT_REQUEST_RESPONSE_INTERNAL_ERROR);
}

struct aws_rr_operation_list_topic_filter_entry *entry = NULL;
if (element == NULL) {
entry = s_aws_rr_operation_list_topic_filter_entry_new(client->allocator, topic_filter_cursor);
aws_hash_table_put(&client->operation_lists_by_subscription_filter, &entry->topic_filter_cursor, entry, NULL);
aws_hash_table_put(&client->streaming_operation_subscription_lists, &entry->topic_filter_cursor, entry, NULL);
AWS_LOGF_DEBUG(
AWS_LS_MQTT_REQUEST_RESPONSE,
"id=%p: request-response client adding topic filter '" PRInSTR "' to subscriptions table",
"id=%p: request-response client adding topic filter '" PRInSTR "' to streaming subscriptions table",
(void *)client,
AWS_BYTE_CURSOR_PRI(topic_filter_cursor));
} else {
Expand All @@ -1230,8 +1352,8 @@ static int s_add_operation_to_subscription_topic_filter_table(

AWS_LOGF_DEBUG(
AWS_LS_MQTT_REQUEST_RESPONSE,
"id=%p: request-response client adding operation %" PRIu64 " to subscription table with topic_filter '" PRInSTR
"'",
"id=%p: request-response client adding streaming operation %" PRIu64
" to streaming subscription table with topic_filter '" PRInSTR "'",
(void *)client,
operation->id,
AWS_BYTE_CURSOR_PRI(topic_filter_cursor));
Expand All @@ -1241,6 +1363,47 @@ static int s_add_operation_to_subscription_topic_filter_table(
return AWS_OP_SUCCESS;
}

static int s_add_request_operation_to_response_path_table(
struct aws_mqtt_request_response_client *client,
struct aws_mqtt_rr_client_operation *operation) {

struct aws_array_list *paths = &operation->storage.request_storage.operation_response_paths;
size_t path_count = aws_array_list_length(paths);
for (size_t i = 0; i < path_count; ++i) {
struct aws_mqtt_request_operation_response_path path;
aws_array_list_get_at(paths, &path, i);

struct aws_hash_element *element = NULL;
if (aws_hash_table_find(&client->request_response_paths, &path.topic, &element)) {
return aws_raise_error(AWS_ERROR_MQTT_REQUEST_RESPONSE_INTERNAL_ERROR);
}

if (element != NULL) {
struct aws_rr_response_path_entry *entry = element->value;
++entry->ref_count;
continue;
}

struct aws_rr_response_path_entry *entry =
s_aws_rr_response_path_entry_new(client->allocator, path.topic, path.correlation_token_json_path);
if (aws_hash_table_put(&client->request_response_paths, &entry->topic_cursor, entry, NULL)) {
return aws_raise_error(AWS_ERROR_MQTT_REQUEST_RESPONSE_INTERNAL_ERROR);
}
}

return AWS_OP_SUCCESS;
}

static int s_add_in_progress_operation_to_tracking_tables(
struct aws_mqtt_request_response_client *client,
struct aws_mqtt_rr_client_operation *operation) {
if (operation->type == AWS_MRROT_STREAMING) {
return s_add_streaming_operation_to_subscription_topic_filter_table(client, operation);
} else {
return s_add_request_operation_to_response_path_table(client, operation);
}
}

static void s_handle_operation_subscribe_result(
struct aws_mqtt_request_response_client *client,
struct aws_mqtt_rr_client_operation *operation,
Expand All @@ -1253,7 +1416,7 @@ static void s_handle_operation_subscribe_result(
return;
}

if (s_add_operation_to_subscription_topic_filter_table(client, operation)) {
if (s_add_in_progress_operation_to_tracking_tables(client, operation)) {
s_request_response_fail_operation(operation, AWS_ERROR_MQTT_REQUEST_RESPONSE_INTERNAL_ERROR);
return;
}
Expand Down Expand Up @@ -1598,8 +1761,6 @@ static void s_mqtt_rr_client_submit_operation(struct aws_task *task, void *arg,
// add appropriate client table entries
aws_hash_table_put(&client->operations, &operation->id, operation, NULL);

// NYI other tables

// add to timeout priority queue
if (operation->type == AWS_MRROT_REQUEST) {
aws_priority_queue_push_ref(
Expand Down Expand Up @@ -1636,6 +1797,49 @@ static void s_aws_mqtt_request_operation_storage_clean_up(struct aws_mqtt_reques
aws_byte_buf_clean_up(&storage->operation_data);
}

static void s_remove_request_operation_from_response_path_table(struct aws_mqtt_rr_client_operation *operation) {
if (operation->type != AWS_MRROT_REQUEST) {
return;
}

struct aws_mqtt_request_response_client *client = operation->client_internal_ref;
struct aws_array_list *paths = &operation->storage.request_storage.operation_response_paths;
size_t path_count = aws_array_list_length(paths);
for (size_t i = 0; i < path_count; ++i) {
struct aws_mqtt_request_operation_response_path path;
aws_array_list_get_at(paths, &path, i);

struct aws_hash_element *element = NULL;
if (aws_hash_table_find(&client->request_response_paths, &path.topic, &element) || element == NULL) {
AWS_LOGF_ERROR(
AWS_LS_MQTT_REQUEST_RESPONSE,
"id=%p: internal state error removing reference to response path for topic " PRInSTR,
(void *)client,
AWS_BYTE_CURSOR_PRI(path.topic));
continue;
}

struct aws_rr_response_path_entry *entry = element->value;
--entry->ref_count;

if (entry->ref_count == 0) {
AWS_LOGF_DEBUG(
AWS_LS_MQTT_REQUEST_RESPONSE,
"id=%p: removing last reference to response path for topic " PRInSTR,
(void *)client,
AWS_BYTE_CURSOR_PRI(path.topic));
aws_hash_table_remove(&client->request_response_paths, &path.topic, NULL, NULL);
} else {
AWS_LOGF_DEBUG(
AWS_LS_MQTT_REQUEST_RESPONSE,
"id=%p: removing reference to response path for topic " PRInSTR ", %zu references remain",
(void *)client,
AWS_BYTE_CURSOR_PRI(path.topic),
entry->ref_count);
}
}
}

static void s_mqtt_rr_client_destroy_operation(struct aws_task *task, void *arg, enum aws_task_status status) {
(void)task;
(void)status;
Expand All @@ -1658,12 +1862,7 @@ static void s_mqtt_rr_client_destroy_operation(struct aws_task *task, void *arg,
aws_rr_subscription_manager_release_subscription(&client->subscription_manager, &release_options);
}

/*
NYI:
Remove from correlation token table
*/
s_remove_request_operation_from_response_path_table(operation);

aws_mqtt_request_response_client_release_internal(operation->client_internal_ref);

Expand Down

0 comments on commit c4313c8

Please sign in to comment.