From b33715ccec92f1a9fb7318534a589e2ff20cb108 Mon Sep 17 00:00:00 2001 From: Adria Puigdomenech Date: Tue, 19 Apr 2016 11:10:01 +0100 Subject: [PATCH] Make torch.Generator serializable --- Generator.c | 29 +++++++++++++++++++++++++---- test/test.lua | 14 ++++++++++++++ 2 files changed, 39 insertions(+), 4 deletions(-) diff --git a/Generator.c b/Generator.c index 06ec6d00..8cf5ba66 100644 --- a/Generator.c +++ b/Generator.c @@ -1,9 +1,5 @@ #include -static const struct luaL_Reg torch_Generator_table_ [] = { - {NULL, NULL} -}; - int torch_Generator_new(lua_State *L) { THGenerator *gen = THGenerator_new(); @@ -18,6 +14,31 @@ int torch_Generator_free(lua_State *L) return 0; } +static int torch_Generator_write(lua_State *L) +{ + THGenerator *gen = luaT_checkudata(L, 1, torch_Generator); + THFile *file = luaT_checkudata(L, 2, "torch.File"); + + THFile_writeByteRaw(file, (unsigned char *)gen, sizeof(THGenerator)); + return 0; +} + +static int torch_Generator_read(lua_State *L) +{ + THGenerator *gen = luaT_checkudata(L, 1, torch_Generator); + THFile *file = luaT_checkudata(L, 2, "torch.File"); + + THFile_readByteRaw(file, (unsigned char *)gen, sizeof(THGenerator)); + return 0; +} + + +static const struct luaL_Reg torch_Generator_table_ [] = { + {"write", torch_Generator_write}, + {"read", torch_Generator_read}, + {NULL, NULL} +}; + #define torch_Generator_factory torch_Generator_new void torch_Generator_init(lua_State *L) diff --git a/test/test.lua b/test/test.lua index fe197c0d..20ca0355 100644 --- a/test/test.lua +++ b/test/test.lua @@ -2483,6 +2483,20 @@ function torchtest.RNGStateAliasing() mytester:assertTensorEq(target_value, forked_value, 1e-16, "RNG has not forked correctly.") end +function torchtest.serializeGenerator() + local generator = torch.Generator() + torch.manualSeed(generator, 123) + local differentGenerator = torch.Generator() + torch.manualSeed(differentGenerator, 124) + local serializedGenerator = torch.serialize(generator) + local deserializedGenerator = torch.deserialize(serializedGenerator) + local generated = torch.random(generator) + local differentGenerated = torch.random(differentGenerator) + local deserializedGenerated = torch.random(deserializedGenerator) + mytester:asserteq(generated, deserializedGenerated, 'torch.Generator changed internal state after being serialized') + mytester:assertne(generated, differentGenerated, 'Generators with different random seed should not produce the same output') +end + function torchtest.testBoxMullerState() torch.manualSeed(123) local odd_number = 101