Skip to content

Commit

Permalink
Add support for torch.HalfTensor (#874)
Browse files Browse the repository at this point in the history
* Add support for torch.HalfTensor.

* Improvements/Simplifications for torch.HalfTensor.

Improvements/Simplifications:
1) Defines half type as TH_Half, so as to not conflict with cutorch
version.  Previously, these were defined as the same "half" type and
required proper ordering of includes to ensure type was only defined
once, which would have affected all downstream projects.
2) No longer generates math functions that are not actually defined
on torch.HalfTensor, e.g. maskedFill, map, etc.
3) Adds tests for all available torch.HalfTensor functions
4) Allows compiling without TH_GENERIC_USE_HALF (so if there's a
problem can just unset that in CMakeLists rather than backing out)
5) Some simplifications: removes a new copy optimization and
some TH_HALF literal definitions

Limitations:
Because match functions are not defined, some "non-math" operators
on torch.HalfTensor give an error message, e.g. __index__/__newindex__
with a ByteTensor apply a mask, but masks aren't implemented.  These
limitations aren't always obvious, (e.g. for documentation purposes),
but they should always give an error message.

* Rename TH_HALF to THHalf.
  • Loading branch information
gchanan authored and soumith committed Dec 29, 2016
1 parent 7ca7ec9 commit a0c0b78
Show file tree
Hide file tree
Showing 40 changed files with 872 additions and 197 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ IF(MSVC)
ADD_DEFINITIONS(-D_CRT_SECURE_NO_DEPRECATE=1)
ENDIF(MSVC)

ADD_DEFINITIONS(-DTH_GENERIC_USE_HALF=1)

# OpenMP support?
SET(WITH_OPENMP ON CACHE BOOL "OpenMP support if available?")
IF (APPLE AND CMAKE_COMPILER_IS_GNUCC)
Expand Down
145 changes: 79 additions & 66 deletions FFI.lua
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,16 @@ local function checkArgumentType(expected, actual, fn, ud, level)
end

if ok then

local Real2real = {
Byte='unsigned char',
Char='char',
Short='short',
Int='int',
Long='long',
Float='float',
Double='double'
Double='double',
Half='THHalf'
}

-- Allocator
Expand All @@ -32,6 +34,14 @@ typedef struct THAllocator {
void* (*realloc)(void*, void*, ptrdiff_t);
void (*free)(void*, void*);
} THAllocator;
]]

-- Half
ffi.cdef[[
typedef struct {
unsigned short x;
} __THHalf;
typedef __THHalf THHalf;
]]

-- Storage
Expand Down Expand Up @@ -76,7 +86,7 @@ typedef struct THRealTensor
long *size;
long *stride;
int nDimension;
THRealStorage *storage;
ptrdiff_t storageOffset;
int refcount;
Expand All @@ -88,7 +98,8 @@ typedef struct THRealTensor
cdefs = cdefs:gsub('Real', Real):gsub('real', real)
ffi.cdef(cdefs)

local Tensor = torch.getmetatable(string.format('torch.%sTensor', Real))
local Tensor_type = string.format('torch.%sTensor', Real)
local Tensor = torch.getmetatable(Tensor_type)
local Tensor_tt = ffi.typeof('TH' .. Real .. 'Tensor**')

rawset(Tensor,
Expand All @@ -107,75 +118,77 @@ typedef struct THRealTensor
end)

