Fast way to initialize tensor in torch7

I need to initialize a three-dimensional tensor with an index-dependent function in torch7, i.e.

func = function(i,j,k) --i, j is the index of an element in the tensor return i*j*k --do operations within func which're dependent of i, j end 

then I initialize the three-dimensional tensor A as follows:

 for i=1,A:size(1) do for j=1,A:size(2) do for k=1,A:size(3) do A[{i,j,k}] = func(i,j,k) end end end 

But this code is very slow, and I found that it takes 92% of the total time. Are there more efficient ways to initialize a three-dimensional tensor in torch7?

+5
source share
1 answer

See the documentation for Tensor:apply

These functions apply a function to each element of the tensor on which the method is called (self). These methods are much faster than using the for loop in Lua.

The example in the documents initializes a 2D array based on its index i (in memory). Below is an extended example for three dimensions and below this for ND tensors. Using the apply method is much faster on my machine:

 require 'torch' A = torch.Tensor(100, 100, 1000) B = torch.Tensor(100, 100, 1000) function func(i,j,k) return i*j*k end t = os.clock() for i=1,A:size(1) do for j=1,A:size(2) do for k=1,A:size(3) do A[{i, j, k}] = i * j * k end end end print("Original time:", os.difftime(os.clock(), t)) t = os.clock() function forindices(A, func) local i = 1 local j = 1 local k = 0 local d3 = A:size(3) local d2 = A:size(2) return function() k = k + 1 if k > d3 then k = 1 j = j + 1 if j > d2 then j = 1 i = i + 1 end end return func(i, j, k) end end B:apply(forindices(A, func)) print("Apply method:", os.difftime(os.clock(), t)) 

EDIT

This will work for any Tensor object:

 function tabulate(A, f) local idx = {} local ndims = A:dim() local dim = A:size() idx[ndims] = 0 for i=1, (ndims - 1) do idx[i] = 1 end return A:apply(function() for i=ndims, 0, -1 do idx[i] = idx[i] + 1 if idx[i] <= dim[i] then break end idx[i] = 1 end return f(unpack(idx)) end) end -- usage for 3D case. tabulate(A, function(i, j, k) return i * j * k end) 
+7
source

All Articles