Skip to content

Commit

Permalink
Make torch.Generator serializable
Browse files Browse the repository at this point in the history
  • Loading branch information
adria-p committed Apr 19, 2016
1 parent 1812606 commit b33715c
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 4 deletions.
29 changes: 25 additions & 4 deletions Generator.c
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
#include <general.h>

static const struct luaL_Reg torch_Generator_table_ [] = {
{NULL, NULL}
};

int torch_Generator_new(lua_State *L)
{
THGenerator *gen = THGenerator_new();
Expand All @@ -18,6 +14,31 @@ int torch_Generator_free(lua_State *L)
return 0;
}

static int torch_Generator_write(lua_State *L)
{
THGenerator *gen = luaT_checkudata(L, 1, torch_Generator);
THFile *file = luaT_checkudata(L, 2, "torch.File");

THFile_writeByteRaw(file, (unsigned char *)gen, sizeof(THGenerator));
return 0;
}

static int torch_Generator_read(lua_State *L)
{
THGenerator *gen = luaT_checkudata(L, 1, torch_Generator);
THFile *file = luaT_checkudata(L, 2, "torch.File");

THFile_readByteRaw(file, (unsigned char *)gen, sizeof(THGenerator));
return 0;
}


static const struct luaL_Reg torch_Generator_table_ [] = {
{"write", torch_Generator_write},
{"read", torch_Generator_read},
{NULL, NULL}
};

#define torch_Generator_factory torch_Generator_new

void torch_Generator_init(lua_State *L)
Expand Down
14 changes: 14 additions & 0 deletions test/test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -2483,6 +2483,20 @@ function torchtest.RNGStateAliasing()
mytester:assertTensorEq(target_value, forked_value, 1e-16, "RNG has not forked correctly.")
end

function torchtest.serializeGenerator()
local generator = torch.Generator()
torch.manualSeed(generator, 123)
local differentGenerator = torch.Generator()
torch.manualSeed(differentGenerator, 124)
local serializedGenerator = torch.serialize(generator)
local deserializedGenerator = torch.deserialize(serializedGenerator)
local generated = torch.random(generator)
local differentGenerated = torch.random(differentGenerator)
local deserializedGenerated = torch.random(deserializedGenerator)
mytester:asserteq(generated, deserializedGenerated, 'torch.Generator changed internal state after being serialized')
mytester:assertne(generated, differentGenerated, 'Generators with different random seed should not produce the same output')
end

function torchtest.testBoxMullerState()
torch.manualSeed(123)
local odd_number = 101
Expand Down

0 comments on commit b33715c

Please sign in to comment.