-- faster apply (contiguous case)
local apply = Tensor.apply
rawset(Tensor,
"apply",
function(self, func)
if self:isContiguous() and self.data then
local self_d = self:data()
for i=0,self:nElement()-1 do
local res = func(tonumber(self_d[i])) -- tonumber() required for long...
if res then
self_d[i] = res
if Tensor_type ~= 'torch.HalfTensor' or torch.hashalfmath() then
local apply = Tensor.apply
rawset(Tensor,
"apply",
function(self, func)
if self:isContiguous() and self.data then
local self_d = self:data()
for i=0,self:nElement()-1 do
local res = func(tonumber(self_d[i])) -- tonumber() required for long...
if res then
self_d[i] = res
end
end
return self
else
return apply(self, func)
end
return self
else
return apply(self, func)
end
end)

-- faster map (contiguous case)
local map = Tensor.map
rawset(Tensor,
"map",
function(self, src, func)
checkArgument(torch.isTensor(src), "map", 1, "tensor expected")
checkArgumentType(self:type(), src:type(), "map", 1)

if self:isContiguous() and src:isContiguous() and self.data and src.data then
local self_d = self:data()
local src_d = src:data()
assert(src:nElement() == self:nElement(), 'size mismatch')
for i=0,self:nElement()-1 do
local res = func(tonumber(self_d[i]), tonumber(src_d[i])) -- tonumber() required for long...
if res then
self_d[i] = res
end)

-- faster map (contiguous case)
local map = Tensor.map
rawset(Tensor,
"map",
function(self, src, func)
checkArgument(torch.isTensor(src), "map", 1, "tensor expected")
checkArgumentType(self:type(), src:type(), "map", 1)

if self:isContiguous() and src:isContiguous() and self.data and src.data then
local self_d = self:data()
local src_d = src:data()
assert(src:nElement() == self:nElement(), 'size mismatch')
for i=0,self:nElement()-1 do
local res = func(tonumber(self_d[i]), tonumber(src_d[i])) -- tonumber() required for long...
if res then
self_d[i] = res
end
end
return self
else
return map(self, src, func)
end
return self
else
return map(self, src, func)
end
end)

-- faster map2 (contiguous case)
local map2 = Tensor.map2
rawset(Tensor,
"map2",
function(self, src1, src2, func)
checkArgument(torch.isTensor(src1), "map", 1, "tensor expected")
checkArgument(torch.isTensor(src2), "map", 2, "tensor expected")
checkArgumentType(self:type(), src1:type(), "map", 1)
checkArgumentType(self:type(), src2:type(), "map", 2)

if self:isContiguous() and src1:isContiguous() and src2:isContiguous() and self.data and src1.data and src2.data then
local self_d = self:data()
local src1_d = src1:data()
local src2_d = src2:data()
assert(src1:nElement() == self:nElement(), 'size mismatch')
assert(src2:nElement() == self:nElement(), 'size mismatch')
for i=0,self:nElement()-1 do
local res = func(tonumber(self_d[i]), tonumber(src1_d[i]), tonumber(src2_d[i])) -- tonumber() required for long...
if res then
self_d[i] = res
end)

-- faster map2 (contiguous case)
local map2 = Tensor.map2
rawset(Tensor,
"map2",
function(self, src1, src2, func)
checkArgument(torch.isTensor(src1), "map", 1, "tensor expected")
checkArgument(torch.isTensor(src2), "map", 2, "tensor expected")
checkArgumentType(self:type(), src1:type(), "map", 1)
checkArgumentType(self:type(), src2:type(), "map", 2)

if self:isContiguous() and src1:isContiguous() and src2:isContiguous() and self.data and src1.data and src2.data then
local self_d = self:data()
local src1_d = src1:data()
local src2_d = src2:data()
assert(src1:nElement() == self:nElement(), 'size mismatch')
assert(src2:nElement() == self:nElement(), 'size mismatch')
for i=0,self:nElement()-1 do
local res = func(tonumber(self_d[i]), tonumber(src1_d[i]), tonumber(src2_d[i])) -- tonumber() required for long...
if res then
self_d[i] = res
end
end
return self
else
return map2(self, src1, src2, func)
end
return self
else
return map2(self, src1, src2, func)
end
end)
end)
end
end

-- torch.data
Expand Down
18 changes: 15 additions & 3 deletions Tensor.lua
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ local Storage = {}
local Tensor = {}

-- types
local types = {'Byte', 'Char', 'Short', 'Int', 'Long', 'Float', 'Double'}
local types = {'Byte', 'Char', 'Short', 'Int', 'Long', 'Float', 'Half', 'Double'}

-- Lua 5.2 compatibility
local log10 = math.log10 or function(x) return math.log(x, 10) end

-- tostring() functions for Tensor and Storage
local function Storage__printformat(self)
if self:size() == 0 then
if self:size() == 0 then
return "", nil, 0
end
local intMode = true
Expand Down Expand Up @@ -277,6 +277,10 @@ function Tensor.double(self)
return self:type('torch.DoubleTensor')
end

function Tensor.half(self)
return self:type('torch.HalfTensor')
end

function Tensor.real(self)
return self:type(torch.getdefaulttensortype())
end
Expand Down Expand Up @@ -556,6 +560,14 @@ torch.permute = Tensor.permute
for _,type in ipairs(types) do
local metatable = torch.getmetatable('torch.' .. type .. 'Tensor')
for funcname, func in pairs(Tensor) do
rawset(metatable, funcname, func)
if funcname ~= 'totable' or type ~='Half' or torch.hashalfmath() then
rawset(metatable, funcname, func)
else
local function Tensor__totable(self)
local host_tensor = self:float()
return self:float():totable()
end
rawset(torch.getmetatable('torch.HalfTensor'), 'totable', Tensor__totable)
end
end
end
60 changes: 9 additions & 51 deletions TensorMath.lua
Original file line number Diff line number Diff line change
Expand Up @@ -6,56 +6,7 @@ local interface = wrap.CInterface.new()
local method = wrap.CInterface.new()
local argtypes = wrap.CInterface.argtypes

