ADVERTISEMENT
ADVERTISEMENT

Decision Tree Regression 

Decision Tree Regression is a powerful machine learning algorithm used for predicting numerical values. It works by splitting data into smaller groups based on conditions, making it easy to understand and interpret.

Imagine you want to predict house prices based on factors like area, number of rooms, and location. Instead of using a single equation like linear regression, Decision Tree Regression splits the data step-by-step, creating a tree-like structure for better predictions.

What is Decision Tree Regression?

A Decision Tree is a model that divides data into smaller regions using conditions. It consists of:

  • Root Node: The starting point that contains all data.
  • Decision Nodes: Conditions that split the data further.
  • Leaf Nodes: The final predictions.

Simple Example to Understand

Let’s say we want to predict a car's price based on its age. (Sample Dataset)

Car Age (Years) Car Price ($)
1 30,000
2 28,000
3 25,000
5 20,000
8 15,000
10 10,000

A Decision Tree Regression model might split the data like this:

  1. If car age ≤ 3 years, price is $27,666 (average of first 3 cars).
  2. If car age > 3 years and ≤ 8 years, price is $17,500 (average of 5 & 8-year-old cars).
  3. If car age > 8 years, price is $10,000.

                     (Car Age ≤ 3)
                        /      \
                     
Yes        No
                    /            \
           Price = 27,666      (Car Age ≤ 8)
                                /       \
                             
Yes        No
                             /            \
                     Price = 17,500     Price = 10,000

 

This is how the tree is formed, making predictions easier to interpret compared to complex equations.

Advantages of Decision Tree Regression

  1. Easy to Understand & Interpret – Decision trees mimic human decision-making, making them highly interpretable.
  2. Handles Non-Linear Data – Can capture complex relationships without requiring linear assumptions.
  3. Feature Selection Built-In – Automatically selects the most important features, reducing dimensionality.
  4. Works with Missing Data – Can handle missing values effectively without requiring imputation.
  5. Requires Less Data Preprocessing – No need for feature scaling or normalization.
  6. Performs Well on Small Datasets – Works efficiently even with a limited amount of data.
  7. Handles Both Categorical & Numerical Data – Versatile for various types of datasets.
  8. Fast Training & Prediction – Decision trees train quickly and make predictions in real time.
  9. Can Be Used for Both Regression & Classification – A flexible algorithm for multiple ML tasks.
  10. Useful for Identifying Important Variables – Helps understand which factors impact predictions the most.

Application Areas

  • Real Estate: Predicting house prices based on size, location, and age.
  • Finance: Predicting stock prices and credit scoring.
  • Healthcare: Diagnosing diseases based on symptoms.
  • Marketing: Understanding customer behavior for product recommendations.
  • Automobile Industry: Estimating car resale values.

Python Implementation of Decision Tree Regression

Now, let's see how to implement Decision Tree Regression in Python using sklearn.

# Import required libraries
import numpy as np
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeRegressor

# Step 1: Define the Dataset (Car Age vs Price)
X = np.array([1, 2, 3, 5, 8, 10]).reshape(-1, 1)  # Car Age (Independent Variable)
y = np.array([30000, 28000, 25000, 20000, 15000, 10000])  # Car Price (Dependent Variable)

# Step 2: Train the Decision Tree Model
regressor = DecisionTreeRegressor(random_state=0)  # Create the model
regressor.fit(X, y)  # Train the model

# Step 3: Make Predictions
predicted_price = regressor.predict([[4]])  # Predict for a 4-year-old car
print(f"Predicted Price of 4-year-old car: ${predicted_price[0]:,.2f}")

# Step 4: Visualize Decision Tree Regression
X_grid = np.arange(min(X), max(X), 0.1).reshape(-1, 1)  # Create smoother data for plotting
y_pred_grid = regressor.predict(X_grid)  # Predict values for visualization

# Plot actual data points
plt.scatter(X, y, color='red', label='Actual Prices')
plt.plot(X_grid, y_pred_grid, color='blue', label='Decision Tree Prediction')

# Labels and Title
plt.xlabel("Car Age (Years)")
plt.ylabel("Car Price ($)")
plt.title("Decision Tree Regression: Car Age vs Price")
plt.legend()
plt.show()

Output:


ADVERTISEMENT

ADVERTISEMENT