Skip to content

Commit

Permalink
fix serialization and deserialization
Browse files Browse the repository at this point in the history
  • Loading branch information
jperez999 committed Feb 25, 2025
1 parent 69815df commit 950e564
Showing 1 changed file with 13 additions and 24 deletions.
37 changes: 13 additions & 24 deletions cpp/include/raft/sparse/matrix/preprocessing.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <raft/core/device_mdspan.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>
#include <raft/core/serialize.hpp>
#include <raft/sparse/matrix/detail/preprocessing.cuh>

#include <map>
Expand Down Expand Up @@ -214,16 +215,10 @@ void SparseEncoder<ValueType, IndexType>::save(raft::resources const& handle, st
auto featIdCount_md = raft::make_device_vector<IndexType, int64_t>(handle, numFeats);
raft::copy(featIdCount_md.data_handle(), featIdCount, numFeats, stream);
std::ofstream saveFile(save_path);
if (saveFile.is_open()) {
std::ostringstream oss;
saveFile << numFeats << " ";
saveFile << numRows << " ";
saveFile << fullIdLen << " ";
for (int i = 0; i < numFeats; i++) {
saveFile << featIdCount[i] << " ";
}
saveFile.close();
}
raft::serialize_scalar<IndexType>(handle, saveFile, numFeats);
raft::serialize_scalar<IndexType>(handle, saveFile, numRows);
raft::serialize_scalar<IndexType>(handle, saveFile, fullIdLen);
raft::serialize_mdspan<IndexType>(handle, saveFile, featIdCount_md.view());
}

template <typename ValueType, typename IndexType>
Expand Down Expand Up @@ -547,22 +542,16 @@ template <typename ValueType, typename IndexType>
SparseEncoder<ValueType, IndexType>* loadSparseEncoder(raft::resources const& handle,
std::string save_path)
{
cudaStream_t stream = raft::resource::get_cuda_stream(handle);
std::ifstream loadFile(save_path, std::ios_base::in);
IndexType num_feats, num_rows, fullIdLen;
loadFile >> num_feats;
loadFile >> num_rows;
loadFile >> fullIdLen;
IndexType val;
std::vector<IndexType> vals;
while (loadFile >> val) {
vals.push_back(val);
}
auto featIdCount_h = raft::make_host_vector<IndexType, int64_t>(num_feats);
raft::copy(featIdCount_h.data_handle(), vals.data(), vals.size(), stream);
loadFile.close();
auto num_feats = deserialize_scalar<IndexType>(handle, loadFile);
auto num_rows = deserialize_scalar<IndexType>(handle, loadFile);
auto fullIdLen = deserialize_scalar<IndexType>(handle, loadFile);

auto featIdCount_md = raft::make_host_vector<IndexType, int64_t>(num_feats);
deserialize_mdspan<IndexType>(handle, loadFile, featIdCount_md.view());

return new SparseEncoder<ValueType, IndexType>(
handle, featIdCount_h.data_handle(), num_feats, num_rows, fullIdLen);
handle, featIdCount_md.data_handle(), num_feats, num_rows, fullIdLen);
}

} // namespace raft::sparse::matrix

0 comments on commit 950e564

Please sign in to comment.