argtypes['ptrdiff_t'] = {

helpname = function(arg)
return 'ptrdiff_t'
end,

declare = function(arg)
-- if it is a number we initialize here
local default = tonumber(tostring(arg.default)) or 0
return string.format("%s arg%d = %g;", 'ptrdiff_t', arg.i, default)
end,

check = function(arg, idx)
return string.format("lua_isnumber(L, %d)", idx)
end,

read = function(arg, idx)
return string.format("arg%d = (%s)lua_tonumber(L, %d);", arg.i, 'ptrdiff_t', idx)
end,

init = function(arg)
-- otherwise do it here
if arg.default then
local default = tostring(arg.default)
if not tonumber(default) then
return string.format("arg%d = %s;", arg.i, default)
end
end
end,

carg = function(arg)
return string.format('arg%d', arg.i)
end,

creturn = function(arg)
return string.format('arg%d', arg.i)
end,

precall = function(arg)
if arg.returned then
return string.format('lua_pushnumber(L, (lua_Number)arg%d);', arg.i)
end
end,

postcall = function(arg)
if arg.creturned then
return string.format('lua_pushnumber(L, (lua_Number)arg%d);', arg.i)
end
end
}
argtypes['ptrdiff_t'] = wrap.types.ptrdiff_t

interface:print([[
#include "TH.h"
Expand Down Expand Up @@ -216,6 +167,7 @@ local reals = {ByteTensor='unsigned char',
IntTensor='int',
LongTensor='long',
FloatTensor='float',
HalfTensor='half',
DoubleTensor='double'}

local accreals = {ByteTensor='long',
Expand All @@ -224,11 +176,12 @@ local accreals = {ByteTensor='long',
IntTensor='long',
LongTensor='long',
FloatTensor='double',
HalfTensor='float',
DoubleTensor='double'}

for _,Tensor in ipairs({"ByteTensor", "CharTensor",
"ShortTensor", "IntTensor", "LongTensor",
"FloatTensor", "DoubleTensor"}) do
"FloatTensor", "HalfTensor", "DoubleTensor"}) do

local real = reals[Tensor]
local accreal = accreals[Tensor]
Expand Down Expand Up @@ -257,6 +210,7 @@ for _,Tensor in ipairs({"ByteTensor", "CharTensor",
end
end

if Tensor ~= 'HalfTensor' then
wrap("zero",
cname("zero"),
{{name=Tensor, returned=true}})
Expand Down Expand Up @@ -1030,6 +984,7 @@ static void THTensor_random1__(THTensor *self, THGenerator *gen, long b)
cname("nonzero"),
{{name="IndexTensor", default=true, returned=true},
{name=Tensor}})
end -- ~= HalfTensor

if Tensor == 'ByteTensor' then
-- Logical accumulators only apply to ByteTensor
Expand Down Expand Up @@ -1483,6 +1438,9 @@ void torch_TensorMath_init(lua_State *L)
torch_IntTensorMath_init(L);
torch_LongTensorMath_init(L);
torch_FloatTensorMath_init(L);
#if TH_NATIVE_HALF
torch_HalfTensorMath_init(L);
#endif
torch_DoubleTensorMath_init(L);
luaT_setfuncs(L, torch_TensorMath__, 0);
}
Expand Down
1 change: 1 addition & 0 deletions Tester.lua
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,7 @@ local typesMatching = {
['torch.LongStorage'] = torch.LongTensor,
['torch.FloatStorage'] = torch.FloatTensor,
['torch.DoubleStorage'] = torch.DoubleTensor,
['torch.HalfStorage'] = torch.HalfTensor,
}

--[[ Tests for storage equality.
Expand Down
6 changes: 5 additions & 1 deletion generic/Storage.c
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ static int torch_Storage_(new)(lua_State *L)
THStorage_(free)(storage);
luaL_error(L, "element at index %d is not a number", i);
}
THStorage_(set)(storage, i-1, (real)lua_tonumber(L, -1));
THStorage_(set)(storage, i-1, LUA_NUMBER_TO_REAL(lua_tonumber(L, -1)));
lua_pop(L, 1);
}
}
Expand Down Expand Up @@ -131,6 +131,10 @@ static int torch_Storage_(copy)(lua_State *L)
THStorage_(copyFloat)(storage, src);
else if( (src = luaT_toudata(L, 2, "torch.DoubleStorage")) )
THStorage_(copyDouble)(storage, src);
#if TH_GENERIC_USE_HALF
else if( (src = luaT_toudata(L, 2, "torch.HalfStorage")) )
THStorage_(copyHalf)(storage, src);
#endif
else
luaL_typerror(L, 2, "torch.*Storage");
lua_settop(L, 1);
Expand Down
Loading

0 comments on commit a0c0b78

Please sign in to comment.