Octave / Matlab: efficient Frobenius internal product calculation?

I have two matrices A and B, and I want to get:

trace(A*B) 

If I'm not mistaken, this is called the Frobenius internal product .

I'm worried about efficiency. I'm just afraid that this slow-motion approach will first do all the multiplication (my matrices are thousands of rows / columns), and only then take the product trace, and the operation I really need is much simpler. Is there a function or syntax to do this efficiently?

+8
matrix matlab octave
source share
3 answers

That's right ... adding up elementary products will be faster:

 n = 1000 A = randn(n); B = randn(n); tic sum(sum(A .* B)); toc tic sum(diag(A * B')); toc 
 Elapsed time is 0.010015 seconds. Elapsed time is 0.130514 seconds. 
+5
source share

sum(sum(A.*B)) avoids doing full matrix multiplication

+2
source share

How about using vector multiplication?

 (A(:)')*B(:) 

Runtime Check

Comparison of the four options with A and B sized 1000 per 1000:
1. vector scalar product: A(:)'*B(:) (this answer) took only 0.0011 sec .
2. Using elemental multiplication sum(sum(A.*B)) ( John answer) took 0.0035 sec .
3. The trace trace(A*B') (proposed by the OP) took 0.054 sec .
4. The sum of the diagonal sum(diag(A*B')) (option rejected by John ) took 0.055 sec .

Take the home message: Matlab is extremely effective when it comes to matrix / vector product. Using an internal vector product is x3 times faster than an effective solution for multiplying by elements.


Verification Code Code used to verify the execution time.

 t=zeros(1,4); n=1000; % size of matrices it=100; % average results over XX trails for ii=1:it, % random inputs A=rand(n); B=rand(n); % John rejected solution tic; n1=sum(diag(A*B')); t(1)=t(1)+toc; % element-wise solution tic; n2=sum(sum(A.*B)); t(2)=t(2)+toc; % MOST efficient solution - using vector product tic; n3=A(:)'*B(:); t(3)=t(3)+toc; % using trace tic; n4=trace(A*B'); t(4)=t(4)+toc; % make sure everything is correct assert(abs(n1-n2)<1e-8 && abs(n3-n4)<1e-8 && abs(n1-n4)<1e-8); end; t./it 

Now you can run this test in a click .

+1
source share

All Articles