Note: This HTML file is based on the mlr3 book , which is the official documentation page of the package. For a more detailed introduction to the package check the link.

Introduction to mlr3

What is it?

The mlr3 (Lang et al. 2019) package and ecosystem provide a generic, object-oriented, and extensible framework for classification, regression, survival analysis, and other machine learning tasks for the R language.

What is it used for?

Its interface provides functionality to extend and combine existing learners, intelligently select and tune the most appropriate technique for a task, and perform large-scale comparisons that enable meta-learning.

What makes it special?

Advanced functionalities include hyperparameter tuning and feature selection. Parallelization of many operations is natively supported. But more on that later.

library(mlr3verse)  # loads most packages in the mlr3 world

Contents:

  1. Basics
  2. Performance Evaluation and Comparison
  3. Model Optimization
  4. Technical and Special Tasks

1. Basics

Below you see a typical ML taskflow. mlr3 can take care of every step.

Fig. 1: Typical ML Workflow

Overview

In this section we will focus on

  • R6

  • Tasks/Data

  • Learners

  • Training and Predicting

R6

  • one of R’s more recent dialects for object-oriented programming (OO)

  • Objects are created by calling the constructor of an R6::R6Class() object, specifically the initialization method $new():
    e.g foo = new(bar = 1) creates a new object of class Foo, setting the bar argument of the constructor to the value 1

  • Most objects in mlr3 are created through special functions (e.g. lrn(“regr.rpart”)

  • Objects have mutable state that is encapsulated in their fields, which can be accessed through the dollar operator.

  • Objects expose methods that allow to inspect the object’s state, retrieve information, or perform an action that changes the internal state of the object.

Tasks

Tasks encapsulate the data with meta-information, such as the name of the prediction target column.

Task Types

To create a task from a data.frame(), data.table() or Matrix(), you first need to select the right task type:

  • Classification Task

  • Regression Task

  • Survival Task

  • Density, Cluster, Spatial, Ordinal Regression Tasks

Task Creation and Predefined Tasks

mlr_tasks
## <DictionaryTask> with 28 stored values
## Keys: actg, bike_sharing, boston_housing, breast_cancer, faithful,
##   gbcs, german_credit, grace, ilpd, iris, kc_housing, lung, moneyball,
##   mtcars, optdigits, penguins, penguins_simple, pima, precip, rats,
##   sonar, spam, titanic, unemployment, usarrests, whas, wine, zoo
as.data.table(mlr_tasks)
##                 key                                     label task_type  nrow
##  1:            actg                                      <NA>      surv  1151
##  2:    bike_sharing                       Bike Sharing Demand      regr 17379
##  3:  boston_housing                     Boston Housing Prices      regr   506
##  4:   breast_cancer                   Wisconsin Breast Cancer   classif   683
##  5:        faithful                                      <NA>      dens   272
##  6:            gbcs                                      <NA>      surv   686
##  7:   german_credit                             German Credit   classif  1000
##  8:           grace                                      <NA>      surv  1000
##  9:            ilpd                 Indian Liver Patient Data   classif   583
## 10:            iris                              Iris Flowers   classif   150
## 11:      kc_housing                   King County House Sales      regr 21613
## 12:            lung                                      <NA>      surv   228
## 13:       moneyball          Major League Baseball Statistics      regr  1232
## 14:          mtcars                              Motor Trends      regr    32
## 15:       optdigits Optical Recognition of Handwritten Digits   classif  5620
## 16:        penguins                           Palmer Penguins   classif   344
## 17: penguins_simple                Simplified Palmer Penguins   classif   333
## 18:            pima                      Pima Indian Diabetes   classif   768
## 19:          precip                                      <NA>      dens    70
## 20:            rats                                      <NA>      surv   300
## 21:           sonar                    Sonar: Mines vs. Rocks   classif   208
## 22:            spam                         HP Spam Detection   classif  4601
## 23:         titanic                                   Titanic   classif  1309
## 24:    unemployment                                      <NA>      surv  3343
## 25:       usarrests                                US Arrests     clust    50
## 26:            whas                                      <NA>      surv   481
## 27:            wine                              Wine Regions   classif   178
## 28:             zoo                               Zoo Animals   classif   101
##                 key                                     label task_type  nrow
##     ncol properties lgl int dbl chr fct ord pxc
##  1:   13              0   3   4   0   4   0   0
##  2:   14              2   4   4   1   2   0   0
##  3:   19              0   3  13   0   2   0   0
##  4:   10   twoclass   0   0   0   0   0   9   0
##  5:    1              0   0   1   0   0   0   0
##  6:   10              0   4   4   0   0   0   0
##  7:   21   twoclass   0   3   0   0  14   3   0
##  8:    8              0   2   4   0   0   0   0
##  9:   11   twoclass   0   4   5   0   1   0   0
## 10:    5 multiclass   0   0   4   0   0   0   0
## 11:   20              1  13   4   0   0   0   1
## 12:   10              0   7   0   0   1   0   0
## 13:   15              0   3   5   0   6   0   0
## 14:   11              0   0  10   0   0   0   0
## 15:   65   twoclass   0  64   0   0   0   0   0
## 16:    8 multiclass   0   3   2   0   2   0   0
## 17:   11 multiclass   0   3   7   0   0   0   0
## 18:    9   twoclass   0   0   8   0   0   0   0
## 19:    1              0   0   1   0   0   0   0
## 20:    5              0   2   0   0   1   0   0
## 21:   61   twoclass   0   0  60   0   0   0   0
## 22:   58   twoclass   0   0  57   0   0   0   0
## 23:   11   twoclass   0   2   2   3   2   1   0
## 24:    6              0   1   2   0   1   0   0
## 25:    4              0   2   2   0   0   0   0
## 26:   11              0   4   3   0   2   0   0
## 27:   14 multiclass   0   2  11   0   0   0   0
## 28:   17 multiclass  15   1   0   0   0   0   0
##     ncol properties lgl int dbl chr fct ord pxc

mlr3 provides the shortcut function tsk(). Here, we retrieve the palmer penguins task, which is provided by the package palmerpenguins:

task_penguins = tsk("penguins")
print(task_penguins)
## <TaskClassif:penguins> (344 x 8): Palmer Penguins
## * Target: species
## * Properties: multiclass
## * Features (7):
##   - int (3): body_mass, flipper_length, year
##   - dbl (2): bill_depth, bill_length
##   - fct (2): island, sex

Creating a task form a dataset:

data("mtcars", package = "datasets")
data = mtcars[, 1:3]
str(data)
## 'data.frame':    32 obs. of  3 variables:
##  $ mpg : num  21 21 22.8 21.4 18.7 18.1 14.3 24.4 22.8 19.2 ...
##  $ cyl : num  6 6 4 6 8 6 8 4 4 6 ...
##  $ disp: num  160 160 108 258 360 ...
task_mtcars = as_task_regr(data, target = "mpg", id = "cars")
print(task_mtcars)
## <TaskRegr:cars> (32 x 3)
## * Target: mpg
## * Properties: -
## * Features (2):
##   - dbl (2): cyl, disp

Retrieving Tasks

The data stored in a task can be retrieved directly from fields

task_mtcars
## <TaskRegr:cars> (32 x 3)
## * Target: mpg
## * Properties: -
## * Features (2):
##   - dbl (2): cyl, disp

More information can be obtained through methods of the object, for example:

task_mtcars$data()
##      mpg cyl  disp
##  1: 21.0   6 160.0
##  2: 21.0   6 160.0
##  3: 22.8   4 108.0
##  4: 21.4   6 258.0
##  5: 18.7   8 360.0
##  6: 18.1   6 225.0
##  7: 14.3   8 360.0
##  8: 24.4   4 146.7
##  9: 22.8   4 140.8
## 10: 19.2   6 167.6
## 11: 17.8   6 167.6
## 12: 16.4   8 275.8
## 13: 17.3   8 275.8
## 14: 15.2   8 275.8
## 15: 10.4   8 472.0
## 16: 10.4   8 460.0
## 17: 14.7   8 440.0
## 18: 32.4   4  78.7
## 19: 30.4   4  75.7
## 20: 33.9   4  71.1
## 21: 21.5   4 120.1
## 22: 15.5   8 318.0
## 23: 15.2   8 304.0
## 24: 13.3   8 350.0
## 25: 19.2   8 400.0
## 26: 27.3   4  79.0
## 27: 26.0   4 120.3
## 28: 30.4   4  95.1
## 29: 15.8   8 351.0
## 30: 19.7   6 145.0
## 31: 15.0   8 301.0
## 32: 21.4   4 121.0
##      mpg cyl  disp

In mlr3, each row (observation) has a unique identifier, stored as an integer(). These can be passed as arguments to the $data() method to select specific rows:

head(task_mtcars$row_ids)
## [1] 1 2 3 4 5 6
# retrieve data for rows with IDs 1, 5, and 10
task_mtcars$data(rows = c(1, 5, 10))
##     mpg cyl  disp
## 1: 21.0   6 160.0
## 2: 18.7   8 360.0
## 3: 19.2   6 167.6

Similarly to row IDs, target and feature columns also have unique identifiers, i.e. names (stored as character()). Their names can be accessed via the public slots $feature_names and $target_names.

task_mtcars$feature_names
## [1] "cyl"  "disp"
task_mtcars$target_names
## [1] "mpg"

To extract the complete data from the task, one can also simply convert it to a data.table:

# show summary of entire data
summary(as.data.table(task_mtcars))
##       mpg             cyl             disp      
##  Min.   :10.40   Min.   :4.000   Min.   : 71.1  
##  1st Qu.:15.43   1st Qu.:4.000   1st Qu.:120.8  
##  Median :19.20   Median :6.000   Median :196.3  
##  Mean   :20.09   Mean   :6.188   Mean   :230.7  
##  3rd Qu.:22.80   3rd Qu.:8.000   3rd Qu.:326.0  
##  Max.   :33.90   Max.   :8.000   Max.   :472.0

Task Mutators

filter() subsets the current view based on row IDs and select() subsets the view based on feature names.

task_penguins = tsk("penguins")
task_penguins$select(c("body_mass", "flipper_length")) # keep only these features
task_penguins$filter(1:3) # keep only these rows
task_penguins$head()
##    species body_mass flipper_length
## 1:  Adelie      3750            181
## 2:  Adelie      3800            186
## 3:  Adelie      3250            195

The methods rbind() and cbind() allow to add extra rows and columns to a task. Again, the original data is not changed.

task_penguins$cbind(data.frame(letters = letters[1:3])) # add column letters
task_penguins$head()
##    species body_mass flipper_length letters
## 1:  Adelie      3750            181       a
## 2:  Adelie      3800            186       b
## 3:  Adelie      3250            195       c

Plotting Tasks

The mlr3viz package provides plotting facilities for many classes implemented in mlr3. The available plot types depend on the class, but all plots are returned as ggplot2 objects which can be easily customized.

Some examples:

# get the pima indians task
task = tsk("pima")

# subset task to only use the 3 first features
task$select(head(task$feature_names, 3))

# default plot: class frequencies
autoplot(task)

## Warning: package 'GGally' was built under R version 4.1.3
## Loading required package: ggplot2
## Registered S3 method overwritten by 'GGally':
##   method from   
##   +.gg   ggplot2
# pairs plot (requires package GGally)
autoplot(task, type = "pairs")

# duo plot (requires package GGally)
autoplot(task, type = "duo")
## Warning: Removed 5 rows containing non-finite values (stat_boxplot).
## Warning: Removed 374 rows containing non-finite values (stat_boxplot).

Of course, you can do the same for regression tasks.

# get the complete mtcars task
task = tsk("mtcars")

# subset task to only use the 3 first features
task$select(head(task$feature_names, 3))

# default plot: boxplot of target variable
autoplot(task)

# pairs plot (requires package GGally)
autoplot(task, type = "pairs")

Learners

Learners encapsulate machine learning algorithms to train models and make predictions for a task. These are provided by other packages.

Objects of class Learner provide a unified interface to many popular machine learning algorithms in R. They consist of methods to train and predict a model for a Task and provide meta-information about the learners, such as the hyperparameters (which control the behavior of the learner) you can set.

Learner 2 stage Process

Predefined Learners

Basic:

  • mlr_learners_classif.featureless
  • mlr_learners_regr.featureless:
  • mlr_learners_classif.rpart
  • mlr_learners_regr.rpart

In mlr3learners package:

  • Linear and logistic regression
  • Penalized Generalized Linear Models
  • k-Nearest Neighbors regression and classification
  • Kriging
  • Linear and Quadratic Discriminant Analysis
  • Naive Bayes
  • Support-Vector machines
  • Gradient Boosting
  • Random Forests for regression, classification and survival
mlr_learners
## <DictionaryLearner> with 136 stored values
## Keys: classif.AdaBoostM1, classif.bart, classif.C50, classif.catboost,
##   classif.cforest, classif.ctree, classif.cv_glmnet, classif.debug,
##   classif.earth, classif.extratrees, classif.featureless, classif.fnn,
##   classif.gam, classif.gamboost, classif.gausspr, classif.gbm,
##   classif.glmboost, classif.glmnet, classif.IBk, classif.J48,
##   classif.JRip, classif.kknn, classif.ksvm, classif.lda,
##   classif.liblinear, classif.lightgbm, classif.LMT, classif.log_reg,
##   classif.lssvm, classif.mob, classif.multinom, classif.naive_bayes,
##   classif.nnet, classif.OneR, classif.PART, classif.qda,
##   classif.randomForest, classif.ranger, classif.rfsrc, classif.rpart,
##   classif.svm, classif.xgboost, clust.agnes, clust.ap, clust.cmeans,
##   clust.cobweb, clust.dbscan, clust.diana, clust.em, clust.fanny,
##   clust.featureless, clust.ff, clust.hclust, clust.kkmeans,
##   clust.kmeans, clust.MBatchKMeans, clust.meanshift, clust.pam,
##   clust.SimpleKMeans, clust.xmeans, dens.hist, dens.kde, dens.kde_kd,
##   dens.kde_ks, dens.locfit, dens.logspline, dens.mixed, dens.nonpar,
##   dens.pen, dens.plug, dens.spline, regr.bart, regr.catboost,
##   regr.cforest, regr.ctree, regr.cubist, regr.cv_glmnet, regr.debug,
##   regr.earth, regr.extratrees, regr.featureless, regr.fnn, regr.gam,
##   regr.gamboost, regr.gausspr, regr.gbm, regr.glm, regr.glmboost,
##   regr.glmnet, regr.IBk, regr.kknn, regr.km, regr.ksvm, regr.liblinear,
##   regr.lightgbm, regr.lm, regr.M5Rules, regr.mars, regr.mob,
##   regr.randomForest, regr.ranger, regr.rfsrc, regr.rpart, regr.rvm,
##   regr.svm, regr.xgboost, surv.akritas, surv.blackboost, surv.cforest,
##   surv.coxboost, surv.coxph, surv.coxtime, surv.ctree,
##   surv.cv_coxboost, surv.cv_glmnet, surv.deephit, surv.deepsurv,
##   surv.dnnsurv, surv.flexible, surv.gamboost, surv.gbm, surv.glmboost,
##   surv.glmnet, surv.kaplan, surv.loghaz, surv.mboost, surv.nelson,
##   surv.obliqueRSF, surv.parametric, surv.pchazard, surv.penalized,
##   surv.ranger, surv.rfsrc, surv.rpart, surv.svm, surv.xgboost

Access Learner Information

Each learner provides the following meta-information:

  • feature_types
  • packages
  • properties
  • predict_types

You can retrieve a specific learner using its ID:

learner = lrn("classif.rpart")
print(learner)
## <LearnerClassifRpart:classif.rpart>: Classification Tree
## * Model: -
## * Parameters: xval=0
## * Packages: mlr3, rpart
## * Predict Type: response
## * Feature types: logical, integer, numeric, factor, ordered
## * Properties: importance, missings, multiclass, selected_features,
##   twoclass, weights

Each learner has hyperparameters that control its behavior. The field param_set stores a description of the hyperparameters the learner has, their ranges, defaults, and current values:

learner$param_set
## <ParamSet>
##                 id    class lower upper nlevels        default value
##  1:             cp ParamDbl     0     1     Inf           0.01      
##  2:     keep_model ParamLgl    NA    NA       2          FALSE      
##  3:     maxcompete ParamInt     0   Inf     Inf              4      
##  4:       maxdepth ParamInt     1    30      30             30      
##  5:   maxsurrogate ParamInt     0   Inf     Inf              5      
##  6:      minbucket ParamInt     1   Inf     Inf <NoDefault[3]>      
##  7:       minsplit ParamInt     1   Inf     Inf             20      
##  8: surrogatestyle ParamInt     0     1       2              0      
##  9:   usesurrogate ParamInt     0     2       3              2      
## 10:           xval ParamInt     0   Inf     Inf             10     0

You can easily change the current hyperparameter values like this

pv = learner$param_set$values
pv$cp = 0.02
learner$param_set$values = pv

or like this

learner = lrn("classif.rpart", id = "rp", cp = 0.001)
learner$id
## [1] "rp"
learner$param_set$values
## $xval
## [1] 0
## 
## $cp
## [1] 0.001

Thresholding

In the example below, we change the threshold to 0.2, which improves the True Positive Rate (TPR).

data("Sonar", package = "mlbench")
task = as_task_classif(Sonar, target = "Class", positive = "M")
learner = lrn("classif.rpart", predict_type = "prob")
pred = learner$train(task)$predict(task)

measures = msrs(c("classif.tpr", "classif.tnr")) # use msrs() to get a list of multiple measures
pred$confusion
##         truth
## response  M  R
##        M 95 10
##        R 16 87
pred$score(measures)
## classif.tpr classif.tnr 
##   0.8558559   0.8969072
pred$set_threshold(0.2)
pred$confusion
##         truth
## response   M   R
##        M 104  25
##        R   7  72
pred$score(measures)
## classif.tpr classif.tnr 
##   0.9369369   0.7422680

Training and Predicting

This section illustrates how to use tasks and learners to train a model and make predictions on a new data set. The concept is demonstrated on a supervised classification task using the penguins dataset and the rpart learner, which builds a single classification tree.

Creating Task and Learner Objects

task = tsk("penguins")
learner = lrn("classif.rpart")

Training und Test Splits

# index vectors
train_set = sample(task$nrow, 0.8 * task$nrow)
test_set = setdiff(seq_len(task$nrow), train_set)

Later on: mlr3 can automatically create training and test sets based on different resampling strategies.

Training the learner

The field model stores the model that is produced in the training step.

# fit the classification tree using the training set of the task by calling the $train() method of learner:
learner$train(task, row_ids = train_set)
print(learner$model)
## n= 275 
## 
## node), split, n, loss, yval, (yprob)
##       * denotes terminal node
## 
## 1) root 275 157 Adelie (0.429090909 0.192727273 0.378181818)  
##   2) flipper_length< 206.5 166  49 Adelie (0.704819277 0.289156627 0.006024096)  
##     4) bill_length< 43.35 117   4 Adelie (0.965811966 0.034188034 0.000000000) *
##     5) bill_length>=43.35 49   5 Chinstrap (0.081632653 0.897959184 0.020408163) *
##   3) flipper_length>=206.5 109   6 Gentoo (0.009174312 0.045871560 0.944954128)  
##     6) bill_depth>=17.2 7   2 Chinstrap (0.142857143 0.714285714 0.142857143) *
##     7) bill_depth< 17.2 102   0 Gentoo (0.000000000 0.000000000 1.000000000) *

Predicting

After the model has been fitted to the training data, we use the test set for prediction:

prediction = learner$predict(task, row_ids = test_set)
print(prediction)
## <PredictionClassif> for 69 observations:
##     row_ids     truth  response
##           5    Adelie    Adelie
##           7    Adelie    Adelie
##           9    Adelie    Adelie
## ---                            
##         326 Chinstrap Chinstrap
##         332 Chinstrap Chinstrap
##         335 Chinstrap Chinstrap

A prediction objects holds the row IDs of the test data, the respective true label of the target column and the respective predictions. The simplest way to extract this information is by converting the Prediction object to a data.table()

head(as.data.table(prediction)) # show first six predictions
##    row_ids  truth response
## 1:       5 Adelie   Adelie
## 2:       7 Adelie   Adelie
## 3:       9 Adelie   Adelie
## 4:      11 Adelie   Adelie
## 5:      15 Adelie   Adelie
## 6:      18 Adelie   Adelie
prediction$confusion
##            truth
## response    Adelie Chinstrap Gentoo
##   Adelie        33         1      0
##   Chinstrap      1        14      1
##   Gentoo         0         0     19

Changing the Predict Type

learner$predict_type = "prob"

# re-fit the model
learner$train(task, row_ids = train_set)

# rebuild prediction object
prediction = learner$predict(task, row_ids = test_set)

The prediction object now contains probabilities for all class labels in addition to the predicted label (the one with the highest probability):

# data.table conversion
head(as.data.table(prediction)) # show first six
##    row_ids  truth response prob.Adelie prob.Chinstrap prob.Gentoo
## 1:       5 Adelie   Adelie    0.965812     0.03418803           0
## 2:       7 Adelie   Adelie    0.965812     0.03418803           0
## 3:       9 Adelie   Adelie    0.965812     0.03418803           0
## 4:      11 Adelie   Adelie    0.965812     0.03418803           0
## 5:      15 Adelie   Adelie    0.965812     0.03418803           0
## 6:      18 Adelie   Adelie    0.965812     0.03418803           0
# directly access the predicted labels:
head(prediction$response)
## [1] Adelie Adelie Adelie Adelie Adelie Adelie
## Levels: Adelie Chinstrap Gentoo
# directly access the matrix of probabilities:
head(prediction$prob)
##        Adelie  Chinstrap Gentoo
## [1,] 0.965812 0.03418803      0
## [2,] 0.965812 0.03418803      0
## [3,] 0.965812 0.03418803      0
## [4,] 0.965812 0.03418803      0
## [5,] 0.965812 0.03418803      0
## [6,] 0.965812 0.03418803      0

Similarly to predicting probabilities for classification, many regression learners support the extraction of standard error estimates for predictions by setting the predict type to “se”.

Plotting Predictions

Similarly to plotting tasks, mlr3viz provides an autoplot() method for Prediction objects.

task = tsk("penguins")
learner = lrn("classif.rpart", predict_type = "prob")
learner$train(task)
prediction = learner$predict(task)
autoplot(prediction)

Performance Assessment

Available measures are stored in mlr_measures:

mlr_measures
## <DictionaryMeasure> with 87 stored values
## Keys: aic, bic, classif.acc, classif.auc, classif.bacc, classif.bbrier,
##   classif.ce, classif.costs, classif.dor, classif.fbeta, classif.fdr,
##   classif.fn, classif.fnr, classif.fomr, classif.fp, classif.fpr,
##   classif.logloss, classif.mbrier, classif.mcc, classif.npv,
##   classif.ppv, classif.prauc, classif.precision, classif.recall,
##   classif.sensitivity, classif.specificity, classif.tn, classif.tnr,
##   classif.tp, classif.tpr, clust.ch, clust.db, clust.dunn,
##   clust.silhouette, clust.wss, debug, dens.logloss, oob_error,
##   regr.bias, regr.ktau, regr.mae, regr.mape, regr.maxae, regr.medae,
##   regr.medse, regr.mse, regr.msle, regr.pbias, regr.rae, regr.rmse,
##   regr.rmsle, regr.rrse, regr.rse, regr.rsq, regr.sae, regr.smape,
##   regr.srho, regr.sse, selected_features, sim.jaccard, sim.phi,
##   surv.brier, surv.calib_alpha, surv.calib_beta, surv.chambless_auc,
##   surv.cindex, surv.dcalib, surv.graf, surv.hung_auc, surv.intlogloss,
##   surv.logloss, surv.mae, surv.mse, surv.nagelk_r2, surv.oquigley_r2,
##   surv.rmse, surv.schmid, surv.song_auc, surv.song_tnr, surv.song_tpr,
##   surv.uno_auc, surv.uno_tnr, surv.uno_tpr, surv.xu_r2, time_both,
##   time_predict, time_train

We choose accuracy (classif.acc) as our specific performance measure here and call the method $score() of the prediction object to quantify the predictive performance of our model.

measure = msr("classif.acc")
print(measure)
## <MeasureClassifSimple:classif.acc>: Classification Accuracy
## * Packages: mlr3, mlr3measures
## * Range: [0, 1]
## * Minimize: FALSE
## * Average: macro
## * Parameters: list()
## * Properties: -
## * Predict type: response
prediction$score(measure)
## classif.acc 
##   0.9651163

If no measure is specified, classification defaults to classification error (classif.ce, the inverse of accuracy) and regression to the mean squared error (regr.mse).

Performance Evaluation

Now, let us explain how mlr3 enable us to perform many common machine learning steps. Like:

ROC curve and Thresholds

Binary classification is a special case of classification where the target variable to predict has only two possible values and a threshold probability to distinguish between both. ROC (Receiver Operating Characteristics) analysis applies particularly to this case and allows us to get a better understanding of the trade offs when choosing between the two classes.

data("Sonar", package = "mlbench")
task = as_task_classif(Sonar, target = "Class", positive = "M")

learner = lrn("classif.rpart", predict_type = "prob")
pred = learner$train(task)$predict(task)
C = pred$confusion
print(C)
##         truth
## response  M  R
##        M 95 10
##        R 16 87

Fig. 2: confussion matrix

We can then derive the following performance metrics from a confusion matrix:

  • True Positive Rate (TPR): How many of the true positives did we predict as positive?

  • True Negative Rate (TNR): How many of the true negatives did we predict as negative?

  • Positive Predictive Value (PPV): If we predict positive, how likely is it a true positive?

  • Negative Predictive Value (NPV): If we predict negative, how likely is it a true negative?

The ROC curve plots the TPR and FPR values for different thresholds in order to characterize the behavior of a binary classifier.

For mlr3 prediction objects, the ROC curve can easily be created with mlrviz which relies in the (precision-recall) precrec to calculate and plot ROC curves:

library("mlr3viz")
## Warning: package 'mlr3viz' was built under R version 4.1.3
# TPR vs FPR / Sensitivity vs (1 - Specificity)
autoplot(pred, type = "roc")

Also, we can plot the precision-recall (PPV vs. TPR). Which are preferred over ROC curves for imbalanced class distribution.

# Precision vs Recall
autoplot(pred, type = "prc")

Resampling

When evaluating the performance of a model, we are interested in its generalization performance. Thus, we can estimate it by evaluating a model on a test set as we have done above.

There are different strategies (resampling) for partitioning a data set into training and test sets. In mlr3 there are the following predefined resampling strategies:

# mlr resampling strategies
as.data.table(mlr_resamplings)
##            key                         label        params iters
## 1:   bootstrap                     Bootstrap ratio,repeats    30
## 2:      custom                 Custom Splits                  NA
## 3:   custom_cv Custom Split Cross-Validation                  NA
## 4:          cv              Cross-Validation         folds    10
## 5:     holdout                       Holdout         ratio     1
## 6:    insample           Insample Resampling                   1
## 7:         loo                 Leave-One-Out                  NA
## 8: repeated_cv     Repeated Cross-Validation folds,repeats   100
## 9: subsampling                   Subsampling ratio,repeats    30

Settings

library("mlr3verse")

# Set the task to use
task = tsk("penguins")

# Set the learner: 
# a simple classification tree 
# from the rpart package
learner = lrn("classif.rpart")

# Set resampling strategy
resampling = rsmp("holdout")
print(resampling)
## <ResamplingHoldout>: Holdout
## * Iterations: 1
## * Instantiated: FALSE
## * Parameters: ratio=0.6667
# by default we have a ratio of 2/3 obs
# going into the training set

# We can change the hyper parameter as:
rsmp("holdout", ratio = 0.8)
## <ResamplingHoldout>: Holdout
## * Iterations: 1
## * Instantiated: FALSE
## * Parameters: ratio=0.8

Now, to actually perform the splitting and obtain indices for the training and test split, the resampling strategy needs a Task.

# Instatiation 
resampling$instantiate(task)

# training set
str(resampling$train_set(1))
##  int [1:229] 3 5 8 9 10 11 12 15 16 17 ...
# test set
str(resampling$test_set(1))
##  int [1:115] 1 2 4 6 7 13 14 19 22 23 ...

Note: If we want to compare multiple Learners, one needs to use the same instantiated resampling for each learner, such that each learner gets exactly the same training data and the performance of the model is evaluated with the same test set. For more on comparison of multiple learners refer to section on benchmarking of the mlr3book.

Execution

Once we have defined a Task, a Learner, and a Resampling object we can call resample(), which fits the learner on the training set and evaluates it on the test set.

# binary classif task
task = tsk("pima")
# select 2 features
task$select(c("glucose", "mass"))
learner = lrn("classif.rpart", predict_type = "prob")

# 10-fold-cv
resampling = rsmp("cv")

# Execute the resampling
# store_models = TRUE, to keep the fitted models
rr = resample(task, learner, resampling, store_models = TRUE)
## INFO  [22:08:45.068] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 4/10) 
## INFO  [22:08:45.111] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 6/10) 
## INFO  [22:08:45.145] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 9/10) 
## INFO  [22:08:45.180] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 2/10) 
## INFO  [22:08:45.199] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 3/10) 
## INFO  [22:08:45.217] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 10/10) 
## INFO  [22:08:45.236] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 7/10) 
## INFO  [22:08:45.256] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 5/10) 
## INFO  [22:08:45.275] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/10) 
## INFO  [22:08:45.295] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 8/10)
print(rr)
## <ResampleResult> of 10 iterations
## * Task: pima
## * Learner: classif.rpart
## * Warnings: 0 in 0 iterations
## * Errors: 0 in 0 iterations
# we can calculate the avg performance
# across all resampling iterations,
# in terms of classification error
rr$aggregate(msr("classif.ce"))
## classif.ce 
##  0.2644224
# or for the individual resampling iterations
# to check if iterations are diff from the avg
rr$score(msr("classif.ce"))
##                  task task_id                   learner    learner_id
##  1: <TaskClassif[50]>    pima <LearnerClassifRpart[38]> classif.rpart
##  2: <TaskClassif[50]>    pima <LearnerClassifRpart[38]> classif.rpart
##  3: <TaskClassif[50]>    pima <LearnerClassifRpart[38]> classif.rpart
##  4: <TaskClassif[50]>    pima <LearnerClassifRpart[38]> classif.rpart
##  5: <TaskClassif[50]>    pima <LearnerClassifRpart[38]> classif.rpart
##  6: <TaskClassif[50]>    pima <LearnerClassifRpart[38]> classif.rpart
##  7: <TaskClassif[50]>    pima <LearnerClassifRpart[38]> classif.rpart
##  8: <TaskClassif[50]>    pima <LearnerClassifRpart[38]> classif.rpart
##  9: <TaskClassif[50]>    pima <LearnerClassifRpart[38]> classif.rpart
## 10: <TaskClassif[50]>    pima <LearnerClassifRpart[38]> classif.rpart
##             resampling resampling_id iteration              prediction
##  1: <ResamplingCV[20]>            cv         1 <PredictionClassif[20]>
##  2: <ResamplingCV[20]>            cv         2 <PredictionClassif[20]>
##  3: <ResamplingCV[20]>            cv         3 <PredictionClassif[20]>
##  4: <ResamplingCV[20]>            cv         4 <PredictionClassif[20]>
##  5: <ResamplingCV[20]>            cv         5 <PredictionClassif[20]>
##  6: <ResamplingCV[20]>            cv         6 <PredictionClassif[20]>
##  7: <ResamplingCV[20]>            cv         7 <PredictionClassif[20]>
##  8: <ResamplingCV[20]>            cv         8 <PredictionClassif[20]>
##  9: <ResamplingCV[20]>            cv         9 <PredictionClassif[20]>
## 10: <ResamplingCV[20]>            cv        10 <PredictionClassif[20]>
##     classif.ce
##  1:  0.2727273
##  2:  0.2077922
##  3:  0.1948052
##  4:  0.2337662
##  5:  0.3376623
##  6:  0.2077922
##  7:  0.3246753
##  8:  0.2597403
##  9:  0.3552632
## 10:  0.2500000
# and many more getters stored in rr:

rr$warnings
## Empty data.table (0 rows and 2 cols): iteration,msg
rr$errors
## Empty data.table (0 rows and 2 cols): iteration,msg
rr$resampling
## <ResamplingCV>: Cross-Validation
## * Iterations: 10
## * Instantiated: TRUE
## * Parameters: folds=10
# the model trained in a specific iteration
lrn = rr$learners[[1]]
lrn$model
## n= 691 
## 
## node), split, n, loss, yval, (yprob)
##       * denotes terminal node
## 
##   1) root 691 237 neg (0.3429812 0.6570188)  
##     2) glucose>=123.5 290 121 pos (0.5827586 0.4172414)  
##       4) mass>=29.95 205  63 pos (0.6926829 0.3073171)  
##         8) glucose>=154.5 85  12 pos (0.8588235 0.1411765) *
##         9) glucose< 154.5 120  51 pos (0.5750000 0.4250000)  
##          18) mass>=41.65 22   4 pos (0.8181818 0.1818182) *
##          19) mass< 41.65 98  47 pos (0.5204082 0.4795918)  
##            38) mass< 39.95 86  37 pos (0.5697674 0.4302326)  
##              76) glucose>=129.5 56  20 pos (0.6428571 0.3571429)  
##               152) mass>=34.05 23   4 pos (0.8260870 0.1739130) *
##               153) mass< 34.05 33  16 pos (0.5151515 0.4848485)  
##                 306) mass< 31.8 12   3 pos (0.7500000 0.2500000) *
##                 307) mass>=31.8 21   8 neg (0.3809524 0.6190476) *
##              77) glucose< 129.5 30  13 neg (0.4333333 0.5666667)  
##               154) mass< 34.15 18   7 pos (0.6111111 0.3888889) *
##               155) mass>=34.15 12   2 neg (0.1666667 0.8333333) *
##            39) mass>=39.95 12   2 neg (0.1666667 0.8333333) *
##       5) mass< 29.95 85  27 neg (0.3176471 0.6823529)  
##        10) glucose>=160 18   7 pos (0.6111111 0.3888889) *
##        11) glucose< 160 67  16 neg (0.2388060 0.7611940) *
##     3) glucose< 123.5 401  68 neg (0.1695761 0.8304239) *
# all individual predictions
# merged into a single Prediction object
rr$prediction() 
## <PredictionClassif> for 768 observations:
##     row_ids truth response  prob.pos  prob.neg
##          21   neg      neg 0.1666667 0.8333333
##          40   pos      neg 0.1695761 0.8304239
##          46   pos      pos 0.8588235 0.1411765
## ---                                           
##         705   neg      neg 0.1921296 0.8078704
##         750   pos      neg 0.3174603 0.6825397
##         766   neg      neg 0.1921296 0.8078704
# or in a specific resampling iteration
rr$predictions()[[1]] 
## <PredictionClassif> for 77 observations:
##     row_ids truth response  prob.pos  prob.neg
##          21   neg      neg 0.1666667 0.8333333
##          40   pos      neg 0.1695761 0.8304239
##          46   pos      pos 0.8588235 0.1411765
## ---                                           
##         753   neg      neg 0.1695761 0.8304239
##         754   pos      pos 0.8588235 0.1411765
##         762   pos      pos 0.8588235 0.1411765

Plotting Resample Results

mlr3viz provides a autoplot() method for resampling results.

# boxplot of AUC values across the 10 folds
autoplot(rr, measure = msr("classif.auc"))

# ROC curve, averaged over 10 folds
autoplot(rr, type = "roc")

# learner predictions for the individual fold
rr$filter(1)
autoplot(rr, type = "prediction")
## Warning: Removed 2 rows containing missing values (geom_point).

2. Model Optimization

In machine learning, when you are dissatisfied with the performance of a model, you might ask yourself how to best improve the model:

We will use the mlr3tuning package, which supports common tuning operations.

Hyperparameter Tuning

Hyperparameters are the parameters of the learners that control how a model is fit to the data. They are sometimes called second-level or second-order parameters of machine learning - the parameters of the models are the first-order parameters and “fit” to the data during model training.

The TuningInstance* Classes

We will examine the optimization of a simple classification tree on the Pima Indian Diabetes data set as an introductory example here.

library("mlr3verse")
task = tsk("pima")
print(task)
## <TaskClassif:pima> (768 x 9): Pima Indian Diabetes
## * Target: diabetes
## * Properties: twoclass
## * Features (8):
##   - dbl (8): age, glucose, insulin, mass, pedigree, pregnant, pressure,
##     triceps

We use the rpart classification tree and choose a subset of the hyperparameters we want to tune. This is often referred to as the “tuning space”.

learner = lrn("classif.rpart")
learner$param_set
## <ParamSet>
##                 id    class lower upper nlevels        default value
##  1:             cp ParamDbl     0     1     Inf           0.01      
##  2:     keep_model ParamLgl    NA    NA       2          FALSE      
##  3:     maxcompete ParamInt     0   Inf     Inf              4      
##  4:       maxdepth ParamInt     1    30      30             30      
##  5:   maxsurrogate ParamInt     0   Inf     Inf              5      
##  6:      minbucket ParamInt     1   Inf     Inf <NoDefault[3]>      
##  7:       minsplit ParamInt     1   Inf     Inf             20      
##  8: surrogatestyle ParamInt     0     1       2              0      
##  9:   usesurrogate ParamInt     0     2       3              2      
## 10:           xval ParamInt     0   Inf     Inf             10     0

Here, we opt to tune two hyperparameters:

  • Complexity hyperparameter
  • Minsplit hyperparameter

Those hyperparameters needs to be bounded with lower and upper bounds.

search_space = ps(
  cp = p_dbl(lower = 0.001, upper = 0.1),
  minsplit = p_int(lower = 1, upper = 10)
)
search_space
## <ParamSet>
##          id    class lower upper nlevels        default value
## 1:       cp ParamDbl 0.001   0.1     Inf <NoDefault[3]>      
## 2: minsplit ParamInt 1.000  10.0      10 <NoDefault[3]>

Next, we need to specify how to evaluate the performance of a trained model. For this, we need to choose a resampling strategy and a performance measure.

hout = rsmp("holdout")
measure = msr("classif.ce")

Finally, we have to specify the budget available for tuning. mlr3 allows to specify complex termination criteria by selecting one of the available Terminators:

  • TerminatorClockTime: Terminate after a given time.
  • TerminatorEvals: Terminate after a given number of iterations.
  • TerminatorPerfReached: Terminate after a specific performance has been reached.
  • TerminatorStagnation: Terminate when tuning does find a better configuration for a given number of iterations.
  • TerminatorCombo: A combination of the above in an ALL or ANY fashion.
evals20 = trm("evals", n_evals = 20)

instance = TuningInstanceSingleCrit$new(
  task = task,
  learner = learner,
  resampling = hout,
  measure = measure,
  search_space = search_space,
  terminator = evals20
)
instance
## <TuningInstanceSingleCrit>
## * State:  Not optimized
## * Objective: <ObjectiveTuning:classif.rpart_on_pima>
## * Search Space:
##          id    class lower upper nlevels
## 1:       cp ParamDbl 0.001   0.1     Inf
## 2: minsplit ParamInt 1.000  10.0      10
## * Terminator: <TerminatorEvals>

In order to start the tuning, we need to select how the optimization should happen (choose the optimization algorithm via the Tuner class).

The Tuner Class

The following algorithms are currently implemented in mlr3tuning:

  • TunerGridSearch: Grid Search
  • TunerRandomSearch: Random Search
  • TunerGenSA: Generalized Simulated Annealing
  • TunerNLoptr: Non-Linear Optimization
tuner = tnr("grid_search", resolution = 5)

Triggering the Tuning

The tuner proceeds as follows:

  1. The Tuner proposes at least one hyperparameter configuration to evaluate.
  2. For each configuration, the given Learner is fitted on the Task and evaluated using the provided Resampling.
  3. The Terminator is queried if the budget is exhausted.
  4. Determine the configurations with the best observed performance from the archive.
  5. Store the best configurations as result in the tuning instance object.

To start the tuning, we simply pass the TuningInstanceSingleCrit to the $optimize() method of the initialized Tuner:

tuner$optimize(instance)
## INFO  [22:08:46.673] [bbotk] Starting to optimize 2 parameter(s) with '<TunerGridSearch>' and '<TerminatorEvals> [n_evals=20, k=0]' 
## INFO  [22:08:46.677] [bbotk] Evaluating 1 configuration(s) 
## INFO  [22:08:46.689] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:46.694] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:46.714] [mlr3] Finished benchmark 
## INFO  [22:08:46.739] [bbotk] Result of batch 1: 
## INFO  [22:08:46.741] [bbotk]       cp minsplit classif.ce warnings errors runtime_learners 
## INFO  [22:08:46.741] [bbotk]  0.02575        8    0.28125        0      0             0.02 
## INFO  [22:08:46.741] [bbotk]                                 uhash 
## INFO  [22:08:46.741] [bbotk]  c01b7751-febd-44c0-8d39-f9135fadf574 
## INFO  [22:08:46.742] [bbotk] Evaluating 1 configuration(s) 
## INFO  [22:08:46.751] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:46.757] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:46.776] [mlr3] Finished benchmark 
## INFO  [22:08:46.803] [bbotk] Result of batch 2: 
## INFO  [22:08:46.804] [bbotk]   cp minsplit classif.ce warnings errors runtime_learners 
## INFO  [22:08:46.804] [bbotk]  0.1        8    0.28125        0      0             0.01 
## INFO  [22:08:46.804] [bbotk]                                 uhash 
## INFO  [22:08:46.804] [bbotk]  1c066822-8cf0-4603-b2e0-40bd0e95790a 
## INFO  [22:08:46.805] [bbotk] Evaluating 1 configuration(s) 
## INFO  [22:08:46.815] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:46.825] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:46.847] [mlr3] Finished benchmark 
## INFO  [22:08:46.872] [bbotk] Result of batch 3: 
## INFO  [22:08:46.873] [bbotk]      cp minsplit classif.ce warnings errors runtime_learners 
## INFO  [22:08:46.873] [bbotk]  0.0505        3    0.28125        0      0                0 
## INFO  [22:08:46.873] [bbotk]                                 uhash 
## INFO  [22:08:46.873] [bbotk]  b2de11c9-7ebb-46f1-bbde-7630850622d5 
## INFO  [22:08:46.874] [bbotk] Evaluating 1 configuration(s) 
## INFO  [22:08:46.886] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:46.891] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:46.908] [mlr3] Finished benchmark 
## INFO  [22:08:46.932] [bbotk] Result of batch 4: 
## INFO  [22:08:46.933] [bbotk]      cp minsplit classif.ce warnings errors runtime_learners 
## INFO  [22:08:46.933] [bbotk]  0.0505        1    0.28125        0      0                0 
## INFO  [22:08:46.933] [bbotk]                                 uhash 
## INFO  [22:08:46.933] [bbotk]  cf8b2097-7654-4f1e-8fa4-7d3db27aec39 
## INFO  [22:08:46.935] [bbotk] Evaluating 1 configuration(s) 
## INFO  [22:08:46.944] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:46.948] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:46.964] [mlr3] Finished benchmark 
## INFO  [22:08:46.999] [bbotk] Result of batch 5: 
## INFO  [22:08:47.001] [bbotk]       cp minsplit classif.ce warnings errors runtime_learners 
## INFO  [22:08:47.001] [bbotk]  0.02575        3    0.28125        0      0                0 
## INFO  [22:08:47.001] [bbotk]                                 uhash 
## INFO  [22:08:47.001] [bbotk]  35988734-d4b4-46a7-858e-12750e2a8092 
## INFO  [22:08:47.003] [bbotk] Evaluating 1 configuration(s) 
## INFO  [22:08:47.015] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:47.020] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:47.041] [mlr3] Finished benchmark 
## INFO  [22:08:47.079] [bbotk] Result of batch 6: 
## INFO  [22:08:47.081] [bbotk]     cp minsplit classif.ce warnings errors runtime_learners 
## INFO  [22:08:47.081] [bbotk]  0.001        1  0.2890625        0      0             0.02 
## INFO  [22:08:47.081] [bbotk]                                 uhash 
## INFO  [22:08:47.081] [bbotk]  29f2c9c1-c491-4511-86a0-2373e6fd56d9 
## INFO  [22:08:47.083] [bbotk] Evaluating 1 configuration(s) 
## INFO  [22:08:47.094] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:47.100] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:47.118] [mlr3] Finished benchmark 
## INFO  [22:08:47.142] [bbotk] Result of batch 7: 
## INFO  [22:08:47.143] [bbotk]      cp minsplit classif.ce warnings errors runtime_learners 
## INFO  [22:08:47.143] [bbotk]  0.0505       10    0.28125        0      0             0.02 
## INFO  [22:08:47.143] [bbotk]                                 uhash 
## INFO  [22:08:47.143] [bbotk]  df71699d-54b9-4fb6-86e9-05d6b4dd0811 
## INFO  [22:08:47.144] [bbotk] Evaluating 1 configuration(s) 
## INFO  [22:08:47.154] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:47.158] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:47.175] [mlr3] Finished benchmark 
## INFO  [22:08:47.200] [bbotk] Result of batch 8: 
## INFO  [22:08:47.202] [bbotk]      cp minsplit classif.ce warnings errors runtime_learners 
## INFO  [22:08:47.202] [bbotk]  0.0505        8    0.28125        0      0                0 
## INFO  [22:08:47.202] [bbotk]                                 uhash 
## INFO  [22:08:47.202] [bbotk]  866067f1-a2c4-4f11-bbd6-4267d4e1de6e 
## INFO  [22:08:47.203] [bbotk] Evaluating 1 configuration(s) 
## INFO  [22:08:47.212] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:47.217] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:47.233] [mlr3] Finished benchmark 
## INFO  [22:08:47.256] [bbotk] Result of batch 9: 
## INFO  [22:08:47.257] [bbotk]      cp minsplit classif.ce warnings errors runtime_learners 
## INFO  [22:08:47.257] [bbotk]  0.0505        5    0.28125        0      0             0.01 
## INFO  [22:08:47.257] [bbotk]                                 uhash 
## INFO  [22:08:47.257] [bbotk]  8570ce12-777a-41c4-b84d-f096c3a628e6 
## INFO  [22:08:47.258] [bbotk] Evaluating 1 configuration(s) 
## INFO  [22:08:47.267] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:47.272] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:47.291] [mlr3] Finished benchmark 
## INFO  [22:08:47.316] [bbotk] Result of batch 10: 
## INFO  [22:08:47.317] [bbotk]       cp minsplit classif.ce warnings errors runtime_learners 
## INFO  [22:08:47.317] [bbotk]  0.07525        8    0.28125        0      0             0.02 
## INFO  [22:08:47.317] [bbotk]                                 uhash 
## INFO  [22:08:47.317] [bbotk]  cde18441-f94b-4eb1-bd93-e417427c918c 
## INFO  [22:08:47.318] [bbotk] Evaluating 1 configuration(s) 
## INFO  [22:08:47.327] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:47.331] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:47.348] [mlr3] Finished benchmark 
## INFO  [22:08:47.375] [bbotk] Result of batch 11: 
## INFO  [22:08:47.376] [bbotk]   cp minsplit classif.ce warnings errors runtime_learners 
## INFO  [22:08:47.376] [bbotk]  0.1        5    0.28125        0      0                0 
## INFO  [22:08:47.376] [bbotk]                                 uhash 
## INFO  [22:08:47.376] [bbotk]  0c73b1ef-0b7a-4b10-8147-a618496950ed 
## INFO  [22:08:47.377] [bbotk] Evaluating 1 configuration(s) 
## INFO  [22:08:47.387] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:47.390] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:47.409] [mlr3] Finished benchmark 
## INFO  [22:08:47.437] [bbotk] Result of batch 12: 
## INFO  [22:08:47.438] [bbotk]     cp minsplit classif.ce warnings errors runtime_learners 
## INFO  [22:08:47.438] [bbotk]  0.001        8  0.2695312        0      0                0 
## INFO  [22:08:47.438] [bbotk]                                 uhash 
## INFO  [22:08:47.438] [bbotk]  68ef7cae-b18d-4320-a44d-70080d077b7a 
## INFO  [22:08:47.439] [bbotk] Evaluating 1 configuration(s) 
## INFO  [22:08:47.450] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:47.454] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:47.484] [mlr3] Finished benchmark 
## INFO  [22:08:47.516] [bbotk] Result of batch 13: 
## INFO  [22:08:47.517] [bbotk]       cp minsplit classif.ce warnings errors runtime_learners 
## INFO  [22:08:47.517] [bbotk]  0.07525       10    0.28125        0      0             0.03 
## INFO  [22:08:47.517] [bbotk]                                 uhash 
## INFO  [22:08:47.517] [bbotk]  29a1c066-5c74-4264-8cf3-97625f4af5d4 
## INFO  [22:08:47.518] [bbotk] Evaluating 1 configuration(s) 
## INFO  [22:08:47.527] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:47.532] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:47.547] [mlr3] Finished benchmark 
## INFO  [22:08:47.571] [bbotk] Result of batch 14: 
## INFO  [22:08:47.573] [bbotk]       cp minsplit classif.ce warnings errors runtime_learners 
## INFO  [22:08:47.573] [bbotk]  0.07525        5    0.28125        0      0             0.01 
## INFO  [22:08:47.573] [bbotk]                                 uhash 
## INFO  [22:08:47.573] [bbotk]  3703156b-b43a-48bd-84aa-88e21c73bb92 
## INFO  [22:08:47.574] [bbotk] Evaluating 1 configuration(s) 
## INFO  [22:08:47.586] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:47.593] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:47.612] [mlr3] Finished benchmark 
## INFO  [22:08:47.635] [bbotk] Result of batch 15: 
## INFO  [22:08:47.636] [bbotk]   cp minsplit classif.ce warnings errors runtime_learners 
## INFO  [22:08:47.636] [bbotk]  0.1       10    0.28125        0      0                0 
## INFO  [22:08:47.636] [bbotk]                                 uhash 
## INFO  [22:08:47.636] [bbotk]  77407ca2-47db-4c31-91d8-4e4be3e030d4 
## INFO  [22:08:47.638] [bbotk] Evaluating 1 configuration(s) 
## INFO  [22:08:47.646] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:47.652] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:47.670] [mlr3] Finished benchmark 
## INFO  [22:08:47.692] [bbotk] Result of batch 16: 
## INFO  [22:08:47.694] [bbotk]       cp minsplit classif.ce warnings errors runtime_learners 
## INFO  [22:08:47.694] [bbotk]  0.07525        3    0.28125        0      0             0.02 
## INFO  [22:08:47.694] [bbotk]                                 uhash 
## INFO  [22:08:47.694] [bbotk]  d5328765-b394-454c-adad-e8d35fa15197 
## INFO  [22:08:47.695] [bbotk] Evaluating 1 configuration(s) 
## INFO  [22:08:47.704] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:47.709] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:47.729] [mlr3] Finished benchmark 
## INFO  [22:08:47.753] [bbotk] Result of batch 17: 
## INFO  [22:08:47.755] [bbotk]     cp minsplit classif.ce warnings errors runtime_learners 
## INFO  [22:08:47.755] [bbotk]  0.001       10  0.2695312        0      0             0.02 
## INFO  [22:08:47.755] [bbotk]                                 uhash 
## INFO  [22:08:47.755] [bbotk]  48a21756-50c5-4963-b5a9-07ef39819964 
## INFO  [22:08:47.756] [bbotk] Evaluating 1 configuration(s) 
## INFO  [22:08:47.765] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:47.769] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:47.786] [mlr3] Finished benchmark 
## INFO  [22:08:47.816] [bbotk] Result of batch 18: 
## INFO  [22:08:47.818] [bbotk]     cp minsplit classif.ce warnings errors runtime_learners 
## INFO  [22:08:47.818] [bbotk]  0.001        3    0.28125        0      0                0 
## INFO  [22:08:47.818] [bbotk]                                 uhash 
## INFO  [22:08:47.818] [bbotk]  e0de4d3c-cc87-4f23-9aee-7859ee57e319 
## INFO  [22:08:47.820] [bbotk] Evaluating 1 configuration(s) 
## INFO  [22:08:47.831] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:47.836] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:47.852] [mlr3] Finished benchmark 
## INFO  [22:08:47.890] [bbotk] Result of batch 19: 
## INFO  [22:08:47.892] [bbotk]   cp minsplit classif.ce warnings errors runtime_learners 
## INFO  [22:08:47.892] [bbotk]  0.1        1    0.28125        0      0             0.01 
## INFO  [22:08:47.892] [bbotk]                                 uhash 
## INFO  [22:08:47.892] [bbotk]  e6a1322c-1140-42ad-902e-12b7d5930dfe 
## INFO  [22:08:47.893] [bbotk] Evaluating 1 configuration(s) 
## INFO  [22:08:47.902] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:47.906] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:47.923] [mlr3] Finished benchmark 
## INFO  [22:08:47.950] [bbotk] Result of batch 20: 
## INFO  [22:08:47.951] [bbotk]     cp minsplit classif.ce warnings errors runtime_learners 
## INFO  [22:08:47.951] [bbotk]  0.001        5   0.265625        0      0             0.02 
## INFO  [22:08:47.951] [bbotk]                                 uhash 
## INFO  [22:08:47.951] [bbotk]  1e88c5da-7ef2-47e3-8f25-05c420a6e896 
## INFO  [22:08:47.956] [bbotk] Finished optimizing after 20 evaluation(s) 
## INFO  [22:08:47.957] [bbotk] Result: 
## INFO  [22:08:47.958] [bbotk]     cp minsplit learner_param_vals  x_domain classif.ce 
## INFO  [22:08:47.958] [bbotk]  0.001        5          <list[3]> <list[2]>   0.265625
##       cp minsplit learner_param_vals  x_domain classif.ce
## 1: 0.001        5          <list[3]> <list[2]>   0.265625
# The best hyperparameter settings
instance$result_learner_param_vals
## $xval
## [1] 0
## 
## $cp
## [1] 0.001
## 
## $minsplit
## [1] 5
# Corresponding measured performance 
instance$result_y
## classif.ce 
##   0.265625
# You can investigate all of the evaluations that were performed
as.data.table(instance$archive)
##          cp minsplit classif.ce x_domain_cp x_domain_minsplit runtime_learners
##  1: 0.02575        8  0.2812500     0.02575                 8             0.02
##  2: 0.10000        8  0.2812500     0.10000                 8             0.01
##  3: 0.05050        3  0.2812500     0.05050                 3             0.00
##  4: 0.05050        1  0.2812500     0.05050                 1             0.00
##  5: 0.02575        3  0.2812500     0.02575                 3             0.00
##  6: 0.00100        1  0.2890625     0.00100                 1             0.02
##  7: 0.05050       10  0.2812500     0.05050                10             0.02
##  8: 0.05050        8  0.2812500     0.05050                 8             0.00
##  9: 0.05050        5  0.2812500     0.05050                 5             0.01
## 10: 0.07525        8  0.2812500     0.07525                 8             0.02
## 11: 0.10000        5  0.2812500     0.10000                 5             0.00
## 12: 0.00100        8  0.2695312     0.00100                 8             0.00
## 13: 0.07525       10  0.2812500     0.07525                10             0.03
## 14: 0.07525        5  0.2812500     0.07525                 5             0.01
## 15: 0.10000       10  0.2812500     0.10000                10             0.00
## 16: 0.07525        3  0.2812500     0.07525                 3             0.02
## 17: 0.00100       10  0.2695312     0.00100                10             0.02
## 18: 0.00100        3  0.2812500     0.00100                 3             0.00
## 19: 0.10000        1  0.2812500     0.10000                 1             0.01
## 20: 0.00100        5  0.2656250     0.00100                 5             0.02
##               timestamp batch_nr warnings errors      resample_result
##  1: 2022-04-09 22:08:46        1        0      0 <ResampleResult[22]>
##  2: 2022-04-09 22:08:46        2        0      0 <ResampleResult[22]>
##  3: 2022-04-09 22:08:46        3        0      0 <ResampleResult[22]>
##  4: 2022-04-09 22:08:46        4        0      0 <ResampleResult[22]>
##  5: 2022-04-09 22:08:46        5        0      0 <ResampleResult[22]>
##  6: 2022-04-09 22:08:47        6        0      0 <ResampleResult[22]>
##  7: 2022-04-09 22:08:47        7        0      0 <ResampleResult[22]>
##  8: 2022-04-09 22:08:47        8        0      0 <ResampleResult[22]>
##  9: 2022-04-09 22:08:47        9        0      0 <ResampleResult[22]>
## 10: 2022-04-09 22:08:47       10        0      0 <ResampleResult[22]>
## 11: 2022-04-09 22:08:47       11        0      0 <ResampleResult[22]>
## 12: 2022-04-09 22:08:47       12        0      0 <ResampleResult[22]>
## 13: 2022-04-09 22:08:47       13        0      0 <ResampleResult[22]>
## 14: 2022-04-09 22:08:47       14        0      0 <ResampleResult[22]>
## 15: 2022-04-09 22:08:47       15        0      0 <ResampleResult[22]>
## 16: 2022-04-09 22:08:47       16        0      0 <ResampleResult[22]>
## 17: 2022-04-09 22:08:47       17        0      0 <ResampleResult[22]>
## 18: 2022-04-09 22:08:47       18        0      0 <ResampleResult[22]>
## 19: 2022-04-09 22:08:47       19        0      0 <ResampleResult[22]>
## 20: 2022-04-09 22:08:47       20        0      0 <ResampleResult[22]>

Now we can take the optimized hyperparameters, set them for the previously-created Learner, and train it on the full dataset.

learner$param_set$values = instance$result_learner_param_vals
learner$train(task)

Tuning with Multiple Performance Measures

When tuning, you might want to use multiple criteria to find the best configuration of hyperparameters. The tuning process is identical to the previous example, with the expection that this time we will specify two performance measures, classification error and time to train the model (time_train) and instead of creating a new TuningInstanceSingleCrit, we create a new TuningInstanceMultiCrit with the two measures we are interested in here.

measures = msrs(c("classif.ce", "time_train"))


evals20 = trm("evals", n_evals = 20)

instance = TuningInstanceMultiCrit$new(
  task = task,
  learner = learner,
  resampling = hout,
  measures = measures,
  search_space = search_space,
  terminator = evals20
)
instance
## <TuningInstanceMultiCrit>
## * State:  Not optimized
## * Objective: <ObjectiveTuning:classif.rpart_on_pima>
## * Search Space:
##          id    class lower upper nlevels
## 1:       cp ParamDbl 0.001   0.1     Inf
## 2: minsplit ParamInt 1.000  10.0      10
## * Terminator: <TerminatorEvals>

Automating the Tuning

This whole process can be automated in mlr3 so that learners are tuned transparently, without the need to extract information on the best hyperparameter settings at the end. The AutoTuner wraps a learner and augments it with an automatic tuning process for a given set of hyperparameters.

learner = lrn("classif.rpart")
search_space = ps(
  cp = p_dbl(lower = 0.001, upper = 0.1),
  minsplit = p_int(lower = 1, upper = 10)
)
terminator = trm("evals", n_evals = 10)
tuner = tnr("random_search")

at = AutoTuner$new(
  learner = learner,
  resampling = rsmp("holdout"),
  measure = msr("classif.ce"),
  search_space = search_space,
  terminator = terminator,
  tuner = tuner
)
at
## <AutoTuner:classif.rpart.tuned>
## * Model: -
## * Search Space:
## <ParamSet>
##          id    class lower upper nlevels        default value
## 1:       cp ParamDbl 0.001   0.1     Inf <NoDefault[3]>      
## 2: minsplit ParamInt 1.000  10.0      10 <NoDefault[3]>      
## * Packages: mlr3, mlr3tuning, rpart
## * Predict Type: response
## * Feature Types: logical, integer, numeric, factor, ordered
## * Properties: importance, missings, multiclass, selected_features,
##   twoclass, weights

We can now use the learner like any other learner, calling the \(\$train()\) and \(\$predict()\) method.

at$train(task)
## INFO  [22:08:48.175] [bbotk] Starting to optimize 2 parameter(s) with '<OptimizerRandomSearch>' and '<TerminatorEvals> [n_evals=10, k=0]' 
## INFO  [22:08:48.189] [bbotk] Evaluating 1 configuration(s) 
## INFO  [22:08:48.199] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:48.204] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:48.221] [mlr3] Finished benchmark 
## INFO  [22:08:48.246] [bbotk] Result of batch 1: 
## INFO  [22:08:48.248] [bbotk]          cp minsplit classif.ce warnings errors runtime_learners 
## INFO  [22:08:48.248] [bbotk]  0.05141753       10  0.2734375        0      0                0 
## INFO  [22:08:48.248] [bbotk]                                 uhash 
## INFO  [22:08:48.248] [bbotk]  0613da77-0978-4cf9-b3cd-506a65659632 
## INFO  [22:08:48.251] [bbotk] Evaluating 1 configuration(s) 
## INFO  [22:08:48.261] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:48.266] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:48.282] [mlr3] Finished benchmark 
## INFO  [22:08:48.308] [bbotk] Result of batch 2: 
## INFO  [22:08:48.310] [bbotk]        cp minsplit classif.ce warnings errors runtime_learners 
## INFO  [22:08:48.310] [bbotk]  0.032403        9  0.2734375        0      0                0 
## INFO  [22:08:48.310] [bbotk]                                 uhash 
## INFO  [22:08:48.310] [bbotk]  3421329e-b977-4580-a440-d3b137b8ced0 
## INFO  [22:08:48.313] [bbotk] Evaluating 1 configuration(s) 
## INFO  [22:08:48.325] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:48.329] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:48.348] [mlr3] Finished benchmark 
## INFO  [22:08:48.373] [bbotk] Result of batch 3: 
## INFO  [22:08:48.374] [bbotk]          cp minsplit classif.ce warnings errors runtime_learners 
## INFO  [22:08:48.374] [bbotk]  0.03644441        6  0.2734375        0      0                0 
## INFO  [22:08:48.374] [bbotk]                                 uhash 
## INFO  [22:08:48.374] [bbotk]  1ee461ea-f475-4883-8498-1fd50234e642 
## INFO  [22:08:48.377] [bbotk] Evaluating 1 configuration(s) 
## INFO  [22:08:48.395] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:48.402] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:48.421] [mlr3] Finished benchmark 
## INFO  [22:08:48.444] [bbotk] Result of batch 4: 
## INFO  [22:08:48.445] [bbotk]          cp minsplit classif.ce warnings errors runtime_learners 
## INFO  [22:08:48.445] [bbotk]  0.08363938       10  0.2734375        0      0             0.02 
## INFO  [22:08:48.445] [bbotk]                                 uhash 
## INFO  [22:08:48.445] [bbotk]  ec9bc184-d8c9-4d0e-b933-08a7c1deb3c3 
## INFO  [22:08:48.448] [bbotk] Evaluating 1 configuration(s) 
## INFO  [22:08:48.458] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:48.462] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:48.480] [mlr3] Finished benchmark 
## INFO  [22:08:48.504] [bbotk] Result of batch 5: 
## INFO  [22:08:48.505] [bbotk]           cp minsplit classif.ce warnings errors runtime_learners 
## INFO  [22:08:48.505] [bbotk]  0.003366969        4  0.3046875        0      0             0.02 
## INFO  [22:08:48.505] [bbotk]                                 uhash 
## INFO  [22:08:48.505] [bbotk]  ce726596-11d0-4460-8f36-d0f278fe7a08 
## INFO  [22:08:48.507] [bbotk] Evaluating 1 configuration(s) 
## INFO  [22:08:48.517] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:48.521] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:48.536] [mlr3] Finished benchmark 
## INFO  [22:08:48.560] [bbotk] Result of batch 6: 
## INFO  [22:08:48.562] [bbotk]          cp minsplit classif.ce warnings errors runtime_learners 
## INFO  [22:08:48.562] [bbotk]  0.09247955        3  0.3046875        0      0                0 
## INFO  [22:08:48.562] [bbotk]                                 uhash 
## INFO  [22:08:48.562] [bbotk]  f442c1e1-7dd3-428f-8b25-74cde09deaeb 
## INFO  [22:08:48.564] [bbotk] Evaluating 1 configuration(s) 
## INFO  [22:08:48.574] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:48.581] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:48.597] [mlr3] Finished benchmark 
## INFO  [22:08:48.624] [bbotk] Result of batch 7: 
## INFO  [22:08:48.626] [bbotk]          cp minsplit classif.ce warnings errors runtime_learners 
## INFO  [22:08:48.626] [bbotk]  0.04844731        1  0.2734375        0      0                0 
## INFO  [22:08:48.626] [bbotk]                                 uhash 
## INFO  [22:08:48.626] [bbotk]  f5b104c0-96a5-448c-8622-a669cbbe8d2c 
## INFO  [22:08:48.628] [bbotk] Evaluating 1 configuration(s) 
## INFO  [22:08:48.639] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:48.643] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:48.662] [mlr3] Finished benchmark 
## INFO  [22:08:48.686] [bbotk] Result of batch 8: 
## INFO  [22:08:48.687] [bbotk]           cp minsplit classif.ce warnings errors runtime_learners 
## INFO  [22:08:48.687] [bbotk]  0.001663224       10  0.2578125        0      0                0 
## INFO  [22:08:48.687] [bbotk]                                 uhash 
## INFO  [22:08:48.687] [bbotk]  eefc555d-f204-4592-b0fb-c02db2641ba2 
## INFO  [22:08:48.690] [bbotk] Evaluating 1 configuration(s) 
## INFO  [22:08:48.703] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:48.711] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:48.732] [mlr3] Finished benchmark 
## INFO  [22:08:48.760] [bbotk] Result of batch 9: 
## INFO  [22:08:48.762] [bbotk]           cp minsplit classif.ce warnings errors runtime_learners 
## INFO  [22:08:48.762] [bbotk]  0.005261892        5  0.2695312        0      0             0.02 
## INFO  [22:08:48.762] [bbotk]                                 uhash 
## INFO  [22:08:48.762] [bbotk]  fa9d2d9c-efa7-400d-93f3-4f16e1759dc7 
## INFO  [22:08:48.764] [bbotk] Evaluating 1 configuration(s) 
## INFO  [22:08:48.773] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:48.778] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:48.797] [mlr3] Finished benchmark 
## INFO  [22:08:48.821] [bbotk] Result of batch 10: 
## INFO  [22:08:48.823] [bbotk]         cp minsplit classif.ce warnings errors runtime_learners 
## INFO  [22:08:48.823] [bbotk]  0.0542137       10  0.2734375        0      0             0.01 
## INFO  [22:08:48.823] [bbotk]                                 uhash 
## INFO  [22:08:48.823] [bbotk]  4bc71a24-dbeb-4428-86e9-279aa3cdf345 
## INFO  [22:08:48.828] [bbotk] Finished optimizing after 10 evaluation(s) 
## INFO  [22:08:48.829] [bbotk] Result: 
## INFO  [22:08:48.830] [bbotk]           cp minsplit learner_param_vals  x_domain classif.ce 
## INFO  [22:08:48.830] [bbotk]  0.001663224       10          <list[3]> <list[2]>  0.2578125

Tuning Search Spaces

When running an optimization, it is important to inform the tuning algorithm about what hyperparameters are valid.

Creating ParamSets

Note, that ParamSet objects exist in two contexts. First, ParamSet-objects are used to define the space of valid parameter settings for a learner (and other objects). Second, they are used to define a search space for tuning.

ps takes named Domain arguments that are turned into parameters. A possible search space for the “classif.svm” learner could for example be:

search_space = ps(
  cost = p_dbl(lower = 0.1, upper = 10),
  kernel = p_fct(levels = c("polynomial", "radial"))
)
print(search_space)
## <ParamSet>
##        id    class lower upper nlevels        default value
## 1:   cost ParamDbl   0.1    10     Inf <NoDefault[3]>      
## 2: kernel ParamFct    NA    NA       2 <NoDefault[3]>

There are five domain constructors that produce a parameters when given to ps: - p_dbl: Real valued parameter (“double”) - p_int: Integer parameter - p_fct: Discrete valued parameter (“factor”) - p_lgl: Logical / Boolean parameter
- p_uty: Untyped parameter

These domain constructors each take some of the following arguments:

  • lower, upper: lower and upper bound of numerical parameters.
  • levels: Allowed categorical values for p_fct parameters.
  • trafo: transformation function.
  • depends: dependencies.
  • tags: Further information about a parameter.
  • default: Value corresponding to default behavior when the parameter is not given.
  • special_vals: Valid values besides the normally accepted values for a parameter.
  • custom_check: Function that checks whether a value given to p_uty is valid.
search_space = ps(cost = p_dbl(0.1, 10), kernel = p_fct(c("polynomial", "radial")))

Transformations (trafo)

library("data.table")
rbindlist(generate_design_grid(search_space, 3)$transpose())
##     cost     kernel
## 1:  0.10 polynomial
## 2:  0.10     radial
## 3:  5.05 polynomial
## 4:  5.05     radial
## 5: 10.00 polynomial
## 6: 10.00     radial

We see that the cost parameter is taken on a linear scale. We assume, that the difference of cost between 0.1 and 1 should have a similar effect as the difference between 1 and 10. Therefore it makes more sense to tune it on a logarithmic scale. Which can be done by using a transformation (trafo).

 search_space = ps(
  cost = p_dbl(-1, 1, trafo = function(x) 10^x),
  kernel = p_fct(c("polynomial", "radial"))
)
rbindlist(generate_design_grid(search_space, 3)$transpose())
##    cost     kernel
## 1:  0.1 polynomial
## 2:  0.1     radial
## 3:  1.0 polynomial
## 4:  1.0     radial
## 5: 10.0 polynomial
## 6: 10.0     radial

It is even possible to attach another transformation to the ParamSet as a whole that gets executed after individual parameter’s transformations were performed (\(.extra\_trafo\)).

search_space = ps(
  cost = p_dbl(-1, 1, trafo = function(x) 10^x),
  kernel = p_fct(c("polynomial", "radial")),
  .extra_trafo = function(x, param_set) {
    if (x$kernel == "polynomial") {
      x$cost = x$cost * 2
    }
    x
  }
)
rbindlist(generate_design_grid(search_space, 3)$transpose())
##    cost     kernel
## 1:  0.2 polynomial
## 2:  0.1     radial
## 3:  2.0 polynomial
## 4:  1.0     radial
## 5: 20.0 polynomial
## 6: 10.0     radial

Automatic Factor Level Transformation

A common use-case is the necessity to specify a list of values that should all be tried (or sampled from). For this, the p_fct constructor’s level argument may be a value that is not a character vector, but something else. If, for example, only the values 0.1, 3, and 10 should be tried for the cost parameter, even when doing random search.

search_space = ps(
  cost = p_fct(c(0.1, 3, 10)),
  kernel = p_fct(c("polynomial", "radial"))
)
rbindlist(generate_design_grid(search_space, 3)$transpose())
##    cost     kernel
## 1:  0.1 polynomial
## 2:  0.1     radial
## 3:  3.0 polynomial
## 4:  3.0     radial
## 5: 10.0 polynomial
## 6: 10.0     radial

Parameter Dependencies (depends)

Some parameters are only relevant when another parameter has a certain value, or one of several values. The Support Vector Machine (SVM), for example, has the degree parameter that is only valid when kernel is “polynomial” which can be specified with the following:

search_space = ps(
  cost = p_dbl(-1, 1, trafo = function(x) 10^x),
  kernel = p_fct(c("polynomial", "radial")),
  degree = p_int(1, 3, depends = kernel == "polynomial")
)
rbindlist(generate_design_grid(search_space, 3)$transpose(), fill = TRUE)
##     cost     kernel degree
##  1:  0.1 polynomial      1
##  2:  0.1 polynomial      2
##  3:  0.1 polynomial      3
##  4:  0.1     radial     NA
##  5:  1.0 polynomial      1
##  6:  1.0 polynomial      2
##  7:  1.0 polynomial      3
##  8:  1.0     radial     NA
##  9: 10.0 polynomial      1
## 10: 10.0 polynomial      2
## 11: 10.0 polynomial      3
## 12: 10.0     radial     NA

Creating Tuning ParamSets from other ParamSets

There is a way to define a tuning ParamSet for a Learner that already has parameter set information. This is done by setting values of a Learner’s ParamSet to so-called TuneTokens. This can be done in the same way that other hyperparameters are set to specific values. It can be understood as the hyperparameters being tagged for later tuning.

Nested Resampling

It is crucial to have an additional layer of resampling, when we have to do the hyperparameters tuning and/or features selection in our model. Using the same data for both model selection and evaluation might severely bias the performance estimate due to for instance test data leaking information about its structure into the model.

Nested resampling is a statistical procedure that separates the model selection steps from estimating the process estimating the performance of the model.

Fig. 3: Nested resampling scheme with 3-fold cross-validation in the outer resampling and 4-fold cross-validation in the inner resampling

Nested resampling process from above:

  1. Use a 3-fold cross-validation to get different testing and training data sets (outer resampling).
  2. Within the training data use a 4-fold cross-validation to get different inner testing and training data sets (inner resampling).
  3. Tunes the hyperparameters using the inner data splits.
  4. Fit the learner on the outer training data set using the tuned hyperparameter configuration obtained with the inner resampling.
  5. Evaluate the performance of the learner on the outer testing data.
  6. 2-5 is repeated for each of the three folds (outer resampling).
  7. The three performance values are aggregated for an unbiased performance estimate.

Execution

To execute the above algorithm, we will use the AutoTuner to wrap a learner and augment it with an automatic tuning process for a given set of hyperparameters.

learner = lrn("classif.rpart")

# 4-fold cross-validation in the inner resampling loop
resampling = rsmp("cv", folds = 4)

measure = msr("classif.ce")

# hyperparameter configurations are proposed by grid search
search_space = ps(cp = p_dbl(lower = 0.001, upper = 0.1))

# terminator triggered after 5 evaluations
terminator = trm("evals", n_evals = 5)

tuner = tnr("grid_search", resolution = 10)

at = AutoTuner$new(learner, resampling, measure, terminator, tuner, search_space)

On each of the three outer train sets hyperparameter tuning is done and we receive three optimized hyperparameter configurations. To execute the nested resampling, we pass the AutoTuner to the resample() function.

task = tsk("pima")

#3-fold cross-validation in the outer resampling loop
outer_resampling = rsmp("cv", folds = 3)

# store_models = TRUE because we need the AutoTuner models to investigate the inner tuning
rr = resample(task, at, outer_resampling, store_models = TRUE)
## INFO  [22:08:49.108] [mlr3] Applying learner 'classif.rpart.tuned' on task 'pima' (iter 1/3) 
## INFO  [22:08:49.151] [bbotk] Starting to optimize 1 parameter(s) with '<TunerGridSearch>' and '<TerminatorEvals> [n_evals=5, k=0]' 
## INFO  [22:08:49.154] [bbotk] Evaluating 1 configuration(s) 
## INFO  [22:08:49.165] [mlr3] Running benchmark with 4 resampling iterations 
## INFO  [22:08:49.170] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 4/4) 
## INFO  [22:08:49.188] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/4) 
## INFO  [22:08:49.205] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 3/4) 
## INFO  [22:08:49.223] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 2/4) 
## INFO  [22:08:49.242] [mlr3] Finished benchmark 
## INFO  [22:08:49.267] [bbotk] Result of batch 1: 
## INFO  [22:08:49.268] [bbotk]     cp classif.ce warnings errors runtime_learners 
## INFO  [22:08:49.268] [bbotk]  0.089  0.2617188        0      0             0.02 
## INFO  [22:08:49.268] [bbotk]                                 uhash 
## INFO  [22:08:49.268] [bbotk]  16649172-e4dc-4708-b395-67b0c37ebf27 
## INFO  [22:08:49.269] [bbotk] Evaluating 1 configuration(s) 
## INFO  [22:08:49.278] [mlr3] Running benchmark with 4 resampling iterations 
## INFO  [22:08:49.283] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 3/4) 
## INFO  [22:08:49.300] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/4) 
## INFO  [22:08:49.318] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 4/4) 
## INFO  [22:08:49.335] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 2/4) 
## INFO  [22:08:49.352] [mlr3] Finished benchmark 
## INFO  [22:08:49.390] [bbotk] Result of batch 2: 
## INFO  [22:08:49.392] [bbotk]     cp classif.ce warnings errors runtime_learners 
## INFO  [22:08:49.392] [bbotk]  0.034  0.2539062        0      0             0.03 
## INFO  [22:08:49.392] [bbotk]                                 uhash 
## INFO  [22:08:49.392] [bbotk]  e38091de-c504-4961-98c4-f81f587b0389 
## INFO  [22:08:49.393] [bbotk] Evaluating 1 configuration(s) 
## INFO  [22:08:49.402] [mlr3] Running benchmark with 4 resampling iterations 
## INFO  [22:08:49.406] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 4/4) 
## INFO  [22:08:49.423] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/4) 
## INFO  [22:08:49.440] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 3/4) 
## INFO  [22:08:49.458] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 2/4) 
## INFO  [22:08:49.478] [mlr3] Finished benchmark 
## INFO  [22:08:49.505] [bbotk] Result of batch 3: 
## INFO  [22:08:49.507] [bbotk]     cp classif.ce warnings errors runtime_learners 
## INFO  [22:08:49.507] [bbotk]  0.045  0.2636719        0      0             0.02 
## INFO  [22:08:49.507] [bbotk]                                 uhash 
## INFO  [22:08:49.507] [bbotk]  c22c5a56-02a8-423d-bdfb-e20e9268cddf 
## INFO  [22:08:49.508] [bbotk] Evaluating 1 configuration(s) 
## INFO  [22:08:49.517] [mlr3] Running benchmark with 4 resampling iterations 
## INFO  [22:08:49.522] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/4) 
## INFO  [22:08:49.539] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 4/4) 
## INFO  [22:08:49.555] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 2/4) 
## INFO  [22:08:49.572] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 3/4) 
## INFO  [22:08:49.590] [mlr3] Finished benchmark 
## INFO  [22:08:49.621] [bbotk] Result of batch 4: 
## INFO  [22:08:49.622] [bbotk]     cp classif.ce warnings errors runtime_learners 
## INFO  [22:08:49.622] [bbotk]  0.056  0.2617188        0      0             0.03 
## INFO  [22:08:49.622] [bbotk]                                 uhash 
## INFO  [22:08:49.622] [bbotk]  f0ecc73f-1a3e-4610-8649-727539097ad2 
## INFO  [22:08:49.624] [bbotk] Evaluating 1 configuration(s) 
## INFO  [22:08:49.633] [mlr3] Running benchmark with 4 resampling iterations 
## INFO  [22:08:49.638] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 4/4) 
## INFO  [22:08:49.656] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/4) 
## INFO  [22:08:49.676] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 2/4) 
## INFO  [22:08:49.696] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 3/4) 
## INFO  [22:08:49.714] [mlr3] Finished benchmark 
## INFO  [22:08:49.741] [bbotk] Result of batch 5: 
## INFO  [22:08:49.742] [bbotk]     cp classif.ce warnings errors runtime_learners 
## INFO  [22:08:49.742] [bbotk]  0.001  0.2617188        0      0             0.02 
## INFO  [22:08:49.742] [bbotk]                                 uhash 
## INFO  [22:08:49.742] [bbotk]  c256749c-c916-4519-87e4-0eb98c8d7ad0 
## INFO  [22:08:49.745] [bbotk] Finished optimizing after 5 evaluation(s) 
## INFO  [22:08:49.746] [bbotk] Result: 
## INFO  [22:08:49.747] [bbotk]     cp learner_param_vals  x_domain classif.ce 
## INFO  [22:08:49.747] [bbotk]  0.034          <list[2]> <list[1]>  0.2539062 
## INFO  [22:08:49.780] [mlr3] Applying learner 'classif.rpart.tuned' on task 'pima' (iter 3/3) 
## INFO  [22:08:49.811] [bbotk] Starting to optimize 1 parameter(s) with '<TunerGridSearch>' and '<TerminatorEvals> [n_evals=5, k=0]' 
## INFO  [22:08:49.813] [bbotk] Evaluating 1 configuration(s) 
## INFO  [22:08:49.822] [mlr3] Running benchmark with 4 resampling iterations 
## INFO  [22:08:49.826] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/4) 
## INFO  [22:08:49.846] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 3/4) 
## INFO  [22:08:49.868] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 2/4) 
## INFO  [22:08:49.885] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 4/4) 
## INFO  [22:08:49.906] [mlr3] Finished benchmark 
## INFO  [22:08:49.930] [bbotk] Result of batch 1: 
## INFO  [22:08:49.931] [bbotk]     cp classif.ce warnings errors runtime_learners 
## INFO  [22:08:49.931] [bbotk]  0.034  0.2265625        0      0             0.01 
## INFO  [22:08:49.931] [bbotk]                                 uhash 
## INFO  [22:08:49.931] [bbotk]  6c4cd602-d4f1-4cf8-a680-5ef7596d818a 
## INFO  [22:08:49.932] [bbotk] Evaluating 1 configuration(s) 
## INFO  [22:08:49.941] [mlr3] Running benchmark with 4 resampling iterations 
## INFO  [22:08:49.945] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 3/4) 
## INFO  [22:08:49.962] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 2/4) 
## INFO  [22:08:49.980] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 4/4) 
## INFO  [22:08:49.997] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/4) 
## INFO  [22:08:50.019] [mlr3] Finished benchmark 
## INFO  [22:08:50.049] [bbotk] Result of batch 2: 
## INFO  [22:08:50.050] [bbotk]     cp classif.ce warnings errors runtime_learners 
## INFO  [22:08:50.050] [bbotk]  0.001  0.2304688        0      0             0.06 
## INFO  [22:08:50.050] [bbotk]                                 uhash 
## INFO  [22:08:50.050] [bbotk]  563187f9-d7ab-4a47-b52a-b80a3dfd53aa 
## INFO  [22:08:50.051] [bbotk] Evaluating 1 configuration(s) 
## INFO  [22:08:50.059] [mlr3] Running benchmark with 4 resampling iterations 
## INFO  [22:08:50.064] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 4/4) 
## INFO  [22:08:50.081] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 2/4) 
## INFO  [22:08:50.098] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 3/4) 
## INFO  [22:08:50.115] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/4) 
## INFO  [22:08:50.134] [mlr3] Finished benchmark 
## INFO  [22:08:50.162] [bbotk] Result of batch 3: 
## INFO  [22:08:50.163] [bbotk]     cp classif.ce warnings errors runtime_learners 
## INFO  [22:08:50.163] [bbotk]  0.012  0.2226562        0      0             0.03 
## INFO  [22:08:50.163] [bbotk]                                 uhash 
## INFO  [22:08:50.163] [bbotk]  b78f46bf-33c4-450c-8339-8186de327e50 
## INFO  [22:08:50.164] [bbotk] Evaluating 1 configuration(s) 
## INFO  [22:08:50.174] [mlr3] Running benchmark with 4 resampling iterations 
## INFO  [22:08:50.179] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 4/4) 
## INFO  [22:08:50.202] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 3/4) 
## INFO  [22:08:50.220] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 2/4) 
## INFO  [22:08:50.237] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/4) 
## INFO  [22:08:50.255] [mlr3] Finished benchmark 
## INFO  [22:08:50.283] [bbotk] Result of batch 4: 
## INFO  [22:08:50.285] [bbotk]     cp classif.ce warnings errors runtime_learners 
## INFO  [22:08:50.285] [bbotk]  0.067  0.2324219        0      0             0.03 
## INFO  [22:08:50.285] [bbotk]                                 uhash 
## INFO  [22:08:50.285] [bbotk]  226ac13a-be0e-44ca-9896-6eee9fabc3c8 
## INFO  [22:08:50.286] [bbotk] Evaluating 1 configuration(s) 
## INFO  [22:08:50.297] [mlr3] Running benchmark with 4 resampling iterations 
## INFO  [22:08:50.305] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 4/4) 
## INFO  [22:08:50.328] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 3/4) 
## INFO  [22:08:50.345] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 2/4) 
## INFO  [22:08:50.361] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/4) 
## INFO  [22:08:50.378] [mlr3] Finished benchmark 
## INFO  [22:08:50.405] [bbotk] Result of batch 5: 
## INFO  [22:08:50.407] [bbotk]     cp classif.ce warnings errors runtime_learners 
## INFO  [22:08:50.407] [bbotk]  0.078  0.2324219        0      0             0.03 
## INFO  [22:08:50.407] [bbotk]                                 uhash 
## INFO  [22:08:50.407] [bbotk]  e216461b-1fb6-4644-a6b9-9291007fca55 
## INFO  [22:08:50.411] [bbotk] Finished optimizing after 5 evaluation(s) 
## INFO  [22:08:50.411] [bbotk] Result: 
## INFO  [22:08:50.412] [bbotk]     cp learner_param_vals  x_domain classif.ce 
## INFO  [22:08:50.412] [bbotk]  0.012          <list[2]> <list[1]>  0.2226562 
## INFO  [22:08:50.444] [mlr3] Applying learner 'classif.rpart.tuned' on task 'pima' (iter 2/3) 
## INFO  [22:08:50.475] [bbotk] Starting to optimize 1 parameter(s) with '<TunerGridSearch>' and '<TerminatorEvals> [n_evals=5, k=0]' 
## INFO  [22:08:50.477] [bbotk] Evaluating 1 configuration(s) 
## INFO  [22:08:50.486] [mlr3] Running benchmark with 4 resampling iterations 
## INFO  [22:08:50.490] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 2/4) 
## INFO  [22:08:50.509] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 4/4) 
## INFO  [22:08:50.533] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/4) 
## INFO  [22:08:50.550] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 3/4) 
## INFO  [22:08:50.568] [mlr3] Finished benchmark 
## INFO  [22:08:50.593] [bbotk] Result of batch 1: 
## INFO  [22:08:50.594] [bbotk]     cp classif.ce warnings errors runtime_learners 
## INFO  [22:08:50.594] [bbotk]  0.045  0.2675781        0      0             0.02 
## INFO  [22:08:50.594] [bbotk]                                 uhash 
## INFO  [22:08:50.594] [bbotk]  ecd9dd73-6daa-4fda-9e40-e2632ea323f0 
## INFO  [22:08:50.595] [bbotk] Evaluating 1 configuration(s) 
## INFO  [22:08:50.604] [mlr3] Running benchmark with 4 resampling iterations 
## INFO  [22:08:50.609] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 4/4) 
## INFO  [22:08:50.634] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 2/4) 
## INFO  [22:08:50.654] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/4) 
## INFO  [22:08:50.673] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 3/4) 
## INFO  [22:08:50.691] [mlr3] Finished benchmark 
## INFO  [22:08:50.717] [bbotk] Result of batch 2: 
## INFO  [22:08:50.718] [bbotk]     cp classif.ce warnings errors runtime_learners 
## INFO  [22:08:50.718] [bbotk]  0.056  0.2792969        0      0             0.02 
## INFO  [22:08:50.718] [bbotk]                                 uhash 
## INFO  [22:08:50.718] [bbotk]  c2866cc7-6373-4de8-844e-4f760917e25e 
## INFO  [22:08:50.719] [bbotk] Evaluating 1 configuration(s) 
## INFO  [22:08:50.728] [mlr3] Running benchmark with 4 resampling iterations 
## INFO  [22:08:50.734] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/4) 
## INFO  [22:08:50.756] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 4/4) 
## INFO  [22:08:50.775] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 2/4) 
## INFO  [22:08:50.793] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 3/4) 
## INFO  [22:08:50.812] [mlr3] Finished benchmark 
## INFO  [22:08:50.839] [bbotk] Result of batch 3: 
## INFO  [22:08:50.840] [bbotk]   cp classif.ce warnings errors runtime_learners 
## INFO  [22:08:50.840] [bbotk]  0.1  0.2578125        0      0             0.03 
## INFO  [22:08:50.840] [bbotk]                                 uhash 
## INFO  [22:08:50.840] [bbotk]  e82a9fcd-ca94-4e5a-b8a7-a83b1f9ae207 
## INFO  [22:08:50.841] [bbotk] Evaluating 1 configuration(s) 
## INFO  [22:08:50.850] [mlr3] Running benchmark with 4 resampling iterations 
## INFO  [22:08:50.855] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 2/4) 
## INFO  [22:08:50.873] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 4/4) 
## INFO  [22:08:50.895] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/4) 
## INFO  [22:08:50.913] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 3/4) 
## INFO  [22:08:50.932] [mlr3] Finished benchmark 
## INFO  [22:08:50.964] [bbotk] Result of batch 4: 
## INFO  [22:08:50.965] [bbotk]     cp classif.ce warnings errors runtime_learners 
## INFO  [22:08:50.965] [bbotk]  0.012  0.2539062        0      0             0.04 
## INFO  [22:08:50.965] [bbotk]                                 uhash 
## INFO  [22:08:50.965] [bbotk]  3a85d117-16b2-4f03-9c1f-26a2adece9f0 
## INFO  [22:08:50.966] [bbotk] Evaluating 1 configuration(s) 
## INFO  [22:08:50.975] [mlr3] Running benchmark with 4 resampling iterations 
## INFO  [22:08:50.980] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 3/4) 
## INFO  [22:08:51.004] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 2/4) 
## INFO  [22:08:51.021] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 4/4) 
## INFO  [22:08:51.039] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/4) 
## INFO  [22:08:51.059] [mlr3] Finished benchmark 
## INFO  [22:08:51.083] [bbotk] Result of batch 5: 
## INFO  [22:08:51.084] [bbotk]     cp classif.ce warnings errors runtime_learners 
## INFO  [22:08:51.084] [bbotk]  0.089  0.2578125        0      0             0.04 
## INFO  [22:08:51.084] [bbotk]                                 uhash 
## INFO  [22:08:51.084] [bbotk]  cb4403b2-774b-4101-a896-c56ac43fc7ec 
## INFO  [22:08:51.088] [bbotk] Finished optimizing after 5 evaluation(s) 
## INFO  [22:08:51.088] [bbotk] Result: 
## INFO  [22:08:51.089] [bbotk]     cp learner_param_vals  x_domain classif.ce 
## INFO  [22:08:51.089] [bbotk]  0.012          <list[2]> <list[1]>  0.2539062

Nested resampling is not restricted to hyperparameter tuning. One can swap the AutoTuner for a AutoFSelector and estimate the performance of a model which is fitted on an optimized feature subset.

Evaluation

With the created ResampleResult we can now inspect the executed resampling iterations.

We check the inner tuning results for stable hyperparameters. Unstable models might be observed in case of the small data set and/or the low number of resampling iterations.

extract_inner_tuning_results(rr)
##    iteration    cp classif.ce learner_param_vals  x_domain task_id
## 1:         1 0.034  0.2539062          <list[2]> <list[1]>    pima
## 2:         2 0.012  0.2539062          <list[2]> <list[1]>    pima
## 3:         3 0.012  0.2226562          <list[2]> <list[1]>    pima
##             learner_id resampling_id
## 1: classif.rpart.tuned            cv
## 2: classif.rpart.tuned            cv
## 3: classif.rpart.tuned            cv

Next, compare the predictive performances estimated on the outer resampling to the inner resampling. Significantly lower predictive performances on the outer resampling indicate that the models with the optimized hyperparameters overfit the data.

rr$score()
##                 task task_id         learner          learner_id
## 1: <TaskClassif[50]>    pima <AutoTuner[42]> classif.rpart.tuned
## 2: <TaskClassif[50]>    pima <AutoTuner[42]> classif.rpart.tuned
## 3: <TaskClassif[50]>    pima <AutoTuner[42]> classif.rpart.tuned
##            resampling resampling_id iteration              prediction
## 1: <ResamplingCV[20]>            cv         1 <PredictionClassif[20]>
## 2: <ResamplingCV[20]>            cv         2 <PredictionClassif[20]>
## 3: <ResamplingCV[20]>            cv         3 <PredictionClassif[20]>
##    classif.ce
## 1:  0.2773438
## 2:  0.2304688
## 3:  0.2304688

The aggregated performance of all outer resampling iterations is essentially the unbiased performance of the model with optimal hyperparameter found by grid search.

rr$aggregate()
## classif.ce 
##  0.2460938

Note: nested resampling is computationally expensive.

Final Model

at$train(task)
## INFO  [22:08:51.274] [bbotk] Starting to optimize 1 parameter(s) with '<TunerGridSearch>' and '<TerminatorEvals> [n_evals=5, k=0]' 
## INFO  [22:08:51.276] [bbotk] Evaluating 1 configuration(s) 
## INFO  [22:08:51.286] [mlr3] Running benchmark with 4 resampling iterations 
## INFO  [22:08:51.294] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/4) 
## INFO  [22:08:51.313] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 3/4) 
## INFO  [22:08:51.335] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 4/4) 
## INFO  [22:08:51.356] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 2/4) 
## INFO  [22:08:51.377] [mlr3] Finished benchmark 
## INFO  [22:08:51.409] [bbotk] Result of batch 1: 
## INFO  [22:08:51.411] [bbotk]     cp classif.ce warnings errors runtime_learners 
## INFO  [22:08:51.411] [bbotk]  0.045  0.2604167        0      0             0.04 
## INFO  [22:08:51.411] [bbotk]                                 uhash 
## INFO  [22:08:51.411] [bbotk]  ef4b8462-df37-432e-a023-ca9461420990 
## INFO  [22:08:51.412] [bbotk] Evaluating 1 configuration(s) 
## INFO  [22:08:51.424] [mlr3] Running benchmark with 4 resampling iterations 
## INFO  [22:08:51.429] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/4) 
## INFO  [22:08:51.447] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 2/4) 
## INFO  [22:08:51.464] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 3/4) 
## INFO  [22:08:51.482] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 4/4) 
## INFO  [22:08:51.501] [mlr3] Finished benchmark 
## INFO  [22:08:51.526] [bbotk] Result of batch 2: 
## INFO  [22:08:51.528] [bbotk]     cp classif.ce warnings errors runtime_learners 
## INFO  [22:08:51.528] [bbotk]  0.034  0.2604167        0      0             0.04 
## INFO  [22:08:51.528] [bbotk]                                 uhash 
## INFO  [22:08:51.528] [bbotk]  ff3f22ec-75f8-47da-a1bf-2fd493178c45 
## INFO  [22:08:51.529] [bbotk] Evaluating 1 configuration(s) 
## INFO  [22:08:51.538] [mlr3] Running benchmark with 4 resampling iterations 
## INFO  [22:08:51.542] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 4/4) 
## INFO  [22:08:51.561] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 2/4) 
## INFO  [22:08:51.579] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/4) 
## INFO  [22:08:51.598] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 3/4) 
## INFO  [22:08:51.616] [mlr3] Finished benchmark 
## INFO  [22:08:51.647] [bbotk] Result of batch 3: 
## INFO  [22:08:51.649] [bbotk]     cp classif.ce warnings errors runtime_learners 
## INFO  [22:08:51.649] [bbotk]  0.001  0.2786458        0      0             0.04 
## INFO  [22:08:51.649] [bbotk]                                 uhash 
## INFO  [22:08:51.649] [bbotk]  ba037a87-e20f-4e74-8949-ab7d6b67329c 
## INFO  [22:08:51.650] [bbotk] Evaluating 1 configuration(s) 
## INFO  [22:08:51.659] [mlr3] Running benchmark with 4 resampling iterations 
## INFO  [22:08:51.663] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/4) 
## INFO  [22:08:51.681] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 2/4) 
## INFO  [22:08:51.700] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 3/4) 
## INFO  [22:08:51.718] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 4/4) 
## INFO  [22:08:51.736] [mlr3] Finished benchmark 
## INFO  [22:08:51.769] [bbotk] Result of batch 4: 
## INFO  [22:08:51.770] [bbotk]     cp classif.ce warnings errors runtime_learners 
## INFO  [22:08:51.770] [bbotk]  0.023  0.2617188        0      0             0.06 
## INFO  [22:08:51.770] [bbotk]                                 uhash 
## INFO  [22:08:51.770] [bbotk]  e72b4b73-bcb2-40df-8a38-f63a1f9bb849 
## INFO  [22:08:51.771] [bbotk] Evaluating 1 configuration(s) 
## INFO  [22:08:51.782] [mlr3] Running benchmark with 4 resampling iterations 
## INFO  [22:08:51.786] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 4/4) 
## INFO  [22:08:51.809] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 3/4) 
## INFO  [22:08:51.828] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/4) 
## INFO  [22:08:51.845] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 2/4) 
## INFO  [22:08:51.863] [mlr3] Finished benchmark 
## INFO  [22:08:51.889] [bbotk] Result of batch 5: 
## INFO  [22:08:51.891] [bbotk]     cp classif.ce warnings errors runtime_learners 
## INFO  [22:08:51.891] [bbotk]  0.056  0.2604167        0      0             0.03 
## INFO  [22:08:51.891] [bbotk]                                 uhash 
## INFO  [22:08:51.891] [bbotk]  ed3d52de-c3ba-4743-9f0b-692b34f5d515 
## INFO  [22:08:51.894] [bbotk] Finished optimizing after 5 evaluation(s) 
## INFO  [22:08:51.895] [bbotk] Result: 
## INFO  [22:08:51.896] [bbotk]     cp learner_param_vals  x_domain classif.ce 
## INFO  [22:08:51.896] [bbotk]  0.045          <list[2]> <list[1]>  0.2604167

The trained model can now be used to make predictions on new data.

Tuning with Hyperband

Besides the more traditional tuning methods, the ecosystem around mlr3 offers another procedure for hyperparameter optimization called Hyperband implemented in the mlr3hyperband package.

Fig. 4: Visulization of how the training processes may look like. The left plot corresponds to the non-selective trainer, while the right one to the selective trainer.

Hyperband is a budget-oriented procedure, weeding out suboptimal performing configurations early on during a partially sequential training process, increasing tuning efficiency as a consequence. For this, a combination of incremental resource allocation and early stopping is used: As optimization progresses, computational resources are increased for more promising configurations, while less promising ones are terminated early.

Hyperband:

  • Consists of several brackets.
  • Each bracket is placed at a unique spot between fully explorative of later training stages and extremely selective, equal to higher exploration of early training stages.

Thanks to the broad ecosystem of the mlr3verse a learner does not require a natural budget parameter.

set.seed(52)

# extend "classif.rpart" with "subsampling" as preprocessing step
ll = po("subsample") %>>% lrn("classif.rpart")

# extend hyperparameters of "classif.rpart" with subsampling fraction as budget
search_space = ps(
  classif.rpart.cp = p_dbl(lower = 0.001, upper = 0.1),
  classif.rpart.minsplit = p_int(lower = 1, upper = 10),
  subsample.frac = p_dbl(lower = 0.1, upper = 1, tags = "budget")
)

Now plug the new learner with the extended hyperparameter set into a TuningInstanceSingleCrit the same way as usual. Hyperband terminates once all of its brackets are evaluated, so a Terminator in the tuning instance acts as an upper bound and should be only set to a low value if one is unsure of how long Hyperband will take to finish.

instance = TuningInstanceSingleCrit$new(
  task = tsk("iris"),
  learner = ll,
  resampling = rsmp("holdout"),
  measure = msr("classif.ce"),
  terminator = trm("none"), # hyperband terminates itself
  search_space = search_space
)

Initialize a new instance of the mlr3hyperband::mlr_tuners_hyperband class and start tuning. Every stage the configuration budget is increased by a factor of eta and only the best 1/eta points are used for the next stage.

library("mlr3hyperband")
## Warning: package 'mlr3hyperband' was built under R version 4.1.3
## Loading required package: mlr3tuning
## Loading required package: paradox
## Warning: package 'paradox' was built under R version 4.1.3
tuner = tnr("hyperband", eta = 3)

# reduce logging output
lgr::get_logger("bbotk")$set_threshold("warn")

tuner$optimize(instance)
## INFO  [22:08:52.323] [mlr3] Running benchmark with 9 resampling iterations 
## INFO  [22:08:52.328] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1) 
## INFO  [22:08:52.385] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1) 
## INFO  [22:08:52.443] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1) 
## INFO  [22:08:52.504] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1) 
## INFO  [22:08:52.561] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1) 
## INFO  [22:08:52.622] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1) 
## INFO  [22:08:52.676] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1) 
## INFO  [22:08:52.725] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1) 
## INFO  [22:08:52.784] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1) 
## INFO  [22:08:52.885] [mlr3] Finished benchmark 
## INFO  [22:08:53.098] [mlr3] Running benchmark with 8 resampling iterations 
## INFO  [22:08:53.102] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1) 
## INFO  [22:08:53.156] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1) 
## INFO  [22:08:53.227] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1) 
## INFO  [22:08:53.279] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1) 
## INFO  [22:08:53.330] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1) 
## INFO  [22:08:53.386] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1) 
## INFO  [22:08:53.438] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1) 
## INFO  [22:08:53.493] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1) 
## INFO  [22:08:53.555] [mlr3] Finished benchmark 
## INFO  [22:08:53.721] [mlr3] Running benchmark with 5 resampling iterations 
## INFO  [22:08:53.726] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1) 
## INFO  [22:08:53.781] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1) 
## INFO  [22:08:53.845] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1) 
## INFO  [22:08:53.896] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1) 
## INFO  [22:08:53.946] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1) 
## INFO  [22:08:54.004] [mlr3] Finished benchmark
##    classif.rpart.cp classif.rpart.minsplit subsample.frac learner_param_vals
## 1:       0.08171096                      3      0.3333333          <list[6]>
##     x_domain classif.ce
## 1: <list[3]>          0

Access the best found configuration through the instance object

instance$result
##    classif.rpart.cp classif.rpart.minsplit subsample.frac learner_param_vals
## 1:       0.08171096                      3      0.3333333          <list[6]>
##     x_domain classif.ce
## 1: <list[3]>          0
instance$result_learner_param_vals
## $subsample.frac
## [1] 0.3333333
## 
## $subsample.stratify
## [1] FALSE
## 
## $subsample.replace
## [1] FALSE
## 
## $classif.rpart.xval
## [1] 0
## 
## $classif.rpart.cp
## [1] 0.08171096
## 
## $classif.rpart.minsplit
## [1] 3
instance$result_y
## classif.ce 
##          0

Feature Selection / Filtering

Two different approaches are emphasized in the literature: one is called Filtering and the other approach is often referred to as feature subset selection or wrapper methods.

Filtering:

  1. Computes a rank of the features (e.g. based on the correlation to the response).
  2. Features are subsetted by a certain criteria (e.g. an absolute number or a percentage of the number of variables).
  3. The selected features are then be used to fit a model (with optional hyperparameters selected by tuning).

This calculation is usually cheaper than “feature subset selection” in terms of computation time. All filters are connected via package mlr3filters.

Wrapper Methods:

  1. Selects a subset of the features.
  2. Evaluates the set by calculating the resampled predictive performance.
  3. Proposes a new set of features (or terminates).

A simple example - sequential forward selection. This is usually computationally very intensive: a lot of models are fitted and all models need to be tuned before the performance is estimated.

Embedded Methods: Many learners internally select a subset of the features which they find helpful for prediction. These subsets can usually be queried:

library("mlr3verse")

task = tsk("iris")
learner = lrn("classif.rpart")

# ensure that the learner selects features
stopifnot("selected_features" %in% learner$properties)

# fit a simple classification tree
learner = learner$train(task)

# extract all features used in the classification tree:
learner$selected_features()
## [1] "Petal.Length" "Petal.Width"

Filters

Currently, only classification and regression tasks are supported.

First step - create a new R object using the class of the desired filter method. Each object of class Filter has a .$calculate() method which computes the filter values and ranks them in a descending order.

filter = flt("jmim")

task = tsk("iris")
filter$calculate(task)

as.data.table(filter)
##         feature     score
## 1:  Petal.Width 1.0000000
## 2: Sepal.Length 0.6666667
## 3: Petal.Length 0.3333333
## 4:  Sepal.Width 0.0000000

We can also change the hyperparameters of the filter.

filter_cor = flt("correlation")
filter_cor$param_set
## <ParamSet>
##        id    class lower upper nlevels    default value
## 1:    use ParamFct    NA    NA       5 everything      
## 2: method ParamFct    NA    NA       3    pearson
# change parameter 'method'
filter_cor$param_set$values = list(method = "spearman")
filter_cor$param_set
## <ParamSet>
##        id    class lower upper nlevels    default    value
## 1:    use ParamFct    NA    NA       5 everything         
## 2: method ParamFct    NA    NA       3    pearson spearman

Wrapper Methods

Wrapper feature selection is supported via the mlr3fselect extension package. At the heart of mlr3fselect are the R6 classes:

  • FSelectInstanceSingleCrit: describes the feature selection problem and store the results.
  • FSelector base class for implementations of feature selection algorithms.
task = tsk("pima")
print(task)
## <TaskClassif:pima> (768 x 9): Pima Indian Diabetes
## * Target: diabetes
## * Properties: twoclass
## * Features (8):
##   - dbl (8): age, glucose, insulin, mass, pedigree, pregnant, pressure,
##     triceps
learner = lrn("classif.rpart")

Choose a resampling strategy and a performance measure.

hout = rsmp("holdout")
measure = msr("classif.ce")

Choose the available budget for the feature selection. This is done by selecting one of the available Terminators:

  • Terminate after a given time (TerminatorClockTime)
  • Terminate after a given amount of iterations (TerminatorEvals)
  • Terminate after a specific performance is reached (TerminatorPerfReached)
  • Terminate when feature selection does not improve (TerminatorStagnation)
  • A combination of the above in an ALL or ANY fashion (TerminatorCombo)
evals20 = trm("evals", n_evals = 20)

instance = FSelectInstanceSingleCrit$new(
  task = task,
  learner = learner,
  resampling = hout,
  measure = measure,
  terminator = evals20
)

instance
## <FSelectInstanceSingleCrit>
## * State:  Not optimized
## * Objective: <ObjectiveFSelect:classif.rpart_on_pima>
## * Search Space:
##          id    class lower upper nlevels
## 1:      age ParamLgl    NA    NA       2
## 2:  glucose ParamLgl    NA    NA       2
## 3:  insulin ParamLgl    NA    NA       2
## 4:     mass ParamLgl    NA    NA       2
## 5: pedigree ParamLgl    NA    NA       2
## 6: pregnant ParamLgl    NA    NA       2
## 7: pressure ParamLgl    NA    NA       2
## 8:  triceps ParamLgl    NA    NA       2
## * Terminator: <TerminatorEvals>

To start the feature selection, we still need to select an algorithm which are defined via the FSelector class.

The following algorithms are currently implemented in mlr3fselect:

  • Random Search (FSelectorRandomSearch)
  • Exhaustive Search (FSelectorExhaustiveSearch)
  • Sequential Search (FSelectorSequential)
  • Recursive Feature Elimination (FSelectorRFE)
  • Design Points (FSelectorDesignPoints)
fselector = fs("random_search")

To start the feature selection, we simply pass the FSelectInstanceSingleCrit to the $optimize() method of the initialized FSelector.

# reduce logging output
lgr::get_logger("bbotk")$set_threshold("warn")

fselector$optimize(instance)
## INFO  [22:08:54.424] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:54.429] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:54.492] [mlr3] Finished benchmark 
## INFO  [22:08:54.587] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:54.594] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:54.656] [mlr3] Finished benchmark 
## INFO  [22:08:54.742] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:54.746] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:54.809] [mlr3] Finished benchmark 
## INFO  [22:08:54.907] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:54.911] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:54.972] [mlr3] Finished benchmark 
## INFO  [22:08:55.063] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:55.067] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:55.133] [mlr3] Finished benchmark 
## INFO  [22:08:55.255] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:55.260] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:55.323] [mlr3] Finished benchmark 
## INFO  [22:08:55.411] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:55.415] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:55.475] [mlr3] Finished benchmark 
## INFO  [22:08:55.575] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:55.579] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:55.640] [mlr3] Finished benchmark 
## INFO  [22:08:55.728] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:55.732] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:55.794] [mlr3] Finished benchmark 
## INFO  [22:08:55.894] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:55.899] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:55.960] [mlr3] Finished benchmark 
## INFO  [22:08:56.048] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:56.053] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:56.115] [mlr3] Finished benchmark 
## INFO  [22:08:56.229] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:56.233] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:56.294] [mlr3] Finished benchmark 
## INFO  [22:08:56.381] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:56.385] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:56.459] [mlr3] Finished benchmark 
## INFO  [22:08:56.553] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:56.559] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:56.624] [mlr3] Finished benchmark 
## INFO  [22:08:56.713] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:56.718] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:56.777] [mlr3] Finished benchmark 
## INFO  [22:08:56.878] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:56.883] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:56.942] [mlr3] Finished benchmark 
## INFO  [22:08:57.034] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:57.038] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:57.117] [mlr3] Finished benchmark 
## INFO  [22:08:57.211] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:57.215] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:57.275] [mlr3] Finished benchmark 
## INFO  [22:08:57.367] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:57.371] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:57.447] [mlr3] Finished benchmark 
## INFO  [22:08:57.530] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:57.535] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:57.598] [mlr3] Finished benchmark
##     age glucose insulin mass pedigree pregnant pressure triceps
## 1: TRUE    TRUE    TRUE TRUE     TRUE     TRUE     TRUE   FALSE
##                                          features classif.ce
## 1: age,glucose,insulin,mass,pedigree,pregnant,...   0.265625
instance$result_feature_set
## [1] "age"      "glucose"  "insulin"  "mass"     "pedigree" "pregnant" "pressure"
instance$result_feature_set
## [1] "age"      "glucose"  "insulin"  "mass"     "pedigree" "pregnant" "pressure"
instance$result_y
## classif.ce 
##   0.265625

Automating the Feature Selection

The AutoFSelector wraps a learner and augments it with an automatic feature selection for a given task. Analogously to the previous subsection, a new classification tree learner is created. This classification tree learner automatically starts a feature selection on the given task using an inner resampling (holdout).

learner = lrn("classif.rpart")
terminator = trm("evals", n_evals = 10)
fselector = fs("random_search")

at = AutoFSelector$new(
  learner = learner,
  resampling = rsmp("holdout"),
  measure = msr("classif.ce"),
  terminator = terminator,
  fselector = fselector
)
at
## <AutoFSelector:classif.rpart.fselector>
## * Model: -
## * Parameters: xval=0
## * Packages: mlr3, mlr3fselect, rpart
## * Predict Type: response
## * Feature types: logical, integer, numeric, factor, ordered
## * Properties: importance, missings, multiclass, selected_features,
##   twoclass, weights

We can now use the learner like any other learner, calling the \(train()** and **\)predict() method. This time however, we pass it to benchmark() to compare the optimized feature subset to the complete feature set. This way, the AutoFSelector will do nested resampling for feature selection.

To compare the optimized feature subset with the complete feature set, we can use benchmark():

grid = benchmark_grid(
  task = tsk("pima"),
  learner = list(at, lrn("classif.rpart")),
  resampling = rsmp("cv", folds = 3)
)

bmr = benchmark(grid, store_models = TRUE)
## INFO  [22:08:57.797] [mlr3] Running benchmark with 6 resampling iterations 
## INFO  [22:08:57.803] [mlr3] Applying learner 'classif.rpart.fselector' on task 'pima' (iter 1/3) 
## INFO  [22:08:57.889] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:57.893] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:57.958] [mlr3] Finished benchmark 
## INFO  [22:08:58.059] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:58.064] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:58.124] [mlr3] Finished benchmark 
## INFO  [22:08:58.234] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:58.239] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:58.311] [mlr3] Finished benchmark 
## INFO  [22:08:58.420] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:58.425] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:58.508] [mlr3] Finished benchmark 
## INFO  [22:08:58.597] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:58.602] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:58.661] [mlr3] Finished benchmark 
## INFO  [22:08:58.756] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:58.764] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:58.841] [mlr3] Finished benchmark 
## INFO  [22:08:58.935] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:58.939] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:59.003] [mlr3] Finished benchmark 
## INFO  [22:08:59.103] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:59.109] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:59.169] [mlr3] Finished benchmark 
## INFO  [22:08:59.256] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:59.261] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:59.324] [mlr3] Finished benchmark 
## INFO  [22:08:59.436] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:59.442] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:59.505] [mlr3] Finished benchmark 
## INFO  [22:08:59.578] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 3/3) 
## INFO  [22:08:59.597] [mlr3] Applying learner 'classif.rpart.fselector' on task 'pima' (iter 2/3) 
## INFO  [22:08:59.687] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:59.693] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:59.760] [mlr3] Finished benchmark 
## INFO  [22:08:59.842] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:08:59.846] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:08:59.909] [mlr3] Finished benchmark 
## INFO  [22:08:59.995] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:09:00.000] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:09:00.087] [mlr3] Finished benchmark 
## INFO  [22:09:00.185] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:09:00.189] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:09:00.251] [mlr3] Finished benchmark 
## INFO  [22:09:00.563] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:09:00.567] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:09:00.639] [mlr3] Finished benchmark 
## INFO  [22:09:00.732] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:09:00.736] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:09:00.798] [mlr3] Finished benchmark 
## INFO  [22:09:00.891] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:09:00.897] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:09:00.967] [mlr3] Finished benchmark 
## INFO  [22:09:01.058] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:09:01.063] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:09:01.123] [mlr3] Finished benchmark 
## INFO  [22:09:01.215] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:09:01.220] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:09:01.284] [mlr3] Finished benchmark 
## INFO  [22:09:01.370] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:09:01.375] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:09:01.442] [mlr3] Finished benchmark 
## INFO  [22:09:01.509] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 2/3) 
## INFO  [22:09:01.535] [mlr3] Applying learner 'classif.rpart.fselector' on task 'pima' (iter 3/3) 
## INFO  [22:09:01.628] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:09:01.635] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:09:01.697] [mlr3] Finished benchmark 
## INFO  [22:09:01.781] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:09:01.786] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:09:01.854] [mlr3] Finished benchmark 
## INFO  [22:09:01.939] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:09:01.943] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:09:02.003] [mlr3] Finished benchmark 
## INFO  [22:09:02.091] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:09:02.096] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:09:02.159] [mlr3] Finished benchmark 
## INFO  [22:09:02.244] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:09:02.254] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:09:02.315] [mlr3] Finished benchmark 
## INFO  [22:09:02.400] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:09:02.405] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:09:02.462] [mlr3] Finished benchmark 
## INFO  [22:09:02.560] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:09:02.564] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:09:02.626] [mlr3] Finished benchmark 
## INFO  [22:09:02.720] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:09:02.724] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:09:02.796] [mlr3] Finished benchmark 
## INFO  [22:09:02.881] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:09:02.886] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:09:02.957] [mlr3] Finished benchmark 
## INFO  [22:09:03.040] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:09:03.045] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1) 
## INFO  [22:09:03.112] [mlr3] Finished benchmark 
## INFO  [22:09:03.186] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/3) 
## INFO  [22:09:03.207] [mlr3] Finished benchmark
bmr$aggregate(msrs(c("classif.ce", "time_train")))
##    nr      resample_result task_id              learner_id resampling_id iters
## 1:  1 <ResampleResult[22]>    pima classif.rpart.fselector            cv     3
## 2:  2 <ResampleResult[22]>    pima           classif.rpart            cv     3
##    classif.ce time_train
## 1:  0.2721354          0
## 2:  0.2760417          0

Note that we do not expect any significant differences since we only evaluated a small fraction of the possible feature subsets.

Good-ole German dataset example

set.seed(47)
library(patchwork)
load(url("https://statmath.wu.ac.at/~hornik/DTM/german.rda"))

# Creating task for german dataset
task <- as_task_classif(german, target = "Class", positive = "good")
task$feature_names
##  [1] "Age"                         "Amount"                     
##  [3] "Duration"                    "Employment_since"           
##  [5] "Foreign"                     "History"                    
##  [7] "Housing"                     "Installment_rate"           
##  [9] "Job"                         "N_of_credits"               
## [11] "N_of_liables"                "Other_debtors_or_guarantors"
## [13] "Other_installment_plans"     "Phone"                      
## [15] "Property"                    "Purpose"                    
## [17] "Residence_since"             "Savings"                    
## [19] "Status_and_sex"              "Status_of_checking_account"
# Subset the task to desired features
task$select(c("Status_of_checking_account", "Age", "Amount"))

# Plots provided with mlr3viz
autoplot(task)

autoplot(task, type = "pairs")
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

library(mlr3learners)
## Warning: package 'mlr3learners' was built under R version 4.1.3
# Logistic Regression Learner
lr_learner <- lrn("classif.log_reg")
# Random Forest Learner
rf_learner <- lrn("classif.ranger")

# train/test rows
train <- sample(task$nrow, 0.7 * task$nrow)
test <- setdiff(seq_len(task$nrow), train)

# Training the learners
lr_learner$train(task, row_ids = train)
print(lr_learner)
## <LearnerClassifLogReg:classif.log_reg>
## * Model: glm
## * Parameters: list()
## * Packages: mlr3, mlr3learners, stats
## * Predict Type: response
## * Feature types: logical, integer, numeric, character, factor, ordered
## * Properties: loglik, twoclass, weights
rf_learner$train(task, row_ids = train)
print(rf_learner)
## <LearnerClassifRanger:classif.ranger>
## * Model: ranger
## * Parameters: num.threads=1
## * Packages: mlr3, mlr3learners, ranger
## * Predict Type: response
## * Feature types: logical, integer, numeric, character, factor, ordered
## * Properties: hotstart_backward, importance, multiclass, oob_error,
##   twoclass, weights
# Prediction on the test set
lr_pred <- lr_learner$predict(task, row_ids = test)
head(as.data.table(lr_pred))
##    row_ids truth response
## 1:       3  good     good
## 2:       6  good     good
## 3:       7  good     good
## 4:      11   bad     good
## 5:      14   bad     good
## 6:      19   bad      bad
rf_pred <- rf_learner$predict(task, row_ids = test)
head(as.data.table(rf_pred))
##    row_ids truth response
## 1:       3  good     good
## 2:       6  good     good
## 3:       7  good     good
## 4:      11   bad     good
## 5:      14   bad     good
## 6:      19   bad      bad
# Confusion Matrices
lr_pred$confusion
##         truth
## response good bad
##     good  191  76
##     bad    11  22
rf_pred$confusion
##         truth
## response good bad
##     good  190  84
##     bad    12  14
# Again autoplots from mlr3viz
p1 <- autoplot(lr_pred) + labs(title = "Logit")
p2 <- autoplot(rf_pred) + labs(title = "RFC")
p1 + p2

# Assessing performance
accuracy <- msr("classif.acc")
auc <- msr("classif.auc")

lr_pred$score(accuracy)
## classif.acc 
##        0.71
rf_pred$score(accuracy)
## classif.acc 
##        0.68
# Default is Classification Error
lr_pred$score()
## classif.ce 
##       0.29
rf_pred$score()
## classif.ce 
##       0.32
# We can also get the prediction probabilities
lr_learner$predict_type = "prob"
rf_learner$predict_type = "prob"

# We re-fit the model
lr_learner$train(task, row_ids = train)
rf_learner$train(task, row_ids = train)

# We now get probabilities
lr_pred <- lr_learner$predict(task, row_ids = test)
head(as.data.table(lr_pred))
##    row_ids truth response prob.good   prob.bad
## 1:       3  good     good 0.9206041 0.07939590
## 2:       6  good     good 0.7912406 0.20875939
## 3:       7  good     good 0.9196340 0.08036598
## 4:      11   bad     good 0.6440619 0.35593812
## 5:      14   bad     good 0.7088747 0.29112525
## 6:      19   bad      bad 0.3970204 0.60297958
rf_pred <- rf_learner$predict(task, row_ids = test)
head(as.data.table(rf_pred))
##    row_ids truth response prob.good   prob.bad
## 1:       3  good     good 0.9494838 0.05051616
## 2:       6  good     good 0.8857721 0.11422791
## 3:       7  good     good 0.9049654 0.09503457
## 4:      11   bad     good 0.5834960 0.41650397
## 5:      14   bad     good 0.5973205 0.40267950
## 6:      19   bad      bad 0.3441812 0.65581882
p1 <- autoplot(lr_pred, type = "roc") + labs(title = "Logit")
p2 <- autoplot(rf_pred, type = "roc") + labs(title = "RFC")
p1 + p2

# Autotuning the hyperparameters for the RFC
rf_learner$param_set
## <ParamSet>
##                               id    class lower upper nlevels        default
##  1:                        alpha ParamDbl  -Inf   Inf     Inf            0.5
##  2:       always.split.variables ParamUty    NA    NA     Inf <NoDefault[3]>
##  3:                class.weights ParamUty    NA    NA     Inf               
##  4:                      holdout ParamLgl    NA    NA       2          FALSE
##  5:                   importance ParamFct    NA    NA       4 <NoDefault[3]>
##  6:                   keep.inbag ParamLgl    NA    NA       2          FALSE
##  7:                    max.depth ParamInt     0   Inf     Inf               
##  8:                min.node.size ParamInt     1   Inf     Inf              1
##  9:                     min.prop ParamDbl  -Inf   Inf     Inf            0.1
## 10:                      minprop ParamDbl  -Inf   Inf     Inf            0.1
## 11:                         mtry ParamInt     1   Inf     Inf <NoDefault[3]>
## 12:                   mtry.ratio ParamDbl     0     1     Inf <NoDefault[3]>
## 13:            num.random.splits ParamInt     1   Inf     Inf              1
## 14:                  num.threads ParamInt     1   Inf     Inf              1
## 15:                    num.trees ParamInt     1   Inf     Inf            500
## 16:                    oob.error ParamLgl    NA    NA       2           TRUE
## 17:        regularization.factor ParamUty    NA    NA     Inf              1
## 18:      regularization.usedepth ParamLgl    NA    NA       2          FALSE
## 19:                      replace ParamLgl    NA    NA       2           TRUE
## 20:    respect.unordered.factors ParamFct    NA    NA       3         ignore
## 21:              sample.fraction ParamDbl     0     1     Inf <NoDefault[3]>
## 22:                  save.memory ParamLgl    NA    NA       2          FALSE
## 23: scale.permutation.importance ParamLgl    NA    NA       2          FALSE
## 24:                    se.method ParamFct    NA    NA       2        infjack
## 25:                         seed ParamInt  -Inf   Inf     Inf               
## 26:         split.select.weights ParamUty    NA    NA     Inf               
## 27:                    splitrule ParamFct    NA    NA       2           gini
## 28:                      verbose ParamLgl    NA    NA       2           TRUE
## 29:                 write.forest ParamLgl    NA    NA       2           TRUE
##                               id    class lower upper nlevels        default
##        parents value
##  1:                 
##  2:                 
##  3:                 
##  4:                 
##  5:                 
##  6:                 
##  7:                 
##  8:                 
##  9:                 
## 10:                 
## 11:                 
## 12:                 
## 13:  splitrule      
## 14:                1
## 15:                 
## 16:                 
## 17:                 
## 18:                 
## 19:                 
## 20:                 
## 21:                 
## 22:                 
## 23: importance      
## 24:                 
## 25:                 
## 26:                 
## 27:                 
## 28:                 
## 29:                 
##        parents value
# HyperParameter Configuration
search_space <- ps(
  max.depth = p_int(lower = 0, upper = 5),
  min.node.size = p_int(lower = 5, upper = 100)
)

terminator <- trm("evals", n_evals = 10)
tuner <- tnr("random_search")

rf_learner$predict_type = "prob"
# Autotune
at <- AutoTuner$new(
  learner = rf_learner,
  resampling = rsmp("holdout"),
  measure = msr("classif.ce"),
  search_space = search_space,
  terminator = terminator,
  tuner = tuner
)

# Training
at$train(task)
## INFO  [22:09:11.050] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:09:11.055] [mlr3] Applying learner 'classif.ranger' on task 'german' (iter 1/1) 
## INFO  [22:09:11.148] [mlr3] Finished benchmark 
## INFO  [22:09:11.194] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:09:11.199] [mlr3] Applying learner 'classif.ranger' on task 'german' (iter 1/1) 
## INFO  [22:09:11.260] [mlr3] Finished benchmark 
## INFO  [22:09:11.306] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:09:11.310] [mlr3] Applying learner 'classif.ranger' on task 'german' (iter 1/1) 
## INFO  [22:09:11.374] [mlr3] Finished benchmark 
## INFO  [22:09:11.420] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:09:11.424] [mlr3] Applying learner 'classif.ranger' on task 'german' (iter 1/1) 
## INFO  [22:09:11.529] [mlr3] Finished benchmark 
## INFO  [22:09:11.575] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:09:11.579] [mlr3] Applying learner 'classif.ranger' on task 'german' (iter 1/1) 
## INFO  [22:09:11.664] [mlr3] Finished benchmark 
## INFO  [22:09:11.716] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:09:11.721] [mlr3] Applying learner 'classif.ranger' on task 'german' (iter 1/1) 
## INFO  [22:09:11.792] [mlr3] Finished benchmark 
## INFO  [22:09:11.837] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:09:11.842] [mlr3] Applying learner 'classif.ranger' on task 'german' (iter 1/1) 
## INFO  [22:09:11.941] [mlr3] Finished benchmark 
## INFO  [22:09:11.987] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:09:11.992] [mlr3] Applying learner 'classif.ranger' on task 'german' (iter 1/1) 
## INFO  [22:09:12.059] [mlr3] Finished benchmark 
## INFO  [22:09:12.103] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:09:12.108] [mlr3] Applying learner 'classif.ranger' on task 'german' (iter 1/1) 
## INFO  [22:09:12.229] [mlr3] Finished benchmark 
## INFO  [22:09:12.275] [mlr3] Running benchmark with 1 resampling iterations 
## INFO  [22:09:12.279] [mlr3] Applying learner 'classif.ranger' on task 'german' (iter 1/1) 
## INFO  [22:09:12.351] [mlr3] Finished benchmark
# New Prediction
new_pred <- at$predict(task, row_ids = test)

new_pred$confusion
##         truth
## response good bad
##     good  186  48
##     bad    16  50
# New Classification Error
new_pred$score()
## classif.ce 
##  0.2133333
# ROC curve comparision
p3 <- autoplot(new_pred, type = "roc") + labs(title = "at_RFC")
p2 + p3

# --- More advanced ---
design = benchmark_grid(
  tasks = tsks(c("german_credit")),
  learners = lrns(c("classif.ranger", "classif.rpart", "classif.log_reg"),
    predict_type = "prob", predict_sets = c("train", "test")),
  resamplings = rsmps("cv", folds = 10)
)
print(design)
##                 task                    learner         resampling
## 1: <TaskClassif[50]> <LearnerClassifRanger[38]> <ResamplingCV[20]>
## 2: <TaskClassif[50]>  <LearnerClassifRpart[38]> <ResamplingCV[20]>
## 3: <TaskClassif[50]> <LearnerClassifLogReg[37]> <ResamplingCV[20]>
# We call the benchmark on our design
bmr = benchmark(design)
## INFO  [22:09:12.910] [mlr3] Running benchmark with 30 resampling iterations 
## INFO  [22:09:12.914] [mlr3] Applying learner 'classif.ranger' on task 'german_credit' (iter 6/10) 
## INFO  [22:09:13.276] [mlr3] Applying learner 'classif.ranger' on task 'german_credit' (iter 10/10) 
## INFO  [22:09:13.636] [mlr3] Applying learner 'classif.log_reg' on task 'german_credit' (iter 1/10) 
## INFO  [22:09:13.691] [mlr3] Applying learner 'classif.rpart' on task 'german_credit' (iter 1/10) 
## INFO  [22:09:13.724] [mlr3] Applying learner 'classif.rpart' on task 'german_credit' (iter 2/10) 
## INFO  [22:09:13.757] [mlr3] Applying learner 'classif.log_reg' on task 'german_credit' (iter 2/10) 
## INFO  [22:09:13.821] [mlr3] Applying learner 'classif.rpart' on task 'german_credit' (iter 7/10) 
## INFO  [22:09:13.859] [mlr3] Applying learner 'classif.ranger' on task 'german_credit' (iter 8/10) 
## INFO  [22:09:14.218] [mlr3] Applying learner 'classif.ranger' on task 'german_credit' (iter 7/10) 
## INFO  [22:09:14.573] [mlr3] Applying learner 'classif.log_reg' on task 'german_credit' (iter 8/10) 
## INFO  [22:09:14.635] [mlr3] Applying learner 'classif.ranger' on task 'german_credit' (iter 2/10) 
## INFO  [22:09:15.000] [mlr3] Applying learner 'classif.rpart' on task 'german_credit' (iter 10/10) 
## INFO  [22:09:15.036] [mlr3] Applying learner 'classif.rpart' on task 'german_credit' (iter 4/10) 
## INFO  [22:09:15.071] [mlr3] Applying learner 'classif.rpart' on task 'german_credit' (iter 3/10) 
## INFO  [22:09:15.107] [mlr3] Applying learner 'classif.log_reg' on task 'german_credit' (iter 5/10) 
## INFO  [22:09:15.163] [mlr3] Applying learner 'classif.ranger' on task 'german_credit' (iter 3/10) 
## INFO  [22:09:15.540] [mlr3] Applying learner 'classif.rpart' on task 'german_credit' (iter 5/10) 
## INFO  [22:09:15.576] [mlr3] Applying learner 'classif.log_reg' on task 'german_credit' (iter 6/10) 
## INFO  [22:09:15.628] [mlr3] Applying learner 'classif.ranger' on task 'german_credit' (iter 4/10) 
## INFO  [22:09:15.996] [mlr3] Applying learner 'classif.ranger' on task 'german_credit' (iter 9/10) 
## INFO  [22:09:16.456] [mlr3] Applying learner 'classif.log_reg' on task 'german_credit' (iter 3/10) 
## INFO  [22:09:16.549] [mlr3] Applying learner 'classif.log_reg' on task 'german_credit' (iter 10/10) 
## INFO  [22:09:16.613] [mlr3] Applying learner 'classif.ranger' on task 'german_credit' (iter 1/10) 
## INFO  [22:09:16.976] [mlr3] Applying learner 'classif.log_reg' on task 'german_credit' (iter 7/10) 
## INFO  [22:09:17.036] [mlr3] Applying learner 'classif.rpart' on task 'german_credit' (iter 9/10) 
## INFO  [22:09:17.076] [mlr3] Applying learner 'classif.rpart' on task 'german_credit' (iter 6/10) 
## INFO  [22:09:17.118] [mlr3] Applying learner 'classif.rpart' on task 'german_credit' (iter 8/10) 
## INFO  [22:09:17.161] [mlr3] Applying learner 'classif.log_reg' on task 'german_credit' (iter 9/10) 
## INFO  [22:09:17.213] [mlr3] Applying learner 'classif.ranger' on task 'german_credit' (iter 5/10) 
## INFO  [22:09:17.561] [mlr3] Applying learner 'classif.log_reg' on task 'german_credit' (iter 4/10) 
## INFO  [22:09:17.628] [mlr3] Finished benchmark
# We can choose many performance measures
measures = list(
  msr("classif.auc", predict_sets = "train", id = "auc_train"),
  msr("classif.auc", id = "auc_test")
)

tab = bmr$aggregate(measures)
print(tab)
##    nr      resample_result       task_id      learner_id resampling_id iters
## 1:  1 <ResampleResult[22]> german_credit  classif.ranger            cv    10
## 2:  2 <ResampleResult[22]> german_credit   classif.rpart            cv    10
## 3:  3 <ResampleResult[22]> german_credit classif.log_reg            cv    10
##    auc_train  auc_test
## 1: 0.9984376 0.8018888
## 2: 0.7862836 0.7299559
## 3: 0.8396874 0.7877817
# Plotting the benchmark results (wrt classification error)
autoplot(bmr)

# lastly the goodlooking ROC curves
autoplot(bmr, type = "roc")