I adapted the plannapus answer to work with several trees (also cut out some parameters that I do not need in the process):

library(ape) heatmap.phylo <- function(x, Rowp, Colp, breaks, col, denscol="cyan", respect=F, ...){ # x numeric matrix # Rowp: phylogenetic tree (class phylo) to be used in rows # Colp: phylogenetic tree (class phylo) to be used in columns # ... additional arguments to be passed to image function scale01 <- function(x, low = min(x), high = max(x)) { x <- (x - low)/(high - low) x } col.tip <- Colp$tip n.col <- 1 if (is.null(col.tip)) { n.col <- length(Colp) col.tip <- unlist(lapply(Colp, function(t) t$tip)) col.lengths <- unlist(lapply(Colp, function(t) length(t$tip))) col.fraction <- col.lengths / sum(col.lengths) col.heights <- unlist(lapply(Colp, function(t) max(node.depth.edgelength(t)))) col.max_height <- max(col.heights) } row.tip <- Rowp$tip n.row <- 1 if (is.null(row.tip)) { n.row <- length(Rowp) row.tip <- unlist(lapply(Rowp, function(t) t$tip)) row.lengths <- unlist(lapply(Rowp, function(t) length(t$tip))) row.fraction <- row.lengths / sum(row.lengths) row.heights <- unlist(lapply(Rowp, function(t) max(node.depth.edgelength(t)))) row.max_height <- max(row.heights) } cexRow <- min(1, 0.2 + 1/log10(n.row)) cexCol <- min(1, 0.2 + 1/log10(n.col)) x <- x[row.tip, col.tip] xl <- c(0.5, ncol(x)+0.5) yl <- c(0.5, nrow(x)+0.5) screen_matrix <- matrix( c( 0,1,4,5, 1,4,4,5, 0,1,1,4, 1,4,1,4, 1,4,0,1, 4,5,1,4 ) / 5, byrow=T, ncol=4 ) if (respect) { r <- grconvertX(1, from = "inches", to = "ndc") / grconvertY(1, from = "inches", to = "ndc") if (r < 1) { screen_matrix <- screen_matrix * matrix( c(r,r,1,1), nrow=6, ncol=4, byrow=T) } else { screen_matrix <- screen_matrix * matrix( c(1,1,1/r,1/r), nrow=6, ncol=4, byrow=T) } } split.screen( screen_matrix ) screen(2) par(mar=rep(0,4)) if (n.col == 1) { plot(Colp, direction="downwards", show.tip.label=FALSE,xaxs="i", x.lim=xl) } else { screens <- split.screen( as.matrix(data.frame( left=cumsum(col.fraction)-col.fraction, right=cumsum(col.fraction), bottom=0, top=1))) for (i in 1:n.col) { screen(screens[i]) plot(Colp[[i]], direction="downwards", show.tip.label=FALSE,xaxs="i", x.lim=c(0.5,0.5+col.lengths[i]), y.lim=-col.max_height+col.heights[i]+c(0,col.max_height)) } } screen(3) par(mar=rep(0,4)) if (n.col == 1) { plot(Rowp, direction="rightwards", show.tip.label=FALSE,yaxs="i", y.lim=yl) } else { screens <- split.screen( as.matrix(data.frame( left=0, right=1, bottom=cumsum(row.fraction)-row.fraction, top=cumsum(row.fraction))) ) for (i in 1:n.col) { screen(screens[i]) plot(Rowp[[i]], direction="rightwards", show.tip.label=FALSE,yaxs="i", x.lim=c(0,row.max_height), y.lim=c(0.5,0.5+row.lengths[i])) } } screen(4) par(mar=rep(0,4), xpd=TRUE) image((1:nrow(x))-0.5, (1:ncol(x))-0.5, x, xaxs="i", yaxs="i", axes=FALSE, xlab="",ylab="", breaks=breaks, col=col, ...) screen(6) par(mar=rep(0,4)) plot(NA, axes=FALSE, ylab="", xlab="", yaxs="i", xlim=c(0,2), ylim=yl) text(rep(0,nrow(x)),1:nrow(x),row.tip, pos=4, cex=cexCol) screen(5) par(mar=rep(0,4)) plot(NA, axes=FALSE, ylab="", xlab="", xaxs="i", ylim=c(0,2), xlim=xl) text(1:ncol(x),rep(2,ncol(x)),col.tip, srt=90, adj=c(1,0.5), cex=cexRow) screen(1) par(mar = c(2, 2, 1, 1), cex = 0.75) symkey <- T tmpbreaks <- breaks if (symkey) { max.raw <- max(abs(c(x, breaks)), na.rm = TRUE) min.raw <- -max.raw tmpbreaks[1] <- -max(abs(x), na.rm = TRUE) tmpbreaks[length(tmpbreaks)] <- max(abs(x), na.rm = TRUE) } else { min.raw <- min(x, na.rm = TRUE) max.raw <- max(x, na.rm = TRUE) } z <- seq(min.raw, max.raw, length = length(col)) image(z = matrix(z, ncol = 1), col = col, breaks = tmpbreaks, xaxt = "n", yaxt = "n") par(usr = c(0, 1, 0, 1)) lv <- pretty(breaks) xv <- scale01(as.numeric(lv), min.raw, max.raw) axis(1, at = xv, labels = lv) h <- hist(x, plot = FALSE, breaks = breaks) hx <- scale01(breaks, min.raw, max.raw) hy <- c(h$counts, h$counts[length(h$counts)]) lines(hx, hy/max(hy) * 0.95, lwd = 1, type = "s", col = denscol) axis(2, at = pretty(hy)/max(hy) * 0.95, pretty(hy)) par(cex = 0.5) mtext(side = 2, "Count", line = 2) close.screen(all.screens = T) } tree <- read.tree(text = "(A:1,B:1);((C:1,D:2):2,E:1);((F:1,G:1,H:2):5,((I:1,J:2):2,K:1):1);", comment.char="") N <- sum(unlist(lapply(tree, function(t) length(t$tip)))) set.seed(42) m <- cor(matrix(rnorm(N*N), nrow=N)) rownames(m) <- colnames(m) <- LETTERS[1:N] heatmap.phylo(m, tree, tree, col=bluered(10), breaks=seq(-1,1,length.out=11), respect=T)