Machine Learning Basics: Decision Trees
Decision Tree is a type of supervised Machine Learning Model. But we can use Decision Trees for both classification and regression types of problems. We will concentrate on the classification problem in this blog. When we talk about classification, our output variable is generally categorical. But in certain cases, it could be numerical like age or income of person. We can still convert these types of numerical values into categories like <25, 25–50 and >50 and then work on a classification problem
What is a decision tree?
When we talk about decision trees, we can just visualize this as a tree with many branches, except for a branch having some conditions (like if, else). In the figure below you can see that we want to predict whether it would rain or not based on the Weather. The top node is considered the root node. Now we have 3 types of weather- Sunny, Cloudy and Rainy. We will have all of these as the tree branches.
- If it is cloudy, then it would surely rain.
- If it is Sunny and there is High Humidity then there is no rain, else there is rain
- If it is Rainy and the Wind is weak, only then will there be rain else there is no rain.
When we look at the way a decision tree classifier works, it is most likely like how our brain makes decisions based on certain conditions. Let us consider another example below. This table shows us the distribution of males and females who range between the age groups >50 and <50. We need to predict if the person will play outdoor games or will not play. Now, we try to build a decision tree, which should we make as the root node? Age or Gender?
How to decide which will be the root node?
We have certain ways in which we can decide this.
- Gini index:
- Entropy/Information Gain:
Let’s first compute the probabilities,
P(Male)=500/1000=0.5
P(Female)=500/1000=0.5
P(Age<50)=700/1000
P(Age>50)=300/1000
P(Play with age<50)=260/700
P(NP with age<50)=440/700
P(Play with age>50)=50/300
P(NP with age>50)=250/300
What is Gini Index?
Gini=∑Pᵢ² where i=1 to k
Let us 1st split by gender and calculate the gini index.
Gini index= ½((10/500)² + (490/500)²)+½((300/500)²+(200/500)²)=0.74
Now, Let us split by age and calculate the gini index
Gini Index=700/1000*((260/700)² + (440/700)²)+300/1000*((50/300)²+(250/300)²)=0.59
Now 0.74>0.59, So we should split on Gender rather than on Age
What is Entropy?
Entropy(E)= -∑Pᵢlog₂Pᵢ² where i =1 to k
Entropy is maximum when it is equal to 0. Entropy is 0 if Pᵢ=1
Information Gain(IG)=Entropy(Parent)-Weighted Average Entropy(Children)
Entropy (Play-Parent)=
-310/1000*log₂(310/1000)-690/1000*log₂(690/1000)=0.27
Entropy (Gender-child) =
310/1000*(-(300/310)*log₂(300/310) -10/310*log₂(10/310))+690/1000*(-(200/690)log₂(200/690) -490/690*log₂(490/690)) =0.3699
Entropy (Age-child) =
310/1000*(-(260/310)*log₂(260/310) -50/310*log₂(50/310))+690/1000*(-(440/690)log₂(440/690) -250/690*log₂(250/690)) =0.4339
Since entropy of Gender is higher than Age, we will split by Gender.
Some Advantages:
- Very easily interpret-able: We can infer the path from root node to child node
- Handling Data: Can handle any type of data- categorical, numerical, boolean, etc
- Normalization: Not required for different scale values- we go by condition
- Importance of a feature- based on Gini index or Entropy
Some Disadvantages:
- Overfitting- As the tree grows more and more, we can overfit the data so much
- Unstable- If you make a small change in the data, it will affect the tree drastically
How to overcome these disadvantages?
We can overcome these disadvantages using the following 2 techniques,
- Truncation- Truncate the tree before the tree is completely grown, by specifying a criteria on max_depth,min_ partitions, min_homogenity, min_leaf_nodes
- Pruning: Once the tree is grown, we can cut off the nodes which we feel are not needed. It is a bottom-up approach. We can always apply pruning and check the accuracy before pruning and after pruning on the validation set and chose the best pruning.
Conclusion:
Decision Tree is a fairly simple machine learning algorithm but can be used in a wide range of applications like fraud detection, predictions in health care, churn prediction, etc. We can improve the model performance by going to Random Forests or Ensemble methods which are slight variance of the decision tree but perform better. I will cover these topics in my next blog.