Preliminaries

Before running a GBM model, decide on the following hyperparameters:

  • loss Function: For regression tasks, ‘Gaussian’ (squared error) is commonly used

  • n.trees: number of trees (default = 100)

  • shrinkage: shrinkage factor or learning rate (default = 0.1)

  • bag.fraction: fraction of the training data used for learning (default = 0.5)

  • cv.folds: number of folds for cross-validation (default = 0, i.e., no CV error returned)

  • interaction.depth: depth of individual trees (default 1)

Case Study: Housing Data

We first split the Boston Housing dataset into a 70% training set and a 30% test set.

library(gbm)
## Loaded gbm 2.1.8.1
url = "https://liangfgithub.github.io/Data/HousingData.csv"
mydata = read.csv(url)
n = nrow(mydata)
ntest = round(n * 0.3)
set.seed(1234)
test.id = sample(1:n, ntest)

Fit a GBM

myfit1 = gbm(Y ~ . , data = mydata[-test.id, ], 
            distribution = "gaussian", 
            n.trees = 100,
            shrinkage = 1, 
            interaction.depth = 3, 
            bag.fraction = 1,
            cv.folds = 5)
myfit1;
## gbm(formula = Y ~ ., distribution = "gaussian", data = mydata[-test.id, 
##     ], n.trees = 100, interaction.depth = 3, shrinkage = 1, bag.fraction = 1, 
##     cv.folds = 5)
## A gradient boosted model with gaussian loss function.
## 100 iterations were performed.
## The best cross-validation iteration was 20.
## There were 14 predictors of which 13 had non-zero influence.

Optimal Stopping Point

Plot the CV error to find the optimal number of trees to prevent overfitting. In our case, the optimal stopping point is about trees.

opt.size = gbm.perf(myfit1)

Performance Evaluation

We evaluate the model’s performance on the test data against the number of trees. It turns out that the performance of the forest with the optimal opt.size is reasonably good even if it’s not the smallest test error.

size = 1:myfit1$n.trees
test.err = rep(0, length(size))
for(i in 1:length(size)){
    y.pred = predict(myfit1, mydata[test.id, ], n.trees = size[i])
    test.err[i] = sum((mydata$Y[test.id] - y.pred)^2)
}    
plot(test.err, type = "n")
lines(size, test.err, lwd = 2)
abline(v = opt.size, lwd = 2, lty = 2, col = "blue")

Fit Another GBM

We then construct another GBM model but with shrinkage. Because of this, we will need more trees and also set the subsampling rate to 50%. In this scenario, the optimal stopping point was over 200 trees.

myfit2 = gbm(Y ~ . , data = mydata[-test.id, ], 
            distribution = "gaussian", 
            n.trees = 1000,
            shrinkage = 0.1, 
            interaction.depth = 3, 
            bag.fraction = 0.5,
            cv.folds = 5)
opt.size = gbm.perf(myfit2)

size = 1:myfit2$n.trees
test.err = rep(0, length(size))
for(i in 1:length(size)){
    y.pred = predict(myfit2, mydata[test.id, ], n.trees = size[i])
    test.err[i] = sum((mydata$Y[test.id] - y.pred)^2)
}    
plot(test.err, type = "n")
lines(size, test.err, lwd = 2)
abline(v = opt.size, lwd = 2, lty = 2, col = "blue")

Variable Importance

GBM in R provides measures for variable importance similar to Random Forest. You can access these measures using the summary() function.

par(mfrow=c(1, 2))
summary(myfit1, cBars = 10,
  method = relative.influence, 
  las = 2)
##             var     rel.inf
## lstat     lstat 55.48719511
## rm           rm 23.49623645
## crim       crim  9.92391857
## lon         lon  2.61359901
## dis         dis  1.97426669
## nox         nox  1.67166826
## lat         lat  1.57319028
## age         age  1.08397273
## ptratio ptratio  0.97920710
## tax         tax  0.57316429
## chas       chas  0.28845125
## zn           zn  0.14235518
## rad         rad  0.12114019
## indus     indus  0.07163489

summary(myfit1, cBars = 10,
  method = permutation.test.gbm, 
  las = 2)

##        var    rel.inf
## 1    lstat 46.9921210
## 2       rm 18.1468970
## 3     crim 12.9336936
## 4      lon  5.4035500
## 5      dis  5.0083208
## 6      age  3.9477929
## 7      nox  2.4677146
## 8      lat  2.3663326
## 9  ptratio  0.8346349
## 10     tax  0.7929412
## 11   indus  0.4274400
## 12    chas  0.2313825
## 13     rad  0.2285759
## 14      zn  0.2186029