mex-, .
n = 2000;
n_labels = 800;
W = rand(n, n);
W = W * W' > .5; % generate symmetric adjacency matrix of logicals
Wd = double(W);
ci = floor(rand(n, 1) * n_labels ) + 1; % generate ids from 1 to 251
[C, IA, IC] = unique(ci);
disp(sprintf('base avg fun time = %g ',timeit(@() interlinks(W, IC))));
disp(sprintf('mex avg fun time = %g ',timeit(@() interlink_mex(W, IC))));
%note this function requires symmetric (function from @aarbelle)
disp(sprintf('bsx avg fun time = %g ',timeit(@() interlinks_bsx(Wd, IC'))));
x1 = interlinks(W, IC);
x2 = interlink_mex(W, IC);
x3 = interlinks_bsx(Wd, IC');
disp(sprintf('norm(x1 - x2) = %g', norm(x1 - x2)));
disp(sprintf('norm(x1 - x3) = %g', norm(x1 - x3)));
Test results with these settings:
base avg fun time = 4.94275
mex avg fun time = 0.0373092
bsx avg fun time = 0.126406
norm(x1 - x2) = 0
norm(x1 - x3) = 0
Basically, for small ones n_labels, the bsx function works very well, but you can make it big enough so that the mex function is faster.
C ++ code
enter it in some file, for example interlink_mex.cpp, and compile it with mex interlink_mex.cpp. You need a C ++ compiler on your computer, etc.
#include "mex.h"
#include "matrix.h"
#include <math.h>
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {
if(nrhs != 2)
mexErrMsgTxt("Invalid number of inputs. Shoudl be 2 input argument.");
if(nlhs != 1)
mexErrMsgTxt("Invalid number of outputs. Should be 1 output arguments.");
if(!mxIsLogical(prhs[0])) {
mexErrMsgTxt("First argument should be a logical array (i.e. type logical)");
}
if(!mxIsDouble(prhs[1])) {
mexErrMsgTxt("Second argument should be an array of type double");
}
const mxArray *W = prhs[0];
const mxArray *ci = prhs[1];
size_t W_m = mxGetM(W);
size_t W_n = mxGetN(W);
if(W_m != W_n)
mexErrMsgTxt("Rows and columns of W are not equal");
size_t ci_n = mxGetNumberOfElements(ci);
mxLogical *W_data = mxGetLogicals(W);
double *ci_data = mxGetPr(ci);
size_t *ci_data_size_t = (size_t*) mxCalloc(ci_n, sizeof(size_t));
size_t ncomms = 0;
double intpart;
for(size_t i = 0; i < ci_n; i++) {
double x = ci_data[i];
if(x < 1 || x > 65536 || modf(x, &intpart) != 0.0) {
mexErrMsgTxt("Input ci is not all integers from 1 to a maximum value of 65536 (can edit source code to change this)");
}
size_t xx = (size_t) x;
if(xx > ncomms)
ncomms = xx;
ci_data_size_t[i] = xx - 1;
}
mxArray *mcd = mxCreateDoubleMatrix(ncomms, ncomms, mxREAL);
double *mcd_data = mxGetPr(mcd);
for(size_t i = 0; i < W_n; i++) {
size_t ii = ci_data_size_t[i];
for(size_t j = 0; j < W_n; j++) {
size_t jj = ci_data_size_t[j];
mcd_data[ii + jj * ncomms] += (W_data[i + j * W_m] != 0);
}
}
for(size_t i = 0; i < ncomms * ncomms; i+= ncomms + 1)
mcd_data[i]/=2;
mxFree(ci_data_size_t);
plhs[0] = mcd;
}