Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] add support for bm25 and tfidf #2567

Open
wants to merge 15 commits into
base: branch-25.04
Choose a base branch
from

Conversation

jperez999
Copy link

@jperez999 jperez999 commented Feb 3, 2025

This PR will add support for tfidf and BM25 preprocessing of sparse matrix. This PR supports encoding values in raft device COO or CSR sparse matrices. It breaks up the statistical recording (fit) phase and the transformation phase. This allows for batch fitting data and then transforming a target input. This class also allows for exporting/loading, so you can fit in one place and transform in a separate environment. This builds on #2353

@jperez999 jperez999 requested review from a team as code owners February 3, 2025 21:41
@jperez999 jperez999 self-assigned this Feb 3, 2025
@jperez999 jperez999 requested a review from cjnolet February 3, 2025 21:42
@jperez999 jperez999 added enhancement New feature or request feature request New feature or request non-breaking Non-breaking change labels Feb 3, 2025
saveFile << fullIdLen << " ";
// serialize_mdspan<IndexType>(handle, oss, featIdCount_md.view());
for (int i = 0; i < vocabSize; i++) {
saveFile << featIdCount[i] << " ";
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could not use the serialize_mdspan, it does not work well when reading from a file that has other information on it.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would probably be better to use a map here to save only indexes/values that are not zero. This would save time and space. Thoughts?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could not use the serialize_mdspan, it does not work well when reading from a file that has other information on it.

How did serialize_mdspan not work well? We have code like https://github.com/rapidsai/cuvs/blob/49298b22956fdf4d3966825ae9aa41e1aa94975b/cpp/src/neighbors/detail/dataset_serialize.hpp#L79-L83 that serializes both scalar values and mdspan values.

One of the advantages of serialize_mdspan is that its compatible with numpy, also we're using for most other serialization and would be nice to be consistent

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed this, thanks for the tip

saveFile << fullIdLen << " ";
// serialize_mdspan<IndexType>(handle, oss, featIdCount_md.view());
for (int i = 0; i < vocabSize; i++) {
saveFile << featIdCount[i] << " ";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could not use the serialize_mdspan, it does not work well when reading from a file that has other information on it.

How did serialize_mdspan not work well? We have code like https://github.com/rapidsai/cuvs/blob/49298b22956fdf4d3966825ae9aa41e1aa94975b/cpp/src/neighbors/detail/dataset_serialize.hpp#L79-L83 that serializes both scalar values and mdspan values.

One of the advantages of serialize_mdspan is that its compatible with numpy, also we're using for most other serialization and would be nice to be consistent

raft::sparse::op::coo_sort(
nnz, nnz, nnz, d_cols.data_handle(), d_rows.data_handle(), d_vals.data_handle(), stream);
IndexType* counts;
cudaMallocManaged(&counts, nnz * sizeof(IndexType));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use RMM managed_memory_resource instead here?

* The array holding the feature(column) occurrence counts for all fitted inputs.
* @param[in] counts
* The array representing value changes in rows input.
* @param[in] out_values
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick:

Suggested change
* @param[in] out_values
* @param[out] out_values

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

template <typename ValueType = float, typename IndexType = int>
class SparseEncoder {
private:
int* featIdCount;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be IndexType instead of int?

Suggested change
int* featIdCount;
IndexType* featIdCount;

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed


* */
template <typename ValueType, typename IndexType>
void SparseEncoder<ValueType, IndexType>::fit(raft::resources& handle,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: handle should be const

Suggested change
void SparseEncoder<ValueType, IndexType>::fit(raft::resources& handle,
void SparseEncoder<ValueType, IndexType>::fit(raft::resources const& handle,

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

void SparseEncoder<ValueType, IndexType>::_fit_feats(IndexType* cols,
IndexType* counts,
IndexType nnz,
IndexType* results)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function should accept a raft handle - since its launching cuda kernels, it needs to use the cuda stream when creating these kernels

Copy link
Author

@jperez999 jperez999 Feb 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed, but my kernel don't use the handle. I wonder if I am missing something else?

Comment on lines +232 to +234
raft::sparse::matrix::detail::_scan<<<blockSize, num_blocks>>>(cols, nnz, counts);
raft::sparse::matrix::detail::_fit_compute_occurs<<<blockSize, num_blocks>>>(
cols, nnz, counts, results, numFeats);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can probably simplify this -

This seems like this is doing something like np.bincount - counting up the document frequencies of each term so that we can compute the idf for tf-idf/bm25.

I think we can avoid having the _scan/_fit_compute_occurs kernels here entirely (and also remove the requirement for the coo_sort that is performed previously) by using cub. For example - other code inside of raft is using cub and HistogramEven to peform the bincount operation (like

/**
* @brief function to calculate the bincounts of number of samples in every label
* @tparam DataT: type of the data samples
* @tparam LabelT: type of the labels
* @param labels: the pointer to the array containing labels for every data sample (1 x nRows)
* @param binCountArray: pointer to the 1D array that contains the count of samples per cluster (1 x
* nLabels)
* @param nRows: number of data samples
* @param nUniqueLabels: number of Labels
* @param workspace: device buffer containing workspace memory
* @param stream: the cuda stream where to launch this kernel
*/
template <typename DataT, typename LabelT>
void countLabels(const LabelT* labels,
DataT* binCountArray,
int nRows,
int nUniqueLabels,
rmm::device_uvector<char>& workspace,
cudaStream_t stream)
{
int num_levels = nUniqueLabels + 1;
LabelT lower_level = 0;
LabelT upper_level = nUniqueLabels;
size_t temp_storage_bytes = 0;
rmm::device_uvector<int> countArray(nUniqueLabels, stream);
RAFT_CUDA_TRY(cub::DeviceHistogram::HistogramEven(nullptr,
temp_storage_bytes,
labels,
binCountArray,
num_levels,
lower_level,
upper_level,
nRows,
stream));
workspace.resize(temp_storage_bytes, stream);
RAFT_CUDA_TRY(cub::DeviceHistogram::HistogramEven(workspace.data(),
temp_storage_bytes,
labels,
binCountArray,
num_levels,
lower_level,
upper_level,
nRows,
stream));
}
etc)

@rhdong rhdong changed the title add support for bm25 and tfidf [Feat] add support for bm25 and tfidf Feb 24, 2025
TEST_P(SparsePreprocessBigTF, Result) { Run(true, false); }

const std::vector<SparsePreprocessInputs<float, int>> sparse_preprocess_inputs = {
{
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since "std::vector" is used and CUDA kernels are involved in the main codes, it is strongly recommended to add test cases for multiple settings, especially the corner cases.Since "std::vector" is used and CUDA kernels are involved in the main codes, it is strongly recommended to add test cases for multiple settings, especially the corner cases.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More cases, better sleep 😜

template <typename ValueType, typename IndexType>
SparseEncoder<ValueType, IndexType>::SparseEncoder(int feats) : numFeats(feats)
{
cudaMallocManaged(&featIdCount, feats * sizeof(int));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May need RAFT_CUDA_TRY

Suggested change
cudaMallocManaged(&featIdCount, feats * sizeof(int));
RAFT_CUDA_TRY(cudaMallocManaged(&featIdCount, feats * sizeof(int)));

Comment on lines +399 to +401
auto d_rows = raft::make_device_vector<IndexType, int64_t>(handle, nnz);
auto d_cols = raft::make_device_vector<IndexType, int64_t>(handle, nnz);
auto d_vals = raft::make_device_vector<ValueType, int64_t>(handle, nnz);
Copy link
Member

@rhdong rhdong Feb 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the d_xxx are temporary data, it would be better to use get_workspace_resource as memory pool for better performance, pls refer to https://github.com/rapidsai/raft/blob/branch-25.04/cpp/include/raft/cluster/detail/kmeans_balanced.cuh#L329-L330

cudaMemset(counts, 0, nnz * sizeof(IndexType));
_fit_feats(handle, d_cols.data_handle(), counts, nnz, featIdCount);
cudaFree(counts);
cudaDeviceSynchronize();
Copy link
Member

@rhdong rhdong Feb 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cudaFree will sync the device implicitly (before releasing a pointer), so there is no need to call cudaDeviceSynchronize separately. Furthermore, : I believe cudaMallocManaged is low performance, maybe you can try pinned host memory via cudaMallocHost when you're pretty sure the device memory is not sufficient.

raft::device_uvector_policy,
raft::PRESERVING> csr_in,
float* results,
bool bm25_on,
Copy link
Member

@rhdong rhdong Feb 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we have more options on algorithms in the future extension, we may need an enum to avoid public API changing(after release)

columns.data_handle(),
values.data_handle(),
stream);
raft::sparse::matrix::detail::_scan<<<num_blocks, blockSize>>>(rows.data_handle(), nnz, counts);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we have a stream, we'd better to apply it in the kernels

Suggested change
raft::sparse::matrix::detail::_scan<<<num_blocks, blockSize>>>(rows.data_handle(), nnz, counts);
raft::sparse::matrix::detail::_scan<<<num_blocks, blockSize, 0, stream>>>(rows.data_handle(), nnz, counts);

values.data_handle(),
stream);
raft::sparse::matrix::detail::_scan<<<num_blocks, blockSize>>>(rows.data_handle(), nnz, counts);
raft::sparse::matrix::detail::_transform<<<num_blocks, blockSize>>>(rows.data_handle(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raft::sparse::matrix::detail::_transform<<<num_blocks, blockSize>>>(rows.data_handle(),
raft::sparse::matrix::detail::_transform<<<num_blocks, blockSize, 0, stream>>>(rows.data_handle(),

nnz,
numFeats,
bm25_on);
cudaFree(counts);
Copy link
Member

@rhdong rhdong Feb 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should take care to use these explicitly and implicitly device-sync APIs, which will break the asynchronized parallel pipeline in the end-user environment and hurt performance. It would be good practice to try to use a stream and return the right of sync on a stream(cudaStreamSynchronize(stream)) to callers(end-users) as possible as we can.

auto featIdCount_h = raft::make_host_vector<IndexType, int64_t>(num_feats);
raft::copy(featIdCount_h.data_handle(), vals.data(), vals.size(), stream);
loadFile.close();
return new SparseEncoder<ValueType, IndexType>(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Considering using smart pointer like std::shared_ptr

namespace raft::util {

template <typename T>
void print_vals(raft::resources& handle, const raft::device_vector_view<T, size_t>& out)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to be unused code, which should be removed.

5,
params.n_rows,
params.n_cols,
67584);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed random number seeds will cause the generated test data to always be the same, which may not cover certain special input situations and thus miss potential bugs.

Copy link
Member

@rhdong rhdong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well done

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
3 - Ready for Review CMake cpp enhancement New feature or request feature request New feature or request non-breaking Non-breaking change
Projects
Status: In Progress
Development

Successfully merging this pull request may close these issues.

4 participants