How does the Decision Tree Algorithm work?
Decision Tree ... My favorite ML algorithm ... "Hello world" of the ML algorithms ... simple yet powerful..!
I have experienced in my carrier as Lead Data Scientist that aspiring data scientists or data scientist who have just started their carrier in the data science domain, most of the time they will be asked to explain the ML algorithm they like. Most of them will choose a decision tree. Though they have read about algorithms and used them extensively in several projects, due to not having clear steps in mind and no preparation for follow-up questions, they won't handle it and their confidence will get shattered. So, to help them to understand and explain it in an easy manner, today I am writing this article.
If you can't explain it simply, you don't understand it well enough! - Albert Einstein
Let me provide an explanation in simple steps, firstly will provide an explanation for What is decision tree algorithm is and how it works? (sufficient to answer questions like how algorithm works?), then material to create a simple Decision Tree in Excel, R, and Python (for better understanding) on the toy data set that I have created. This will help you to understand the working of the decision tree algorithm. Finally, I will provide some Follow-up Questions (food for the brain and for further exploration about the algorithm) which might be asked to check your understanding of the algorithm.
Table of Contents
What is a decision tree algorithm and how it works?
Decision Tree in Excel, R, and Python 2.1 Decision tree in Excel 2.2 Decision tree in R 2.3 Decision tree in Python
Follow-up Questions
Data & Code on GitHub
Conclusion
1. What is a decision tree algorithm and how it works?
A decision tree is all about creating a tree from the given labeled data.
Here, the idea is to represent the labeled data as a tree where internal node or decision node represents the condition on attributes of the data given to that node, each branch represents the outcome of that condition and each leaf node represents the class label.
To build a tree, we need to decide which question to ask and when to ask?
To get answers to these questions or conditions and their order, we need to provide all the labeled data to the root node and iterate through all the attributes for features in the the data set and calculate impurity/information gain for each attribute.
To calculate impurity or information gain, we can use entropy (a measure of uncertainty) or GINI index (a measure of information gain)
After splitting the root node based on condition, the same process will be repeated recursively for newly created nodes with remaining rows till no attribute or no instance is remaining.
2. Decision Tree in Excel, R, and Python
As I have already mentioned, if you want to understand some concepts in depth then better way is to create it from scratch and understand how it works. I have created decision tree in excel from scratch. Let's see how the steps mentioned above will help us to create the decision tree in excel. I am also giving code written in R and python for better understanding.
2.1 Decision Tree in Excel
For this experiment, I have created the following toy data, with 2 features Age and Income and target variable is Loan Approved.
Now let’s start to apply steps 3,4 & 5 mentioned earlier.
After the the first pass, as shown the below figure, you can see that the total impurity for Income the feature is the smallest means information gain is more. Hence, we can split the tree based on its attributes.
As we can see, for feature Income and attributes medium or low, instances are already classified, we can use condition/question Income == medium or low.
After the first split, let’s apply step 6 mentioned earlier.
So, we will provide remaining rows from data to 2 nodes created based on condition Income == medium or low. The first branch where the condition is true is already classified as entropy is 0 for them. For the second branch or node, we will provide the remaining rows for further classification.
After the second split, we will get condition as Age == Senior. The first branch of that node where the condition is true is already classified with a single negative instance as entropy is 0 and the second branch is also classified with 2 positive examples.
Please remember that — For the leaf node entropy is 0 and splitting won’t happen on the leaf node and Only decision nodes undergo the splitting process as entropy > 0.
The following figure helps us to visualize the whole classification process -
Please refer the legends at the right side of the image for better understanding.
2.2 Decision Tree in R
Following is the sample code and output for the toy data in R.
Code
# ---------------------------------------
# 1. INSTALLATION
# ---------------------------------------
# install.packages("rpart")
# install.packages("rpart.plot")
# ---------------------------------------
# 2. IMPORT
# ---------------------------------------
library("rpart")
library("rpart.plot")
# ---------------------------------------
# 3. GET DATA
# ---------------------------------------
compLA<- read.csv("DT Loan_approval.csv",header = TRUE)
# ---------------------------------------
# 4. FIT MODEL
# ---------------------------------------
fit<- rpart(formula = Loan.Approved~Age+Income ,
method = "class",
data=compLA,
control=rpart.control(minsplit=1,
minbucket=1,
cp=0.01))
summary(fit)
# ---------------------------------------
# 5. VISUALIZATION
# ---------------------------------------
rpart.plot(fit)
Output
2.3 Decision Tree in Python
Following is the sample code and output for the toy data in python -
Code
# ---------------------------------------
# 1. INSTALLATION
# ---------------------------------------
! conda install pandas# ! conda install sklearn#
! conda install matplotlib
# ---------------------------------------
# 2. IMPORT
# ---------------------------------------
import pandas as pd
from sklearn import tree
from sklearn.preprocessing import LabelEncoder
import matplotlib.pyplot as plt
# ---------------------------------------
# 3. GET DATA
# ---------------------------------------
data = pd.read_csv("DT Loan_approval.csv")
data = data.astype(str)
le_Age = LabelEncoder()
le_Income = LabelEncoder()
le_Loan_Approved = LabelEncoder()
data['en_Age'] = le_Age.fit_transform(data['Age'])
data['en_Income'] = le_Income.fit_transform(data['Income'])
data['en_Loan_Approved']= le_Loan_Approved.fit_transform(data['Loan Approved'])
# ---------------------------------------
# 4. FIT MODEL
# ---------------------------------------
clf = tree.DecisionTreeClassifier(random_state=0)
clf = clf.fit(data[['en_Age', 'en_Income']], data[["en_Loan_Approved"]])
# ---------------------------------------
# 5. VISUALIZATION
# ---------------------------------------
feature_list=['Age','Income']
target_list=['Yes','No']
fig, axes = plt.subplots(nrows = 1,ncols = 1,figsize = (5,4), dpi=200)
tree.plot_tree(clf,
feature_names = feature_list,
class_names=target_list,
fontsize = 6,
filled = True)
Output
3. Follow-up Questions
Q1. What type of decision trees are available?
There are 2 types of decision trees available -
Categorical Variable Decision Tree - when the target variable is categorical
Continuous Variable Decision Tree - when the target variable is continuous
Q2. What will be the impact of outliers on the decision tree?
Outliers don't make any or much impact on the Decision Tree classifier. After going through the above section in the article, you can now understand that they will be handled in decision tree classifiers with the separate branches.
Q3. How/when do we get to know that the decision tree is over-fitting?
As the depth of the tree grows, it starts over-fitting on given data. If we don't control the decision tree with the parameters like max-depth,min_samples_split,max_leaf_nodes, etc. then it will give you 100% accuracy on given data because in worst-case scenario, it will generate one leaf node for each observation.
Q4. What steps one can take to avoid over-fitting?
Set the conditions on the parameters to reduce tree size. Please note that the below parameters are from scikit-learn tree package
max-depth: to control the depth of the tree.
min_samples_split: minimum number of samples required to split an internal node or decision node.
max_leaf_nodes: to grow a tree with max_leaf_nodes in best-first fashion.
Tree Pruning - Pruning reduces the size of decision trees by removing parts of the tree that do not provide power to classify instances. Decision trees are the most susceptible out of all the machine learning algorithms to over-fitting and effective pruning can reduce this likelihood. In R, for tree pruning, we can use the prune function from rpart library.
Description: Determines a nested sequence of subtrees of the supplied rpart object by recursively snipping off the least important splits, based on the complexity parameter (cp).
Usage: prune(tree, cp, ...)
Arguments:
tree: fitted model object of class "rpart". This is assumed to be the result of some function that produces an object with the same-named components as that returned by the rpart function.
cp: Complexity parameter to which the rpart object will be trimmed.
Q5. What are other decision tree based algorithms?
Random Forest - Bagging (Ensemble Technique) In this, we train homogeneous weak learners, independently and in parallel to do the averaging of their results
XGBoost - Boosting (Ensemble Technique) In this, we train homogeneous weak learners, sequentially in a very adaptive way (a base model depends on the previous ones) and combines them
Q6. What are the advantages and disadvantages of a decision tree algorithm?
Advantages
Applicable to both classification or regression- Decision trees can be used for classification or regression. The data type is not issue.
Explainability - Easy to understand and easy to explain with simple rules.
No assumption about the distribution of data and relationship between features - for example in the case of linear regression it has been assumed that there will be a linear relationship between the dependent and independent variables but that is not case with the decision tree. So it is called a non-parametric method.
Non Linearity - It can handle nonlinear relationship between features and the target variables.
Disadvantages
Over-fitting - as discussed earlier.
Not the best while working with continuous variables.
4. Data & Code on GitHub
Please refer to the following Github repository for data and code files -
5. Conclusion
The decision tree algorithm is simple and powerful. As a data scientist one should know how it works and should be able to explain it with simple steps. More advanced ensemble models are also based on it, so it has a lot of value in the data science domain.
I'd recommend modifying the code given here for the decision tree to work with your own data set and to have fun building a simple and explainable classifier for use in your projects.
If you like this article and wanted to connect then the following are ways to connect with me -
👉 LinkedIn
Thanks for reading the article! Please let me know your queries and suggestion through comments!
Happy Learning...! 😃
Comentarios