Ctree () - How to get a list of splitting conditions for each terminal node?

I have a way out of ctree() ( party package), which looks like this. How to get a list of split conditions for each terminal node, for example, as sns <= 0, dta <= 1; sns <= 0, dta > 1 sns <= 0, dta <= 1; sns <= 0, dta > 1 , etc.?

 1) sns <= 0; criterion = 1, statistic = 14655.021 2) dta <= 1; criterion = 1, statistic = 3286.389 3)* weights = 153682 2) dta > 1 4)* weights = 289415 1) sns > 0 5) dta <= 2; criterion = 1, statistic = 1882.439 6)* weights = 245457 5) dta > 2 7) dta <= 6; criterion = 1, statistic = 1170.813 8)* weights = 328582 7) dta > 6 

thanks

+7
r decision-tree party
source share
4 answers

This function should do the job

  CtreePathFunc <- function (ct, data) { ResulTable <- data.frame(Node = character(), Path = character()) for(Node in unique(where(ct))){ # Taking all possible non-Terminal nodes that are smaller than the selected terminal node NonTerminalNodes <- setdiff(1:(Node - 1), unique(where(ct))[unique(where(ct)) < Node]) # Getting the weigths for that node NodeWeights <- nodes(ct, Node)[[1]]$weights # Finding the path Path <- NULL for (i in NonTerminalNodes){ if(any(NodeWeights & nodes(ct, i)[[1]][2][[1]] == 1)) Path <- append(Path, i) } # Finding the splitting creteria for that path Path2 <- SB <- NULL for(i in 1:length(Path)){ if(i == length(Path)) { n <- nodes(ct, Node)[[1]] } else {n <- nodes(ct, Path[i + 1])[[1]]} if(all(data[which(as.logical(n$weights)), as.character(unlist(nodes(ct,Path[i])[[1]][[5]])[length(unlist(nodes(ct,Path[i])[[1]][[5]]))])] <= as.numeric(unlist(nodes(ct,Path[i])[[1]][[5]])[3]))){ SB <- "<=" } else {SB <- ">"} Path2 <- paste(c(Path2, paste(as.character(unlist(nodes(ct,Path[i])[[1]][[5]])[length(unlist(nodes(ct,Path[i])[[1]][[5]]))]), SB, as.character(unlist(nodes(ct,Path[i])[[1]][[5]])[3]))), collapse = ", ") } # Output ResulTable <- rbind(ResulTable, cbind(Node = Node, Path = Path2)) } return(ResulTable) } 

Testing

 library(party) airq <- subset(airquality, !is.na(Ozone)) ct <- ctree(Ozone ~ ., data = airq, controls = ctree_control(maxsurrogate = 3)) Result <- CtreePathFunc(ct, airq) Result ## Node Path ## 1 5 Temp <= 82, Wind > 6.9, Temp <= 77 ## 2 3 Temp <= 82, Wind <= 6.9 ## 3 6 Temp <= 82, Wind > 6.9, Temp > 77 ## 4 9 Temp > 82, Wind > 10.3 ## 5 8 Temp > 82, Wind <= 10.3 
+9
source share

If you use the new ctree() implementation recommended by partykit instead of the old party package, you can use the .list.rules.party() function. It has not yet been officially exported, but can be used to extract the desired information.

 library("partykit") airq <- subset(airquality, !is.na(Ozone)) ct <- ctree(Ozone ~ ., data = airq) partykit:::.list.rules.party(ct) ## 3 5 ## "Temp <= 82 & Wind <= 6.9" "Temp <= 82 & Wind > 6.9 & Temp <= 77" ## 6 8 ## "Temp <= 82 & Wind > 6.9 & Temp > 77" "Temp > 82 & Wind <= 10.3" ## 9 ## "Temp > 82 & Wind > 10.3" 
+6
source share

I need this function, but for categorical data I more or less answer the question @ JoãoDaniel (I tested only with categorical predictor variables), the following functions:

 # returns string w/o leading or trailing whitespace # http://stackoverflow.com/questions/2261079/how-to-trim-leading-and-trailing-whitespace-in-r trim <- function (x) gsub("^\\s+|\\s+$", "", x) getVariable <- function (x) sub("(.*?)[[:space:]].*", "\\1", x) getSimbolo <- function (x) sub("(.*?)[[:space:]](.*?)[[:space:]].*", "\\2", x) getReglaFinal = function(elemento) { x = as.data.frame(strsplit(as.character(elemento),";")) Regla = apply(x,1, trim) Regla = data.frame(Regla) indice = as.numeric(rownames(Regla)) variable = apply(Regla,1, getVariable) simbolo = apply(Regla,1, getSimbolo) ReglaRaw = data.frame(Regla,indice,variable,simbolo) cols <- c( 'variable' , 'simbolo' ) ReglaRaw$tipo_corte <- apply( ReglaRaw[ , cols ] ,1 , paste , collapse = "" ) #print(ReglaRaw) cortes = unique(ReglaRaw$tipo_corte) #print(cortes) ReglaFinal = "" for(i in 1:length(cortes)){ #print("------------------------------------") #print(cortes[i]) #print("ReglaRaw econtrada") #print(ReglaRaw$indice[ReglaRaw$tipo_corte==cortes[i]]) maximo = max(ReglaRaw$indice[ReglaRaw$tipo_corte==cortes[i]]) #print(maximo) tmp = as.character(ReglaRaw$Regla[ReglaRaw$indice==maximo]) if(ReglaFinal==""){ ReglaFinal = tmp }else{ ReglaFinal = paste(ReglaFinal,tmp,sep="; ",collapse="; ") } } return(ReglaFinal) }#getReglaFinal CtreePathFuncAllCat <- function (ct) { ResulTable <- data.frame(Node = character(), Path = character()) for(Node in unique(where(ct))){ # Taking all possible non-Terminal nodes that are smaller than the selected terminal node NonTerminalNodes <- setdiff(1:(Node - 1), unique(where(ct))[unique(where(ct)) < Node]) # Getting the weigths for that node NodeWeights <- nodes(ct, Node)[[1]]$weights # Finding the path Path <- NULL for (i in NonTerminalNodes){ if(any(NodeWeights & nodes(ct, i)[[1]][2][[1]] == 1)) Path <- append(Path, i) } # Finding the splitting creteria for that path Path2 <- SB <- NULL variablesNombres <- array() variablesPuntos <- list() for(i in 1:length(Path)){ n <- nodes(ct, Path[i])[[1]] if(i == length(Path)) { nextNodeID = Node } else { nextNodeID = Path[i+1] } vec_puntos = as.vector(n[[5]]$splitpoint) vec_nombre = n[[5]]$variableName vec_niveles = attr(n[[5]]$splitpoint,"levels") index = 0 if((length(vec_puntos)!=length(vec_niveles)) && (length(vec_niveles)!=0) ){ index = vec_puntos vec_puntos = vector(length=length(vec_niveles)) vec_puntos[index] = TRUE } if(length(vec_niveles)==0){ index = vec_puntos vec_puntos = n[[5]]$splitpoint } if(index==0){ if(nextNodeID==n$right$nodeID){ vec_puntos = !vec_puntos }else{ vec_puntos = !!vec_puntos } if(i != 1) { for(j in 1:(length(Path)-1)){ if(length(variablesNombres)>=j){ if( variablesNombres[j]==vec_nombre){ vec_puntos = vec_puntos*variablesPuntos[[j]] } } } vec_puntos = vec_puntos==1 } SB = "=" }else{ if(nextNodeID==n$right$nodeID){ SB = ">" }else{ SB = "<=" } } variablesPuntos[[i]] = vec_puntos variablesNombres[i] = vec_nombre if(length(vec_niveles)==0){ descripcion = vec_puntos }else{ descripcion = paste(vec_niveles[vec_puntos],collapse=", ") } Path2 <- paste(c(Path2, paste(c(variablesNombres[i],SB,"{",descripcion, "}"),collapse=" ") ), collapse = "; ") } # Output ResulTable <- rbind(ResulTable, cbind(Node = Node, Path = Path2)) } we = weights(ct) c0 = as.matrix(where(ct)) c3 = sapply(we, function(w) sum(w)) c3 = as.matrix(unique(cbind(c0,c3))) Counts = as.matrix(c3[,2]) c2 = drop(Predict(ct)) Means = as.matrix(unique(c2)) ResulTable = data.frame(ResulTable,Means,Counts) ResulTable = ResulTable[ order(ResulTable$Means) ,] ResulTable$TruePath = apply(as.data.frame(ResulTable$Path),1, getReglaFinal) ResulTable2 = ResulTable ResulTable2$SQL <- paste("WHEN ",gsub("\\'([-+]?([0-9]*\\.[0-9]+|[0-9]+))\\'", "\\1",gsub("\\, ", "','", gsub(" \\}", "')", gsub("\\{ ", "('", gsub("\\;", " AND ", ResulTable2$TruePath)))))," THEN ") cols <- c( 'SQL' , 'Node' ) ResulTable2$SQL <- apply( ResulTable2[ , cols ] ,1 , paste , collapse = "'Nodo " ) ResulTable2$SQL <- gsub("THEN'", "THEN '", gsub(" '", "'", paste(ResulTable2$SQL,"'"))) ResultadoFinal = list() ResultadoFinal$PreTable = ResulTable ResultadoFinal$Table = ResulTable ResultadoFinal$Table$Path = ResultadoFinal$Table$TruePath ResultadoFinal$Table$TruePath = NULL ResultadoFinal$SQL = paste(" CASE ",paste(ResulTable2$SQL,sep="",collapse=" ")," END ",collapse="") return(ResultadoFinal) }#CtreePathFuncAllCat 

Here is the test:

 library(party) #With ordered factors TreeModel1 = ctree(PB~ME+SYMPT+HIST+BSE+DECT, data = mammoexp) Result2 <- CtreePathFuncAllCat(TreeModel1) Result2 ##$PreTable ## Node Path Means Counts ##3 7 DECT > { Somewhat likely }; SYMPT > { Disagree } 6.526316 114 ##2 6 DECT > { Somewhat likely }; SYMPT <= { Disagree } 7.640000 175 ##1 4 DECT <= { Somewhat likely }; DECT > { Not likely } 8.161905 105 ##4 3 DECT <= { Somewhat likely }; DECT <= { Not likely } 9.833333 18 ## TruePath ##3 DECT > { Somewhat likely }; SYMPT > { Disagree } ##2 DECT > { Somewhat likely }; SYMPT <= { Disagree } ##1 DECT <= { Somewhat likely }; DECT > { Not likely } ##4 DECT <= { Not likely } ## ##$Table ## Node Path Means Counts ##3 7 DECT > { Somewhat likely }; SYMPT > { Disagree } 6.526316 114 ##2 6 DECT > { Somewhat likely }; SYMPT <= { Disagree } 7.640000 175 ##1 4 DECT <= { Somewhat likely }; DECT > { Not likely } 8.161905 105 ##4 3 DECT <= { Not likely } 9.833333 18 ## ##$SQL ##[1] " CASE WHEN DECT > ('Somewhat likely') AND SYMPT > ('Disagree') THEN 'Nodo 7' WHEN DECT > ('Somewhat likely') AND SYMPT <= ('Disagree') THEN 'Nodo 6' WHEN DECT <= ('Somewhat likely') AND DECT > ('Not likely') THEN 'Nodo 4' WHEN DECT <= ('Not likely') THEN 'Nodo 3' END " #With unordered factors TreeModel2 = ctree(count~spray, data = InsectSprays) plot(TreeModel2, type="simple") Result2 <- CtreePathFuncAllCat(TreeModel2) Result2 ##$PreTable ##Node Path Means Counts TruePath ##2 5 spray = { C, D, E }; spray = { C, E } 2.791667 24 spray = { C, E } ##3 4 spray = { C, D, E }; spray = { D } 4.916667 12 spray = { D } ##1 2 spray = { A, B, F } 15.500000 36 spray = { A, B, F } ## ##$Table ##Node Path Means Counts ##2 5 spray = { C, E } 2.791667 24 ##3 4 spray = { D } 4.916667 12 ##1 2 spray = { A, B, F } 15.500000 36 ## ##$SQL ##[1] " CASE WHEN spray = ('C','E') THEN 'Nodo 5' WHEN spray = ('D') THEN 'Nodo 4' WHEN spray = ('A','B','F') THEN 'Nodo 2' END " #With continuous variables airq <- subset(airquality, !is.na(Ozone)) TreeModel3 <- ctree(Ozone ~ ., data = airq, controls = ctree_control(maxsurrogate = 3)) Result2 <- CtreePathFuncAllCat(TreeModel3) Result2 ##$PreTable ## Node Path Means Counts ##1 5 Temp <= { 82 }; Wind > { 6.9 }; Temp <= { 77 } 18.47917 48 ##3 6 Temp <= { 82 }; Wind > { 6.9 }; Temp > { 77 } 31.14286 21 ##4 9 Temp > { 82 }; Wind > { 10.3 } 48.71429 7 ##2 3 Temp <= { 82 }; Wind <= { 6.9 } 55.60000 10 ##5 8 Temp > { 82 }; Wind <= { 10.3 } 81.63333 30 ## TruePath ##1 Temp <= { 77 }; Wind > { 6.9 } ##3 Temp <= { 82 }; Wind > { 6.9 }; Temp > { 77 } ##4 Temp > { 82 }; Wind > { 10.3 } ##2 Temp <= { 82 }; Wind <= { 6.9 } ##5 Temp > { 82 }; Wind <= { 10.3 } ## ##$Table ## Node Path Means Counts ##1 5 Temp <= { 77 }; Wind > { 6.9 } 18.47917 48 ##3 6 Temp <= { 82 }; Wind > { 6.9 }; Temp > { 77 } 31.14286 21 ##4 9 Temp > { 82 }; Wind > { 10.3 } 48.71429 7 ##2 3 Temp <= { 82 }; Wind <= { 6.9 } 55.60000 10 ##5 8 Temp > { 82 }; Wind <= { 10.3 } 81.63333 30 ## ##$SQL ##[1] " CASE WHEN Temp <= (77) AND Wind > (6.9) THEN 'Nodo 5' WHEN Temp <= (82) AND Wind > (6.9) AND Temp > (77) THEN 'Nodo 6' WHEN Temp > (82) AND Wind > (10.3) THEN 'Nodo 9' WHEN Temp <= (82) AND Wind <= (6.9) THEN 'Nodo 3' WHEN Temp > (82) AND Wind <= (10.3) THEN 'Nodo 8' END " 

Update! Now the function supports a combination of categorical and numeric variables!

+3
source share

The CtreePathFunc function, rewritten in most of Hadley's verse (and I think is more understandable). Also handles categorical variables.

 library(magrittr) readSplitter <- function(nodeSplit){ splitPoint <- nodeSplit$splitpoint if("levels" %>% is_in(splitPoint %>% attributes %>% names)){ splitPoint %>% attr("levels") %>% .[splitPoint] }else{ splitPoint %>% as.numeric } } hasWeigths <- function(ct, path, terminalNode, pathNumber){ ct %>% nodes(pathNumber %>% equals(path %>% length) %>% ifelse(terminalNode, path[pathNumber + 1]) ) %>% .[[1]] %>% use_series("weights") %>% as.logical %>% which } dataFilter <- function(ct, dts, path, terminalNode, pathNumber){ whichWeights <- hasWeigths(ct, path, terminalNode, pathNumber) nodes(ct, path[pathNumber])[[1]][[5]] %>% buildDataFilter(dts, whichWeights) } buildDataFilter <- function(nodeSplit, ...) UseMethod("buildDataFilter") buildDataFilter.nominalSplit <- function(nodeSplit, dts, whichWeights){ varName <- nodeSplit$variableName includedLevels <- dts[ whichWeights ,varName] %>% unique paste( varName, "==" ,includedLevels %>% paste(collapse = ", ") %>% paste0("{", ., "}")) } buildDataFilter.orderedSplit <- function(nodeSplit, dts, whichWeights){ varName <- nodeSplit$variableName splitter <- nodeSplit %>% readSplitter dts[ whichWeights ,varName] %>% is_weakly_less_than(splitter) %>% all %>% ifelse("<=" ,">") %>% paste(varName, ., splitter) } readTerminalNodePaths <- function (ct, dts) { nodeWeights <- function(Node) nodes(ct, Node)[[1]]$weights sgmnts <- ct %>% where %>% unique nodesFirstTreeWeightIsOne <- function(node) nodes(ct, node)[[1]][2][[1]] == 1 # Take the inner nodes smaller than the selected terminal node innerNodes <- function(Node) setdiff( 1:(Node - 1) ,sgmnts[sgmnts < Node]) pathForTerminalNode <- function(terminalNode){ innerNodes(terminalNode) %>% sapply(function(innerNode){ if(any(nodeWeights(terminalNode) & nodesFirstTreeWeightIsOne(innerNode))) innerNode }) %>% unlist } # Find the splits criteria sgmnts %>% sapply(function(terminalNode){ # path <- terminalNode %>% pathForTerminalNode path %>% length %>% seq %>% sapply(function(nodeNumber){ dataFilter(ct, dts, path, terminalNode, nodeNumber) }, simplify = FALSE) %>% unlist %>% paste(collapse = " & ") %>% data.frame(Node = terminalNode, Path = .) }, simplify = FALSE) %>% Reduce(f = rbind) } 

Testing

 shiftFirstPart <- function(vctr, divideBy, proportion = .5){ vctr[vctr %>% length %>% multiply_by(proportion) %>% round %>% seq] %<>% divide_by(divideBy) vctr } set.seed(11) n <- 13000 gdt <- data.frame( is_buyer = runif(n) %>% shiftFirstPart(1.5) %>% round %>% factor(labels = c("no", "yes")) ,age = runif(n) %>% shiftFirstPart(1.5) %>% cut(breaks = c(0, .3, .6, 1), include_lowest = TRUE, ordered_result = TRUE, labels = c("low", "mid", "high")) ,city = runif(n) %>% shiftFirstPart(1.5) %>% cut(breaks = c(0, .3, .6, 1), include_lowest = TRUE, labels = c("Chigaco", "Boston", "Memphis")) ,point = runif(n) %>% shiftFirstPart(1.2) ) gct <- ctree( is_buyer ~ ., data = gdt) readTerminalNodePaths(gct, gdt) 
0
source share

All Articles