diff --git a/src/secret.c b/src/secret.c index e7a119c..cb3264e 100644 --- a/src/secret.c +++ b/src/secret.c @@ -203,40 +203,7 @@ static void mbedtls_sha3_finish(mbedtls_sha3_context *ctx, } -// secretbase - internal auxiliary functions ----------------------------------- - -static int validate_bitlength(const SEXP bits) { - - const int size = Rf_asInteger(bits); - if (size < 8 || size > (1 << 24)) - Rf_error("'bits' must be between 8 and 2^24"); - return size; - -} - -static mbedtls_sha3_id id_from_size(size_t size) { - - mbedtls_sha3_id id; - switch (size) { - case 224: - id = MBEDTLS_SHA3_224; - break; - case 256: - id = MBEDTLS_SHA3_256; - break; - case 384: - id = MBEDTLS_SHA3_384; - break; - case 512: - id = MBEDTLS_SHA3_512; - break; - default: - id = MBEDTLS_SHA3_SHAKE256; - break; - } - return id; - -} +// secretbase - internals ------------------------------------------------------ static void hash_bytes(R_outpstream_t stream, void *src, int len) { @@ -266,105 +233,121 @@ static SEXP hash_to_char(unsigned char *buf, const size_t sz) { } -static SEXP create_object(const int type, unsigned char *buf, const size_t sz) { - - SEXP out; - switch (type) { - case 0: - out = Rf_allocVector(RAWSXP, sz); - memcpy(STDVEC_DATAPTR(out), buf, sz); - break; - case 1: - out = hash_to_char(buf, sz); - break; - default: - out = Rf_allocVector(INTSXP, sz / sizeof(int)); - memcpy(STDVEC_DATAPTR(out), buf, sz); - break; - } - return out; - -} - -// secretbase - exported functions --------------------------------------------- - -SEXP secretbase_sha3(SEXP x, SEXP bits, SEXP convert) { +static SEXP secretbase_sha3_impl(const SEXP x, const SEXP bits, + const SEXP convert, const int file) { const int conv = LOGICAL(convert)[0]; - const int size = validate_bitlength(bits); + const int size = Rf_asInteger(bits); + if (size < 8 || size > (1 << 24)) + Rf_error("'bits' must be between 8 and 2^24"); const size_t outlen = (size_t) (size / 8); unsigned char output[outlen]; + SEXP out; - mbedtls_sha3_id id = id_from_size(size); mbedtls_sha3_context ctx; + mbedtls_sha3_id id; + switch (size) { + case 224: + id = MBEDTLS_SHA3_224; + break; + case 256: + id = MBEDTLS_SHA3_256; + break; + case 384: + id = MBEDTLS_SHA3_384; + break; + case 512: + id = MBEDTLS_SHA3_512; + break; + default: + id = MBEDTLS_SHA3_SHAKE256; + break; + } mbedtls_sha3_init(&ctx); mbedtls_sha3_starts(&ctx, id); - switch (TYPEOF(x)) { - case STRSXP: - if (XLENGTH(x) == 1 && ATTRIB(x) == R_NilValue) { - const char *s = CHAR(STRING_ELT(x, 0)); - mbedtls_sha3_update(&ctx, (const uint8_t *) s, strlen(s)); - goto finish; + if (file) { + + const char *filepath = R_ExpandFileName(CHAR(STRING_ELT(x, 0))); + unsigned char buf[SB_BUF_SIZE]; + size_t cur; + + FILE *fp = fopen(filepath, "rb"); + if (fp == NULL) + Rf_error("file not found or no read permission"); + while ((cur = fread(buf, 1, sizeof(buf), fp))) { + mbedtls_sha3_update(&ctx, buf, cur); } - break; - case RAWSXP: - if (ATTRIB(x) == R_NilValue) { - mbedtls_sha3_update(&ctx, (const uint8_t *) STDVEC_DATAPTR(x), (size_t) XLENGTH(x)); - goto finish; + fclose(fp); + + } else { + + switch (TYPEOF(x)) { + case STRSXP: + if (XLENGTH(x) == 1 && ATTRIB(x) == R_NilValue) { + const char *s = CHAR(STRING_ELT(x, 0)); + mbedtls_sha3_update(&ctx, (const uint8_t *) s, strlen(s)); + goto finish; + } + break; + case RAWSXP: + if (ATTRIB(x) == R_NilValue) { + mbedtls_sha3_update(&ctx, (const uint8_t *) STDVEC_DATAPTR(x), (size_t) XLENGTH(x)); + goto finish; + } + break; } - break; + + secretbase_context sctx; + sctx.ctx = &ctx; + sctx.skip = SB_SERIAL_HEADERS; + + struct R_outpstream_st output_stream; + R_InitOutPStream( + &output_stream, + (R_pstream_data_t) &sctx, + R_pstream_xdr_format, + SB_R_SERIAL_VER, + NULL, + hash_bytes, + NULL, + R_NilValue + ); + R_Serialize(x, &output_stream); + } - secretbase_context sctx; - sctx.ctx = &ctx; - sctx.skip = SB_SERIAL_HEADERS; - - struct R_outpstream_st output_stream; - R_InitOutPStream( - &output_stream, - (R_pstream_data_t) &sctx, - R_pstream_xdr_format, - SB_R_SERIAL_VER, - NULL, - hash_bytes, - NULL, - R_NilValue - ); - R_Serialize(x, &output_stream); - finish: mbedtls_sha3_finish(&ctx, output, outlen); - return create_object(conv, output, outlen); + switch (conv) { + case 0: + out = Rf_allocVector(RAWSXP, outlen); + memcpy(STDVEC_DATAPTR(out), output, outlen); + break; + case 1: + out = hash_to_char(output, outlen); + break; + default: + out = Rf_allocVector(INTSXP, outlen / sizeof(int)); + memcpy(STDVEC_DATAPTR(out), output, outlen); + break; + } + + return out; } -SEXP secretbase_sha3_file(SEXP x, SEXP bits, SEXP convert) { - - const char *file = R_ExpandFileName(CHAR(STRING_ELT(x, 0))); - const int conv = LOGICAL(convert)[0]; - const int size = validate_bitlength(bits); - const size_t outlen = (size_t) (size / 8); - unsigned char output[outlen]; - unsigned char buf[SB_BUF_SIZE]; - size_t cur; +// secretbase - exported functions --------------------------------------------- + +SEXP secretbase_sha3(SEXP x, SEXP bits, SEXP convert) { - mbedtls_sha3_id id = id_from_size(size); - mbedtls_sha3_context ctx; - mbedtls_sha3_init(&ctx); - mbedtls_sha3_starts(&ctx, id); + return secretbase_sha3_impl(x, bits, convert, 0); - FILE *fp = fopen(file, "rb"); - if (fp == NULL) - Rf_error("file not found or accessible"); - while ((cur = fread(buf, 1, sizeof(buf), fp))) { - mbedtls_sha3_update(&ctx, buf, cur); - } - fclose(fp); - - mbedtls_sha3_finish(&ctx, output, outlen); +} + +SEXP secretbase_sha3_file(SEXP x, SEXP bits, SEXP convert) { - return create_object(conv, output, outlen); + return secretbase_sha3_impl(x, bits, convert, 1); } diff --git a/src/secret.h b/src/secret.h index ccf0633..08a1daa 100644 --- a/src/secret.h +++ b/src/secret.h @@ -59,8 +59,8 @@ typedef struct mbedtls_sha3_context { } mbedtls_sha3_context; typedef struct secretbase_context_s { - mbedtls_sha3_context *ctx; int skip; + mbedtls_sha3_context *ctx; } secretbase_context; SEXP secretbase_sha3(SEXP, SEXP, SEXP); diff --git a/tests/tests.R b/tests/tests.R index f5c14c6..7d95e1a 100644 --- a/tests/tests.R +++ b/tests/tests.R @@ -37,4 +37,4 @@ hash_func <- function(file, string) { sha3file(file) } test_equal(hash_func(tempfile(), "secret base"), "a721d57570e7ce366adee2fccbe9770723c6e3622549c31c7cab9dbb4a795520") -test_error(hash_func("", ""), "file not found or accessible") +test_error(hash_func("", ""), "file not found or no read permission")