From f820b67843fd04a9dd548cd940381845474208b1 Mon Sep 17 00:00:00 2001 From: Paris Morgan Date: Mon, 2 Sep 2024 12:41:15 +0200 Subject: [PATCH] Various tdb matrix cleanups --- src/include/detail/linalg/tdb_matrix.h | 14 +++----------- src/include/detail/linalg/tdb_matrix_with_ids.h | 6 +++--- src/include/detail/linalg/tdb_partitioned_matrix.h | 9 +++------ 3 files changed, 9 insertions(+), 20 deletions(-) diff --git a/src/include/detail/linalg/tdb_matrix.h b/src/include/detail/linalg/tdb_matrix.h index bcbe55943..a3bdc2619 100644 --- a/src/include/detail/linalg/tdb_matrix.h +++ b/src/include/detail/linalg/tdb_matrix.h @@ -92,13 +92,12 @@ class tdbBlockedMatrix : public MatrixBase { index_type last_row_; index_type first_col_; index_type last_col_; + + // The columns loaded into memory. Except for the last (remainder) block, + // this range will be equal to `load_blocksize_`. index_type first_resident_col_; index_type last_resident_col_; - // The number of columns loaded into memory. Except for the last (remainder) - // block, this will be equal to `blocksize_`. - index_type num_resident_cols_{0}; - // How many columns to load at a time index_type load_blocksize_{0}; @@ -234,21 +233,14 @@ class tdbBlockedMatrix : public MatrixBase { auto cell_order = schema_.cell_order(); auto tile_order = schema_.tile_order(); - if ((matrix_order_ == TILEDB_ROW_MAJOR && cell_order == TILEDB_COL_MAJOR) || (matrix_order_ == TILEDB_COL_MAJOR && cell_order == TILEDB_ROW_MAJOR)) { throw std::runtime_error("Cell order and matrix order must match"); } - if (cell_order != tile_order) { throw std::runtime_error("Cell order and tile order must match"); } - auto domain_{schema_.domain()}; - - auto row_domain{domain_.dimension(0)}; - auto col_domain{domain_.dimension(1)}; - // If non_empty_domain() is an empty vector it means that // the array is empty. Else If the user specifies a value then we use it, // otherwise we use the non-empty domain. diff --git a/src/include/detail/linalg/tdb_matrix_with_ids.h b/src/include/detail/linalg/tdb_matrix_with_ids.h index 40273a34c..06f09fedd 100644 --- a/src/include/detail/linalg/tdb_matrix_with_ids.h +++ b/src/include/detail/linalg/tdb_matrix_with_ids.h @@ -85,8 +85,8 @@ class tdbBlockedMatrixWithIds /** * @brief Construct a new tdbBlockedMatrixWithIds object, limited to - * `upper_bound` vectors. In this case, the `Matrix` is row-major, so the - * number of vectors is the number of rows. + * `upper_bound` vectors. In this case, the `Matrix` is column-major, so the + * number of vectors is the number of columns. * * @param ctx The TileDB context to use. * @param uri URI of the TileDB array to read. @@ -142,7 +142,7 @@ class tdbBlockedMatrixWithIds size_t first_col, std::optional last_col, size_t upper_bound, - TemporalPolicy temporal_policy) // noexcept + TemporalPolicy temporal_policy) requires(std::is_same_v) : Base( ctx, diff --git a/src/include/detail/linalg/tdb_partitioned_matrix.h b/src/include/detail/linalg/tdb_partitioned_matrix.h index cb0eba8b5..e9e9d9c3e 100644 --- a/src/include/detail/linalg/tdb_partitioned_matrix.h +++ b/src/include/detail/linalg/tdb_partitioned_matrix.h @@ -318,16 +318,13 @@ class tdbPartitionedMatrix auto cell_order = partitioned_vectors_schema_.cell_order(); auto tile_order = partitioned_vectors_schema_.tile_order(); - if (cell_order != tile_order) { throw std::runtime_error("Cell order and tile order must match"); } auto domain_{partitioned_vectors_schema_.domain()}; - auto array_rows{domain_.dimension(0)}; auto array_cols{domain_.dimension(1)}; - dimensions_ = (array_rows.template domain().second - array_rows.template domain().first + 1); @@ -622,10 +619,10 @@ class tdbPartitionedMatrix } void debug_tdb_partitioned_matrix( - const std::string& msg, size_t max_size = 10) { + const std::string& msg, size_t max_size = 10) const { debug_partitioned_matrix(*this, msg, max_size); - debug_vector(master_indices_, "# master_indices_", max_size); - debug_vector(relevant_parts_, "# relevant_parts_", max_size); + debug_vector(master_indices_, "# master_indices_ ", max_size); + debug_vector(relevant_parts_, "# relevant_parts_ ", max_size); debug_vector(squashed_indices_, "# squashed_indices_", max_size); std::cout << "# total_num_parts_: " << total_num_parts_ << std::endl; std::cout << "# last_resident_part_: " << last_resident_part_ << std::endl;