ctree() - How to get the list of splitting conditions for each terminal node?

Sriram Murali picture Sriram Murali · Jan 29, 2014 · Viewed 7k times · Source

I have an output from ctree() (party package) that looks like the following. How do I get the list of splitting conditions for each terminal node, like like sns <= 0, dta <= 1; sns <= 0, dta > 1 and so on?

1) sns <= 0; criterion = 1, statistic = 14655.021
  2) dta <= 1; criterion = 1, statistic = 3286.389
   3)*  weights = 153682 
  2) dta > 1
   4)*  weights = 289415 
1) sns > 0
  5) dta <= 2; criterion = 1, statistic = 1882.439
   6)*  weights = 245457 
  5) dta > 2
   7) dta <= 6; criterion = 1, statistic = 1170.813
     8)*  weights = 328582 
   7) dta > 6

Thanks

Answer

David Arenburg picture David Arenburg · Mar 10, 2014

This function should do the job

 CtreePathFunc <- function (ct, data) {

  ResulTable <- data.frame(Node = character(), Path = character())

  for(Node in unique(where(ct))){
  # Taking all possible non-Terminal nodes that are smaller than the selected terminal node
  NonTerminalNodes <- setdiff(1:(Node - 1), unique(where(ct))[unique(where(ct)) < Node])


  # Getting the weigths for that node
  NodeWeights <- nodes(ct, Node)[[1]]$weights


  # Finding the path
  Path <- NULL
  for (i in NonTerminalNodes){
    if(any(NodeWeights & nodes(ct, i)[[1]][2][[1]] == 1)) Path <- append(Path, i)
  }

  # Finding the splitting creteria for that path
  Path2 <- SB <- NULL

  for(i in 1:length(Path)){
    if(i == length(Path)) {
      n <- nodes(ct, Node)[[1]]
    } else {n <- nodes(ct, Path[i + 1])[[1]]}

    if(all(data[which(as.logical(n$weights)), as.character(unlist(nodes(ct,Path[i])[[1]][[5]])[length(unlist(nodes(ct,Path[i])[[1]][[5]]))])] <= as.numeric(unlist(nodes(ct,Path[i])[[1]][[5]])[3]))){
      SB <- "<="
    } else {SB <- ">"}
    Path2 <- paste(c(Path2, paste(as.character(unlist(nodes(ct,Path[i])[[1]][[5]])[length(unlist(nodes(ct,Path[i])[[1]][[5]]))]),
                                 SB,
                                 as.character(unlist(nodes(ct,Path[i])[[1]][[5]])[3]))),
                   collapse = ", ")
  }

  # Output
  ResulTable <- rbind(ResulTable, cbind(Node = Node, Path = Path2))
  }
  return(ResulTable)
}

Testing

library(party)
airq <- subset(airquality, !is.na(Ozone))
ct <- ctree(Ozone ~ ., data = airq,  controls = ctree_control(maxsurrogate = 3))
Result <- CtreePathFunc(ct, airq)
Result 

##   Node                               Path
## 1    5 Temp <= 82, Wind > 6.9, Temp <= 77
## 2    3            Temp <= 82, Wind <= 6.9
## 3    6  Temp <= 82, Wind > 6.9, Temp > 77
## 4    9             Temp > 82, Wind > 10.3
## 5    8            Temp > 82, Wind <= 10.3