Why is the runtime to construct a decision tree mnlog(n)?

iltp38 picture iltp38 · Dec 10, 2015 · Viewed 9.8k times · Source

When m is the amount of features and n is the amount of samples, the python scikit-learn site (http://scikit-learn.org/stable/modules/tree.html) states that the runtime to construct a binary decision tree is mnlog(n).

I understand that the log(n) comes from the average height of the tree after splitting. I understand that at each split, you have to look at each feature (m) and choose the best one to split on. I understand that this is done by calculating a "best metric" (in my case, a gini impurity) for each sample at that node (n). However, to find the best split, doesn't this mean that you would have to look at each possible way to split the samples for each feature? And wouldn't that be something like 2^n-1 * m rather than just mn? Am I thinking about this wrong? Any advice would help. Thank you.

Answer

templatetypedef picture templatetypedef · Dec 10, 2015

One way to build a decision tree would be, at each point, to do something like this:

  • For each possible feature to split on:
    • Find the best possible split for that feature.
    • Determine the "goodness" of this fit.
  • Of all the options tried above, take the best and use that for the split.

The question is how do perform each step. If you have continuous data, a common technique for finding the best possible split would be to sort the data into ascending order along that data point, then consider all possible partition points between those data points and taking the one that minimizes the entropy. This sorting step takes time O(n log n), which dominates the runtime. Since we're doing that for each of the O(m) features, the runtime ends up working out to O(mn log n) total work done per node.