Skip to content

Commit

Permalink
Refactor parquet::encryption::AesEncryptor
Browse files Browse the repository at this point in the history
  • Loading branch information
mapleFU committed Jul 11, 2024
1 parent c777ac8 commit 03315fc
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 29 deletions.
15 changes: 6 additions & 9 deletions cpp/src/parquet/encryption/encryption_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -420,23 +420,20 @@ AesDecryptor::AesDecryptorImpl::AesDecryptorImpl(ParquetCipher::type alg_id, int
}
}

AesEncryptor* AesEncryptor::Make(ParquetCipher::type alg_id, int key_len, bool metadata,
std::vector<AesEncryptor*>* all_encryptors) {
return Make(alg_id, key_len, metadata, true /*write_length*/, all_encryptors);
std::unique_ptr<AesEncryptor> AesEncryptor::Make(ParquetCipher::type alg_id, int key_len,
bool metadata) {
return Make(alg_id, key_len, metadata, true /*write_length*/);
}

AesEncryptor* AesEncryptor::Make(ParquetCipher::type alg_id, int key_len, bool metadata,
bool write_length,
std::vector<AesEncryptor*>* all_encryptors) {
std::unique_ptr<AesEncryptor> AesEncryptor::Make(ParquetCipher::type alg_id, int key_len,
bool metadata, bool write_length) {
if (ParquetCipher::AES_GCM_V1 != alg_id && ParquetCipher::AES_GCM_CTR_V1 != alg_id) {
std::stringstream ss;
ss << "Crypto algorithm " << alg_id << " is not supported";
throw ParquetException(ss.str());
}

AesEncryptor* encryptor = new AesEncryptor(alg_id, key_len, metadata, write_length);
if (all_encryptors != nullptr) all_encryptors->push_back(encryptor);
return encryptor;
return std::make_unique<AesEncryptor>(alg_id, key_len, metadata, write_length);
}

AesDecryptor::AesDecryptor(ParquetCipher::type alg_id, int key_len, bool metadata,
Expand Down
9 changes: 4 additions & 5 deletions cpp/src/parquet/encryption/encryption_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,11 @@ class PARQUET_EXPORT AesEncryptor {
explicit AesEncryptor(ParquetCipher::type alg_id, int key_len, bool metadata,
bool write_length = true);

static AesEncryptor* Make(ParquetCipher::type alg_id, int key_len, bool metadata,
std::vector<AesEncryptor*>* all_encryptors);
static std::unique_ptr<AesEncryptor> Make(ParquetCipher::type alg_id, int key_len,
bool metadata);

static AesEncryptor* Make(ParquetCipher::type alg_id, int key_len, bool metadata,
bool write_length,
std::vector<AesEncryptor*>* all_encryptors);
static std::unique_ptr<AesEncryptor> Make(ParquetCipher::type alg_id, int key_len,
bool metadata, bool write_length);

~AesEncryptor();

Expand Down
9 changes: 4 additions & 5 deletions cpp/src/parquet/encryption/encryption_internal_nossl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,13 @@ void AesDecryptor::WipeOut() { ThrowOpenSSLRequiredException(); }

AesDecryptor::~AesDecryptor() {}

AesEncryptor* AesEncryptor::Make(ParquetCipher::type alg_id, int key_len, bool metadata,
std::vector<AesEncryptor*>* all_encryptors) {
std::unique_ptr<AesEncryptor> AesEncryptor::Make(ParquetCipher::type alg_id, int key_len,
bool metadata) {
return NULLPTR;
}

AesEncryptor* AesEncryptor::Make(ParquetCipher::type alg_id, int key_len, bool metadata,
bool write_length,
std::vector<AesEncryptor*>* all_encryptors) {
std::unique_ptr<AesEncryptor> AesEncryptor::Make(ParquetCipher::type alg_id, int key_len,
bool metadata, bool write_length) {
return NULLPTR;
}

Expand Down
10 changes: 5 additions & 5 deletions cpp/src/parquet/encryption/internal_file_encryptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ InternalFileEncryptor::InternalFileEncryptor::GetColumnEncryptor(
return encryptor;
}

int InternalFileEncryptor::MapKeyLenToEncryptorArrayIndex(int key_len) {
int InternalFileEncryptor::MapKeyLenToEncryptorArrayIndex(int key_len) const {
if (key_len == 16)
return 0;
else if (key_len == 24)
Expand All @@ -150,8 +150,8 @@ encryption::AesEncryptor* InternalFileEncryptor::GetMetaAesEncryptor(
int key_len = static_cast<int>(key_size);
int index = MapKeyLenToEncryptorArrayIndex(key_len);
if (meta_encryptor_[index] == nullptr) {
meta_encryptor_[index].reset(
encryption::AesEncryptor::Make(algorithm, key_len, true, &all_encryptors_));
meta_encryptor_[index] = encryption::AesEncryptor::Make(algorithm, key_len, true);
all_encryptors_.push_back(meta_encryptor_[index].get());
}
return meta_encryptor_[index].get();
}
Expand All @@ -161,8 +161,8 @@ encryption::AesEncryptor* InternalFileEncryptor::GetDataAesEncryptor(
int key_len = static_cast<int>(key_size);
int index = MapKeyLenToEncryptorArrayIndex(key_len);
if (data_encryptor_[index] == nullptr) {
data_encryptor_[index].reset(
encryption::AesEncryptor::Make(algorithm, key_len, false, &all_encryptors_));
data_encryptor_[index] = encryption::AesEncryptor::Make(algorithm, key_len, false);
all_encryptors_.push_back(data_encryptor_[index].get());
}
return data_encryptor_[index].get();
}
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/parquet/encryption/internal_file_encryptor.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class InternalFileEncryptor {
encryption::AesEncryptor* GetDataAesEncryptor(ParquetCipher::type algorithm,
size_t key_len);

int MapKeyLenToEncryptorArrayIndex(int key_len);
int MapKeyLenToEncryptorArrayIndex(int key_len) const;
};

} // namespace parquet
7 changes: 3 additions & 4 deletions cpp/src/parquet/metadata.cc
Original file line number Diff line number Diff line change
Expand Up @@ -649,9 +649,9 @@ class FileMetaData::FileMetaDataImpl {
std::string key = file_decryptor_->GetFooterKey();
std::string aad = encryption::CreateFooterAad(file_decryptor_->file_aad());

auto aes_encryptor = encryption::AesEncryptor::Make(
file_decryptor_->algorithm(), static_cast<int>(key.size()), true,
false /*write_length*/, nullptr);
auto aes_encryptor = encryption::AesEncryptor::Make(file_decryptor_->algorithm(),
static_cast<int>(key.size()),
true, false /*write_length*/);

std::shared_ptr<Buffer> encrypted_buffer = std::static_pointer_cast<ResizableBuffer>(
AllocateBuffer(file_decryptor_->pool(),
Expand All @@ -662,7 +662,6 @@ class FileMetaData::FileMetaDataImpl {
encrypted_buffer->mutable_data());
// Delete AES encryptor object. It was created only to verify the footer signature.
aes_encryptor->WipeOut();
delete aes_encryptor;
return 0 ==
memcmp(encrypted_buffer->data() + encrypted_len - encryption::kGcmTagLength,
tag, encryption::kGcmTagLength);
Expand Down

0 comments on commit 03315fc

Please sign in to comment.