Note: This HTML file is based on the mlr3 book , which is the official documentation page of the package. For a more detailed introduction to the package check the link.
Introduction to mlr3
What is it?
The mlr3 (Lang et al. 2019) package and ecosystem provide a generic, object-oriented, and extensible framework for classification, regression, survival analysis, and other machine learning tasks for the R language.
What is it used for?
Its interface provides functionality to extend and combine existing learners, intelligently select and tune the most appropriate technique for a task, and perform large-scale comparisons that enable meta-learning.
What makes it special?
Advanced functionalities include hyperparameter tuning and feature selection. Parallelization of many operations is natively supported. But more on that later.
library(mlr3verse) # loads most packages in the mlr3 world
Contents:
- Basics
- Performance Evaluation and Comparison
- Model Optimization
- Technical and Special Tasks
1. Basics
Below you see a typical ML taskflow. mlr3 can take care of every step.
Fig. 1: Typical ML Workflow
Overview
In this section we will focus on
R6
Tasks/Data
Learners
Training and Predicting
R6
one of R’s more recent dialects for object-oriented programming (OO)
Objects are created by calling the constructor of an R6::R6Class() object, specifically the initialization method $new():
e.g foo = new(bar = 1) creates a new object of class Foo, setting the bar argument of the constructor to the value 1Most objects in mlr3 are created through special functions (e.g. lrn(“regr.rpart”)
Objects have mutable state that is encapsulated in their fields, which can be accessed through the dollar operator.
Objects expose methods that allow to inspect the object’s state, retrieve information, or perform an action that changes the internal state of the object.
Tasks
Tasks encapsulate the data with meta-information, such as the name of the prediction target column.
Task Types
To create a task from a data.frame()
, data.table()
or Matrix()
, you first need to select the right task type:
Classification Task
Regression Task
Survival Task
Density, Cluster, Spatial, Ordinal Regression Tasks
Task Creation and Predefined Tasks
mlr_tasks
## <DictionaryTask> with 28 stored values
## Keys: actg, bike_sharing, boston_housing, breast_cancer, faithful,
## gbcs, german_credit, grace, ilpd, iris, kc_housing, lung, moneyball,
## mtcars, optdigits, penguins, penguins_simple, pima, precip, rats,
## sonar, spam, titanic, unemployment, usarrests, whas, wine, zoo
as.data.table(mlr_tasks)
## key label task_type nrow
## 1: actg <NA> surv 1151
## 2: bike_sharing Bike Sharing Demand regr 17379
## 3: boston_housing Boston Housing Prices regr 506
## 4: breast_cancer Wisconsin Breast Cancer classif 683
## 5: faithful <NA> dens 272
## 6: gbcs <NA> surv 686
## 7: german_credit German Credit classif 1000
## 8: grace <NA> surv 1000
## 9: ilpd Indian Liver Patient Data classif 583
## 10: iris Iris Flowers classif 150
## 11: kc_housing King County House Sales regr 21613
## 12: lung <NA> surv 228
## 13: moneyball Major League Baseball Statistics regr 1232
## 14: mtcars Motor Trends regr 32
## 15: optdigits Optical Recognition of Handwritten Digits classif 5620
## 16: penguins Palmer Penguins classif 344
## 17: penguins_simple Simplified Palmer Penguins classif 333
## 18: pima Pima Indian Diabetes classif 768
## 19: precip <NA> dens 70
## 20: rats <NA> surv 300
## 21: sonar Sonar: Mines vs. Rocks classif 208
## 22: spam HP Spam Detection classif 4601
## 23: titanic Titanic classif 1309
## 24: unemployment <NA> surv 3343
## 25: usarrests US Arrests clust 50
## 26: whas <NA> surv 481
## 27: wine Wine Regions classif 178
## 28: zoo Zoo Animals classif 101
## key label task_type nrow
## ncol properties lgl int dbl chr fct ord pxc
## 1: 13 0 3 4 0 4 0 0
## 2: 14 2 4 4 1 2 0 0
## 3: 19 0 3 13 0 2 0 0
## 4: 10 twoclass 0 0 0 0 0 9 0
## 5: 1 0 0 1 0 0 0 0
## 6: 10 0 4 4 0 0 0 0
## 7: 21 twoclass 0 3 0 0 14 3 0
## 8: 8 0 2 4 0 0 0 0
## 9: 11 twoclass 0 4 5 0 1 0 0
## 10: 5 multiclass 0 0 4 0 0 0 0
## 11: 20 1 13 4 0 0 0 1
## 12: 10 0 7 0 0 1 0 0
## 13: 15 0 3 5 0 6 0 0
## 14: 11 0 0 10 0 0 0 0
## 15: 65 twoclass 0 64 0 0 0 0 0
## 16: 8 multiclass 0 3 2 0 2 0 0
## 17: 11 multiclass 0 3 7 0 0 0 0
## 18: 9 twoclass 0 0 8 0 0 0 0
## 19: 1 0 0 1 0 0 0 0
## 20: 5 0 2 0 0 1 0 0
## 21: 61 twoclass 0 0 60 0 0 0 0
## 22: 58 twoclass 0 0 57 0 0 0 0
## 23: 11 twoclass 0 2 2 3 2 1 0
## 24: 6 0 1 2 0 1 0 0
## 25: 4 0 2 2 0 0 0 0
## 26: 11 0 4 3 0 2 0 0
## 27: 14 multiclass 0 2 11 0 0 0 0
## 28: 17 multiclass 15 1 0 0 0 0 0
## ncol properties lgl int dbl chr fct ord pxc
mlr3 provides the shortcut function tsk(). Here, we retrieve the palmer penguins task, which is provided by the package palmerpenguins:
= tsk("penguins")
task_penguins print(task_penguins)
## <TaskClassif:penguins> (344 x 8): Palmer Penguins
## * Target: species
## * Properties: multiclass
## * Features (7):
## - int (3): body_mass, flipper_length, year
## - dbl (2): bill_depth, bill_length
## - fct (2): island, sex
Creating a task form a dataset:
data("mtcars", package = "datasets")
= mtcars[, 1:3]
data str(data)
## 'data.frame': 32 obs. of 3 variables:
## $ mpg : num 21 21 22.8 21.4 18.7 18.1 14.3 24.4 22.8 19.2 ...
## $ cyl : num 6 6 4 6 8 6 8 4 4 6 ...
## $ disp: num 160 160 108 258 360 ...
= as_task_regr(data, target = "mpg", id = "cars")
task_mtcars print(task_mtcars)
## <TaskRegr:cars> (32 x 3)
## * Target: mpg
## * Properties: -
## * Features (2):
## - dbl (2): cyl, disp
Retrieving Tasks
The data stored in a task can be retrieved directly from fields
task_mtcars
## <TaskRegr:cars> (32 x 3)
## * Target: mpg
## * Properties: -
## * Features (2):
## - dbl (2): cyl, disp
More information can be obtained through methods of the object, for example:
$data() task_mtcars
## mpg cyl disp
## 1: 21.0 6 160.0
## 2: 21.0 6 160.0
## 3: 22.8 4 108.0
## 4: 21.4 6 258.0
## 5: 18.7 8 360.0
## 6: 18.1 6 225.0
## 7: 14.3 8 360.0
## 8: 24.4 4 146.7
## 9: 22.8 4 140.8
## 10: 19.2 6 167.6
## 11: 17.8 6 167.6
## 12: 16.4 8 275.8
## 13: 17.3 8 275.8
## 14: 15.2 8 275.8
## 15: 10.4 8 472.0
## 16: 10.4 8 460.0
## 17: 14.7 8 440.0
## 18: 32.4 4 78.7
## 19: 30.4 4 75.7
## 20: 33.9 4 71.1
## 21: 21.5 4 120.1
## 22: 15.5 8 318.0
## 23: 15.2 8 304.0
## 24: 13.3 8 350.0
## 25: 19.2 8 400.0
## 26: 27.3 4 79.0
## 27: 26.0 4 120.3
## 28: 30.4 4 95.1
## 29: 15.8 8 351.0
## 30: 19.7 6 145.0
## 31: 15.0 8 301.0
## 32: 21.4 4 121.0
## mpg cyl disp
In mlr3, each row (observation) has a unique identifier, stored as an integer(). These can be passed as arguments to the $data() method to select specific rows:
head(task_mtcars$row_ids)
## [1] 1 2 3 4 5 6
# retrieve data for rows with IDs 1, 5, and 10
$data(rows = c(1, 5, 10)) task_mtcars
## mpg cyl disp
## 1: 21.0 6 160.0
## 2: 18.7 8 360.0
## 3: 19.2 6 167.6
Similarly to row IDs, target and feature columns also have unique identifiers, i.e. names (stored as character()). Their names can be accessed via the public slots $feature_names and $target_names.
$feature_names task_mtcars
## [1] "cyl" "disp"
$target_names task_mtcars
## [1] "mpg"
To extract the complete data from the task, one can also simply convert it to a data.table:
# show summary of entire data
summary(as.data.table(task_mtcars))
## mpg cyl disp
## Min. :10.40 Min. :4.000 Min. : 71.1
## 1st Qu.:15.43 1st Qu.:4.000 1st Qu.:120.8
## Median :19.20 Median :6.000 Median :196.3
## Mean :20.09 Mean :6.188 Mean :230.7
## 3rd Qu.:22.80 3rd Qu.:8.000 3rd Qu.:326.0
## Max. :33.90 Max. :8.000 Max. :472.0
Task Mutators
filter() subsets the current view based on row IDs and select() subsets the view based on feature names.
= tsk("penguins")
task_penguins $select(c("body_mass", "flipper_length")) # keep only these features
task_penguins$filter(1:3) # keep only these rows
task_penguins$head() task_penguins
## species body_mass flipper_length
## 1: Adelie 3750 181
## 2: Adelie 3800 186
## 3: Adelie 3250 195
The methods rbind() and cbind() allow to add extra rows and columns to a task. Again, the original data is not changed.
$cbind(data.frame(letters = letters[1:3])) # add column letters
task_penguins$head() task_penguins
## species body_mass flipper_length letters
## 1: Adelie 3750 181 a
## 2: Adelie 3800 186 b
## 3: Adelie 3250 195 c
Plotting Tasks
The mlr3viz package provides plotting facilities for many classes implemented in mlr3. The available plot types depend on the class, but all plots are returned as ggplot2 objects which can be easily customized.
Some examples:
# get the pima indians task
= tsk("pima")
task
# subset task to only use the 3 first features
$select(head(task$feature_names, 3))
task
# default plot: class frequencies
autoplot(task)
## Warning: package 'GGally' was built under R version 4.1.3
## Loading required package: ggplot2
## Registered S3 method overwritten by 'GGally':
## method from
## +.gg ggplot2
# pairs plot (requires package GGally)
autoplot(task, type = "pairs")
# duo plot (requires package GGally)
autoplot(task, type = "duo")
## Warning: Removed 5 rows containing non-finite values (stat_boxplot).
## Warning: Removed 374 rows containing non-finite values (stat_boxplot).
Of course, you can do the same for regression tasks.
# get the complete mtcars task
= tsk("mtcars")
task
# subset task to only use the 3 first features
$select(head(task$feature_names, 3))
task
# default plot: boxplot of target variable
autoplot(task)
# pairs plot (requires package GGally)
autoplot(task, type = "pairs")
Learners
Learners encapsulate machine learning algorithms to train models and make predictions for a task. These are provided by other packages.
Objects of class Learner provide a unified interface to many popular machine learning algorithms in R. They consist of methods to train and predict a model for a Task and provide meta-information about the learners, such as the hyperparameters (which control the behavior of the learner) you can set.
Learner 2 stage Process
Predefined Learners
Basic:
- mlr_learners_classif.featureless
- mlr_learners_regr.featureless:
- mlr_learners_classif.rpart
- mlr_learners_regr.rpart
In mlr3learners package:
- Linear and logistic regression
- Penalized Generalized Linear Models
- k-Nearest Neighbors regression and classification
- Kriging
- Linear and Quadratic Discriminant Analysis
- Naive Bayes
- Support-Vector machines
- Gradient Boosting
- Random Forests for regression, classification and survival
mlr_learners
## <DictionaryLearner> with 136 stored values
## Keys: classif.AdaBoostM1, classif.bart, classif.C50, classif.catboost,
## classif.cforest, classif.ctree, classif.cv_glmnet, classif.debug,
## classif.earth, classif.extratrees, classif.featureless, classif.fnn,
## classif.gam, classif.gamboost, classif.gausspr, classif.gbm,
## classif.glmboost, classif.glmnet, classif.IBk, classif.J48,
## classif.JRip, classif.kknn, classif.ksvm, classif.lda,
## classif.liblinear, classif.lightgbm, classif.LMT, classif.log_reg,
## classif.lssvm, classif.mob, classif.multinom, classif.naive_bayes,
## classif.nnet, classif.OneR, classif.PART, classif.qda,
## classif.randomForest, classif.ranger, classif.rfsrc, classif.rpart,
## classif.svm, classif.xgboost, clust.agnes, clust.ap, clust.cmeans,
## clust.cobweb, clust.dbscan, clust.diana, clust.em, clust.fanny,
## clust.featureless, clust.ff, clust.hclust, clust.kkmeans,
## clust.kmeans, clust.MBatchKMeans, clust.meanshift, clust.pam,
## clust.SimpleKMeans, clust.xmeans, dens.hist, dens.kde, dens.kde_kd,
## dens.kde_ks, dens.locfit, dens.logspline, dens.mixed, dens.nonpar,
## dens.pen, dens.plug, dens.spline, regr.bart, regr.catboost,
## regr.cforest, regr.ctree, regr.cubist, regr.cv_glmnet, regr.debug,
## regr.earth, regr.extratrees, regr.featureless, regr.fnn, regr.gam,
## regr.gamboost, regr.gausspr, regr.gbm, regr.glm, regr.glmboost,
## regr.glmnet, regr.IBk, regr.kknn, regr.km, regr.ksvm, regr.liblinear,
## regr.lightgbm, regr.lm, regr.M5Rules, regr.mars, regr.mob,
## regr.randomForest, regr.ranger, regr.rfsrc, regr.rpart, regr.rvm,
## regr.svm, regr.xgboost, surv.akritas, surv.blackboost, surv.cforest,
## surv.coxboost, surv.coxph, surv.coxtime, surv.ctree,
## surv.cv_coxboost, surv.cv_glmnet, surv.deephit, surv.deepsurv,
## surv.dnnsurv, surv.flexible, surv.gamboost, surv.gbm, surv.glmboost,
## surv.glmnet, surv.kaplan, surv.loghaz, surv.mboost, surv.nelson,
## surv.obliqueRSF, surv.parametric, surv.pchazard, surv.penalized,
## surv.ranger, surv.rfsrc, surv.rpart, surv.svm, surv.xgboost
Access Learner Information
Each learner provides the following meta-information:
- feature_types
- packages
- properties
- predict_types
You can retrieve a specific learner using its ID:
= lrn("classif.rpart")
learner print(learner)
## <LearnerClassifRpart:classif.rpart>: Classification Tree
## * Model: -
## * Parameters: xval=0
## * Packages: mlr3, rpart
## * Predict Type: response
## * Feature types: logical, integer, numeric, factor, ordered
## * Properties: importance, missings, multiclass, selected_features,
## twoclass, weights
Each learner has hyperparameters that control its behavior. The field param_set stores a description of the hyperparameters the learner has, their ranges, defaults, and current values:
$param_set learner
## <ParamSet>
## id class lower upper nlevels default value
## 1: cp ParamDbl 0 1 Inf 0.01
## 2: keep_model ParamLgl NA NA 2 FALSE
## 3: maxcompete ParamInt 0 Inf Inf 4
## 4: maxdepth ParamInt 1 30 30 30
## 5: maxsurrogate ParamInt 0 Inf Inf 5
## 6: minbucket ParamInt 1 Inf Inf <NoDefault[3]>
## 7: minsplit ParamInt 1 Inf Inf 20
## 8: surrogatestyle ParamInt 0 1 2 0
## 9: usesurrogate ParamInt 0 2 3 2
## 10: xval ParamInt 0 Inf Inf 10 0
You can easily change the current hyperparameter values like this
= learner$param_set$values
pv $cp = 0.02
pv$param_set$values = pv learner
or like this
= lrn("classif.rpart", id = "rp", cp = 0.001)
learner $id learner
## [1] "rp"
$param_set$values learner
## $xval
## [1] 0
##
## $cp
## [1] 0.001
Thresholding
In the example below, we change the threshold to 0.2, which improves the True Positive Rate (TPR).
data("Sonar", package = "mlbench")
= as_task_classif(Sonar, target = "Class", positive = "M")
task = lrn("classif.rpart", predict_type = "prob")
learner = learner$train(task)$predict(task)
pred
= msrs(c("classif.tpr", "classif.tnr")) # use msrs() to get a list of multiple measures
measures $confusion pred
## truth
## response M R
## M 95 10
## R 16 87
$score(measures) pred
## classif.tpr classif.tnr
## 0.8558559 0.8969072
$set_threshold(0.2)
pred$confusion pred
## truth
## response M R
## M 104 25
## R 7 72
$score(measures) pred
## classif.tpr classif.tnr
## 0.9369369 0.7422680
Training and Predicting
This section illustrates how to use tasks and learners to train a model and make predictions on a new data set. The concept is demonstrated on a supervised classification task using the penguins dataset and the rpart learner, which builds a single classification tree.
Creating Task and Learner Objects
= tsk("penguins")
task = lrn("classif.rpart") learner
Training und Test Splits
# index vectors
= sample(task$nrow, 0.8 * task$nrow)
train_set = setdiff(seq_len(task$nrow), train_set) test_set
Later on: mlr3 can automatically create training and test sets based on different resampling strategies.
Training the learner
The field model stores the model that is produced in the training step.
# fit the classification tree using the training set of the task by calling the $train() method of learner:
$train(task, row_ids = train_set)
learnerprint(learner$model)
## n= 275
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 275 157 Adelie (0.429090909 0.192727273 0.378181818)
## 2) flipper_length< 206.5 166 49 Adelie (0.704819277 0.289156627 0.006024096)
## 4) bill_length< 43.35 117 4 Adelie (0.965811966 0.034188034 0.000000000) *
## 5) bill_length>=43.35 49 5 Chinstrap (0.081632653 0.897959184 0.020408163) *
## 3) flipper_length>=206.5 109 6 Gentoo (0.009174312 0.045871560 0.944954128)
## 6) bill_depth>=17.2 7 2 Chinstrap (0.142857143 0.714285714 0.142857143) *
## 7) bill_depth< 17.2 102 0 Gentoo (0.000000000 0.000000000 1.000000000) *
Predicting
After the model has been fitted to the training data, we use the test set for prediction:
= learner$predict(task, row_ids = test_set)
prediction print(prediction)
## <PredictionClassif> for 69 observations:
## row_ids truth response
## 5 Adelie Adelie
## 7 Adelie Adelie
## 9 Adelie Adelie
## ---
## 326 Chinstrap Chinstrap
## 332 Chinstrap Chinstrap
## 335 Chinstrap Chinstrap
A prediction objects holds the row IDs of the test data, the respective true label of the target column and the respective predictions. The simplest way to extract this information is by converting the Prediction object to a data.table()
head(as.data.table(prediction)) # show first six predictions
## row_ids truth response
## 1: 5 Adelie Adelie
## 2: 7 Adelie Adelie
## 3: 9 Adelie Adelie
## 4: 11 Adelie Adelie
## 5: 15 Adelie Adelie
## 6: 18 Adelie Adelie
$confusion prediction
## truth
## response Adelie Chinstrap Gentoo
## Adelie 33 1 0
## Chinstrap 1 14 1
## Gentoo 0 0 19
Changing the Predict Type
$predict_type = "prob"
learner
# re-fit the model
$train(task, row_ids = train_set)
learner
# rebuild prediction object
= learner$predict(task, row_ids = test_set) prediction
The prediction object now contains probabilities for all class labels in addition to the predicted label (the one with the highest probability):
# data.table conversion
head(as.data.table(prediction)) # show first six
## row_ids truth response prob.Adelie prob.Chinstrap prob.Gentoo
## 1: 5 Adelie Adelie 0.965812 0.03418803 0
## 2: 7 Adelie Adelie 0.965812 0.03418803 0
## 3: 9 Adelie Adelie 0.965812 0.03418803 0
## 4: 11 Adelie Adelie 0.965812 0.03418803 0
## 5: 15 Adelie Adelie 0.965812 0.03418803 0
## 6: 18 Adelie Adelie 0.965812 0.03418803 0
# directly access the predicted labels:
head(prediction$response)
## [1] Adelie Adelie Adelie Adelie Adelie Adelie
## Levels: Adelie Chinstrap Gentoo
# directly access the matrix of probabilities:
head(prediction$prob)
## Adelie Chinstrap Gentoo
## [1,] 0.965812 0.03418803 0
## [2,] 0.965812 0.03418803 0
## [3,] 0.965812 0.03418803 0
## [4,] 0.965812 0.03418803 0
## [5,] 0.965812 0.03418803 0
## [6,] 0.965812 0.03418803 0
Similarly to predicting probabilities for classification, many regression learners support the extraction of standard error estimates for predictions by setting the predict type to “se”.
Plotting Predictions
Similarly to plotting tasks, mlr3viz provides an autoplot() method for Prediction objects.
= tsk("penguins")
task = lrn("classif.rpart", predict_type = "prob")
learner $train(task)
learner= learner$predict(task)
prediction autoplot(prediction)
Performance Assessment
Available measures are stored in mlr_measures:
mlr_measures
## <DictionaryMeasure> with 87 stored values
## Keys: aic, bic, classif.acc, classif.auc, classif.bacc, classif.bbrier,
## classif.ce, classif.costs, classif.dor, classif.fbeta, classif.fdr,
## classif.fn, classif.fnr, classif.fomr, classif.fp, classif.fpr,
## classif.logloss, classif.mbrier, classif.mcc, classif.npv,
## classif.ppv, classif.prauc, classif.precision, classif.recall,
## classif.sensitivity, classif.specificity, classif.tn, classif.tnr,
## classif.tp, classif.tpr, clust.ch, clust.db, clust.dunn,
## clust.silhouette, clust.wss, debug, dens.logloss, oob_error,
## regr.bias, regr.ktau, regr.mae, regr.mape, regr.maxae, regr.medae,
## regr.medse, regr.mse, regr.msle, regr.pbias, regr.rae, regr.rmse,
## regr.rmsle, regr.rrse, regr.rse, regr.rsq, regr.sae, regr.smape,
## regr.srho, regr.sse, selected_features, sim.jaccard, sim.phi,
## surv.brier, surv.calib_alpha, surv.calib_beta, surv.chambless_auc,
## surv.cindex, surv.dcalib, surv.graf, surv.hung_auc, surv.intlogloss,
## surv.logloss, surv.mae, surv.mse, surv.nagelk_r2, surv.oquigley_r2,
## surv.rmse, surv.schmid, surv.song_auc, surv.song_tnr, surv.song_tpr,
## surv.uno_auc, surv.uno_tnr, surv.uno_tpr, surv.xu_r2, time_both,
## time_predict, time_train
We choose accuracy (classif.acc) as our specific performance measure here and call the method $score() of the prediction object to quantify the predictive performance of our model.
= msr("classif.acc")
measure print(measure)
## <MeasureClassifSimple:classif.acc>: Classification Accuracy
## * Packages: mlr3, mlr3measures
## * Range: [0, 1]
## * Minimize: FALSE
## * Average: macro
## * Parameters: list()
## * Properties: -
## * Predict type: response
$score(measure) prediction
## classif.acc
## 0.9651163
If no measure is specified, classification defaults to classification error (classif.ce, the inverse of accuracy) and regression to the mean squared error (regr.mse).
Performance Evaluation
Now, let us explain how mlr3 enable us to perform many common machine learning steps. Like:
Binary classification and ROC curves
Resampling
Benchmarking
ROC curve and Thresholds
Binary classification is a special case of classification where the target variable to predict has only two possible values and a threshold probability to distinguish between both. ROC (Receiver Operating Characteristics) analysis applies particularly to this case and allows us to get a better understanding of the trade offs when choosing between the two classes.
data("Sonar", package = "mlbench")
= as_task_classif(Sonar, target = "Class", positive = "M")
task
= lrn("classif.rpart", predict_type = "prob")
learner = learner$train(task)$predict(task)
pred = pred$confusion
C print(C)
## truth
## response M R
## M 95 10
## R 16 87
Fig. 2: confussion matrix
We can then derive the following performance metrics from a confusion matrix:
True Positive Rate (TPR): How many of the true positives did we predict as positive?
True Negative Rate (TNR): How many of the true negatives did we predict as negative?
Positive Predictive Value (PPV): If we predict positive, how likely is it a true positive?
Negative Predictive Value (NPV): If we predict negative, how likely is it a true negative?
The ROC curve plots the TPR and FPR values for different thresholds in order to characterize the behavior of a binary classifier.
For mlr3 prediction objects, the ROC curve can easily be created with mlrviz
which relies in the (precision-recall) precrec
to calculate and plot ROC curves:
library("mlr3viz")
## Warning: package 'mlr3viz' was built under R version 4.1.3
# TPR vs FPR / Sensitivity vs (1 - Specificity)
autoplot(pred, type = "roc")
Also, we can plot the precision-recall (PPV vs. TPR). Which are preferred over ROC curves for imbalanced class distribution.
# Precision vs Recall
autoplot(pred, type = "prc")
Resampling
When evaluating the performance of a model, we are interested in its generalization performance. Thus, we can estimate it by evaluating a model on a test set as we have done above.
There are different strategies (resampling) for partitioning a data set into training and test sets. In mlr3 there are the following predefined resampling strategies:
cross validation
(“cv”),leave-one-out cross validation
(“loo”),repeated cross validation
(“repeated_cv”),bootstrapping
(“bootstrap”),subsampling
(“subsampling”),holdout
(“holdout”),in-sample resampling
(“insample”), andcustom resampling
(“custom”)Additional resampling methods for special use cases are available via extension packages, such as
mlr3spatiotemporal
for spatial data.
# mlr resampling strategies
as.data.table(mlr_resamplings)
## key label params iters
## 1: bootstrap Bootstrap ratio,repeats 30
## 2: custom Custom Splits NA
## 3: custom_cv Custom Split Cross-Validation NA
## 4: cv Cross-Validation folds 10
## 5: holdout Holdout ratio 1
## 6: insample Insample Resampling 1
## 7: loo Leave-One-Out NA
## 8: repeated_cv Repeated Cross-Validation folds,repeats 100
## 9: subsampling Subsampling ratio,repeats 30
Settings
library("mlr3verse")
# Set the task to use
= tsk("penguins")
task
# Set the learner:
# a simple classification tree
# from the rpart package
= lrn("classif.rpart")
learner
# Set resampling strategy
= rsmp("holdout")
resampling print(resampling)
## <ResamplingHoldout>: Holdout
## * Iterations: 1
## * Instantiated: FALSE
## * Parameters: ratio=0.6667
# by default we have a ratio of 2/3 obs
# going into the training set
# We can change the hyper parameter as:
rsmp("holdout", ratio = 0.8)
## <ResamplingHoldout>: Holdout
## * Iterations: 1
## * Instantiated: FALSE
## * Parameters: ratio=0.8
Now, to actually perform the splitting and obtain indices for the training and test split, the resampling strategy needs a Task.
# Instatiation
$instantiate(task)
resampling
# training set
str(resampling$train_set(1))
## int [1:229] 3 5 8 9 10 11 12 15 16 17 ...
# test set
str(resampling$test_set(1))
## int [1:115] 1 2 4 6 7 13 14 19 22 23 ...
Note: If we want to compare multiple Learners, one needs to use the same instantiated resampling for each learner, such that each learner gets exactly the same training data and the performance of the model is evaluated with the same test set. For more on comparison of multiple learners refer to section on benchmarking
of the mlr3book.
Execution
Once we have defined a Task, a Learner, and a Resampling object we can call resample()
, which fits the learner on the training set and evaluates it on the test set.
# binary classif task
= tsk("pima")
task # select 2 features
$select(c("glucose", "mass"))
task= lrn("classif.rpart", predict_type = "prob")
learner
# 10-fold-cv
= rsmp("cv")
resampling
# Execute the resampling
# store_models = TRUE, to keep the fitted models
= resample(task, learner, resampling, store_models = TRUE) rr
## INFO [22:08:45.068] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 4/10)
## INFO [22:08:45.111] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 6/10)
## INFO [22:08:45.145] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 9/10)
## INFO [22:08:45.180] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 2/10)
## INFO [22:08:45.199] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 3/10)
## INFO [22:08:45.217] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 10/10)
## INFO [22:08:45.236] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 7/10)
## INFO [22:08:45.256] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 5/10)
## INFO [22:08:45.275] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/10)
## INFO [22:08:45.295] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 8/10)
print(rr)
## <ResampleResult> of 10 iterations
## * Task: pima
## * Learner: classif.rpart
## * Warnings: 0 in 0 iterations
## * Errors: 0 in 0 iterations
# we can calculate the avg performance
# across all resampling iterations,
# in terms of classification error
$aggregate(msr("classif.ce")) rr
## classif.ce
## 0.2644224
# or for the individual resampling iterations
# to check if iterations are diff from the avg
$score(msr("classif.ce")) rr
## task task_id learner learner_id
## 1: <TaskClassif[50]> pima <LearnerClassifRpart[38]> classif.rpart
## 2: <TaskClassif[50]> pima <LearnerClassifRpart[38]> classif.rpart
## 3: <TaskClassif[50]> pima <LearnerClassifRpart[38]> classif.rpart
## 4: <TaskClassif[50]> pima <LearnerClassifRpart[38]> classif.rpart
## 5: <TaskClassif[50]> pima <LearnerClassifRpart[38]> classif.rpart
## 6: <TaskClassif[50]> pima <LearnerClassifRpart[38]> classif.rpart
## 7: <TaskClassif[50]> pima <LearnerClassifRpart[38]> classif.rpart
## 8: <TaskClassif[50]> pima <LearnerClassifRpart[38]> classif.rpart
## 9: <TaskClassif[50]> pima <LearnerClassifRpart[38]> classif.rpart
## 10: <TaskClassif[50]> pima <LearnerClassifRpart[38]> classif.rpart
## resampling resampling_id iteration prediction
## 1: <ResamplingCV[20]> cv 1 <PredictionClassif[20]>
## 2: <ResamplingCV[20]> cv 2 <PredictionClassif[20]>
## 3: <ResamplingCV[20]> cv 3 <PredictionClassif[20]>
## 4: <ResamplingCV[20]> cv 4 <PredictionClassif[20]>
## 5: <ResamplingCV[20]> cv 5 <PredictionClassif[20]>
## 6: <ResamplingCV[20]> cv 6 <PredictionClassif[20]>
## 7: <ResamplingCV[20]> cv 7 <PredictionClassif[20]>
## 8: <ResamplingCV[20]> cv 8 <PredictionClassif[20]>
## 9: <ResamplingCV[20]> cv 9 <PredictionClassif[20]>
## 10: <ResamplingCV[20]> cv 10 <PredictionClassif[20]>
## classif.ce
## 1: 0.2727273
## 2: 0.2077922
## 3: 0.1948052
## 4: 0.2337662
## 5: 0.3376623
## 6: 0.2077922
## 7: 0.3246753
## 8: 0.2597403
## 9: 0.3552632
## 10: 0.2500000
# and many more getters stored in rr:
$warnings rr
## Empty data.table (0 rows and 2 cols): iteration,msg
$errors rr
## Empty data.table (0 rows and 2 cols): iteration,msg
$resampling rr
## <ResamplingCV>: Cross-Validation
## * Iterations: 10
## * Instantiated: TRUE
## * Parameters: folds=10
# the model trained in a specific iteration
= rr$learners[[1]]
lrn $model lrn
## n= 691
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 691 237 neg (0.3429812 0.6570188)
## 2) glucose>=123.5 290 121 pos (0.5827586 0.4172414)
## 4) mass>=29.95 205 63 pos (0.6926829 0.3073171)
## 8) glucose>=154.5 85 12 pos (0.8588235 0.1411765) *
## 9) glucose< 154.5 120 51 pos (0.5750000 0.4250000)
## 18) mass>=41.65 22 4 pos (0.8181818 0.1818182) *
## 19) mass< 41.65 98 47 pos (0.5204082 0.4795918)
## 38) mass< 39.95 86 37 pos (0.5697674 0.4302326)
## 76) glucose>=129.5 56 20 pos (0.6428571 0.3571429)
## 152) mass>=34.05 23 4 pos (0.8260870 0.1739130) *
## 153) mass< 34.05 33 16 pos (0.5151515 0.4848485)
## 306) mass< 31.8 12 3 pos (0.7500000 0.2500000) *
## 307) mass>=31.8 21 8 neg (0.3809524 0.6190476) *
## 77) glucose< 129.5 30 13 neg (0.4333333 0.5666667)
## 154) mass< 34.15 18 7 pos (0.6111111 0.3888889) *
## 155) mass>=34.15 12 2 neg (0.1666667 0.8333333) *
## 39) mass>=39.95 12 2 neg (0.1666667 0.8333333) *
## 5) mass< 29.95 85 27 neg (0.3176471 0.6823529)
## 10) glucose>=160 18 7 pos (0.6111111 0.3888889) *
## 11) glucose< 160 67 16 neg (0.2388060 0.7611940) *
## 3) glucose< 123.5 401 68 neg (0.1695761 0.8304239) *
# all individual predictions
# merged into a single Prediction object
$prediction() rr
## <PredictionClassif> for 768 observations:
## row_ids truth response prob.pos prob.neg
## 21 neg neg 0.1666667 0.8333333
## 40 pos neg 0.1695761 0.8304239
## 46 pos pos 0.8588235 0.1411765
## ---
## 705 neg neg 0.1921296 0.8078704
## 750 pos neg 0.3174603 0.6825397
## 766 neg neg 0.1921296 0.8078704
# or in a specific resampling iteration
$predictions()[[1]] rr
## <PredictionClassif> for 77 observations:
## row_ids truth response prob.pos prob.neg
## 21 neg neg 0.1666667 0.8333333
## 40 pos neg 0.1695761 0.8304239
## 46 pos pos 0.8588235 0.1411765
## ---
## 753 neg neg 0.1695761 0.8304239
## 754 pos pos 0.8588235 0.1411765
## 762 pos pos 0.8588235 0.1411765
Plotting Resample Results
mlr3viz
provides a autoplot()
method for resampling results.
# boxplot of AUC values across the 10 folds
autoplot(rr, measure = msr("classif.auc"))
# ROC curve, averaged over 10 folds
autoplot(rr, type = "roc")
# learner predictions for the individual fold
$filter(1)
rrautoplot(rr, type = "prediction")
## Warning: Removed 2 rows containing missing values (geom_point).
2. Model Optimization
In machine learning, when you are dissatisfied with the performance of a model, you might ask yourself how to best improve the model:
Can it be improved by changing the hyperparameters of the learner?
Should you use a completely different learner for this particular task?
We will use the mlr3tuning package, which supports common tuning operations.
Hyperparameter Tuning
Hyperparameters are the parameters of the learners that control how a model is fit to the data. They are sometimes called second-level or second-order parameters of machine learning - the parameters of the models are the first-order parameters and “fit” to the data during model training.
The TuningInstance* Classes
We will examine the optimization of a simple classification tree on the Pima Indian Diabetes data set as an introductory example here.
library("mlr3verse")
= tsk("pima")
task print(task)
## <TaskClassif:pima> (768 x 9): Pima Indian Diabetes
## * Target: diabetes
## * Properties: twoclass
## * Features (8):
## - dbl (8): age, glucose, insulin, mass, pedigree, pregnant, pressure,
## triceps
We use the rpart classification tree and choose a subset of the hyperparameters we want to tune. This is often referred to as the “tuning space”.
= lrn("classif.rpart")
learner $param_set learner
## <ParamSet>
## id class lower upper nlevels default value
## 1: cp ParamDbl 0 1 Inf 0.01
## 2: keep_model ParamLgl NA NA 2 FALSE
## 3: maxcompete ParamInt 0 Inf Inf 4
## 4: maxdepth ParamInt 1 30 30 30
## 5: maxsurrogate ParamInt 0 Inf Inf 5
## 6: minbucket ParamInt 1 Inf Inf <NoDefault[3]>
## 7: minsplit ParamInt 1 Inf Inf 20
## 8: surrogatestyle ParamInt 0 1 2 0
## 9: usesurrogate ParamInt 0 2 3 2
## 10: xval ParamInt 0 Inf Inf 10 0
Here, we opt to tune two hyperparameters:
- Complexity hyperparameter
- Minsplit hyperparameter
Those hyperparameters needs to be bounded with lower and upper bounds.
= ps(
search_space cp = p_dbl(lower = 0.001, upper = 0.1),
minsplit = p_int(lower = 1, upper = 10)
) search_space
## <ParamSet>
## id class lower upper nlevels default value
## 1: cp ParamDbl 0.001 0.1 Inf <NoDefault[3]>
## 2: minsplit ParamInt 1.000 10.0 10 <NoDefault[3]>
Next, we need to specify how to evaluate the performance of a trained model. For this, we need to choose a resampling strategy and a performance measure.
= rsmp("holdout")
hout = msr("classif.ce") measure
Finally, we have to specify the budget available for tuning. mlr3 allows to specify complex termination criteria by selecting one of the available Terminators:
- TerminatorClockTime: Terminate after a given time.
- TerminatorEvals: Terminate after a given number of iterations.
- TerminatorPerfReached: Terminate after a specific performance has been reached.
- TerminatorStagnation: Terminate when tuning does find a better configuration for a given number of iterations.
- TerminatorCombo: A combination of the above in an ALL or ANY fashion.
= trm("evals", n_evals = 20)
evals20
= TuningInstanceSingleCrit$new(
instance task = task,
learner = learner,
resampling = hout,
measure = measure,
search_space = search_space,
terminator = evals20
) instance
## <TuningInstanceSingleCrit>
## * State: Not optimized
## * Objective: <ObjectiveTuning:classif.rpart_on_pima>
## * Search Space:
## id class lower upper nlevels
## 1: cp ParamDbl 0.001 0.1 Inf
## 2: minsplit ParamInt 1.000 10.0 10
## * Terminator: <TerminatorEvals>
In order to start the tuning, we need to select how the optimization should happen (choose the optimization algorithm via the Tuner class).
The Tuner Class
The following algorithms are currently implemented in mlr3tuning:
- TunerGridSearch: Grid Search
- TunerRandomSearch: Random Search
- TunerGenSA: Generalized Simulated Annealing
- TunerNLoptr: Non-Linear Optimization
= tnr("grid_search", resolution = 5) tuner
Triggering the Tuning
The tuner proceeds as follows:
- The Tuner proposes at least one hyperparameter configuration to evaluate.
- For each configuration, the given Learner is fitted on the Task and evaluated using the provided Resampling.
- The Terminator is queried if the budget is exhausted.
- Determine the configurations with the best observed performance from the archive.
- Store the best configurations as result in the tuning instance object.
To start the tuning, we simply pass the TuningInstanceSingleCrit to the $optimize() method of the initialized Tuner:
$optimize(instance) tuner
## INFO [22:08:46.673] [bbotk] Starting to optimize 2 parameter(s) with '<TunerGridSearch>' and '<TerminatorEvals> [n_evals=20, k=0]'
## INFO [22:08:46.677] [bbotk] Evaluating 1 configuration(s)
## INFO [22:08:46.689] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:46.694] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:46.714] [mlr3] Finished benchmark
## INFO [22:08:46.739] [bbotk] Result of batch 1:
## INFO [22:08:46.741] [bbotk] cp minsplit classif.ce warnings errors runtime_learners
## INFO [22:08:46.741] [bbotk] 0.02575 8 0.28125 0 0 0.02
## INFO [22:08:46.741] [bbotk] uhash
## INFO [22:08:46.741] [bbotk] c01b7751-febd-44c0-8d39-f9135fadf574
## INFO [22:08:46.742] [bbotk] Evaluating 1 configuration(s)
## INFO [22:08:46.751] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:46.757] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:46.776] [mlr3] Finished benchmark
## INFO [22:08:46.803] [bbotk] Result of batch 2:
## INFO [22:08:46.804] [bbotk] cp minsplit classif.ce warnings errors runtime_learners
## INFO [22:08:46.804] [bbotk] 0.1 8 0.28125 0 0 0.01
## INFO [22:08:46.804] [bbotk] uhash
## INFO [22:08:46.804] [bbotk] 1c066822-8cf0-4603-b2e0-40bd0e95790a
## INFO [22:08:46.805] [bbotk] Evaluating 1 configuration(s)
## INFO [22:08:46.815] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:46.825] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:46.847] [mlr3] Finished benchmark
## INFO [22:08:46.872] [bbotk] Result of batch 3:
## INFO [22:08:46.873] [bbotk] cp minsplit classif.ce warnings errors runtime_learners
## INFO [22:08:46.873] [bbotk] 0.0505 3 0.28125 0 0 0
## INFO [22:08:46.873] [bbotk] uhash
## INFO [22:08:46.873] [bbotk] b2de11c9-7ebb-46f1-bbde-7630850622d5
## INFO [22:08:46.874] [bbotk] Evaluating 1 configuration(s)
## INFO [22:08:46.886] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:46.891] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:46.908] [mlr3] Finished benchmark
## INFO [22:08:46.932] [bbotk] Result of batch 4:
## INFO [22:08:46.933] [bbotk] cp minsplit classif.ce warnings errors runtime_learners
## INFO [22:08:46.933] [bbotk] 0.0505 1 0.28125 0 0 0
## INFO [22:08:46.933] [bbotk] uhash
## INFO [22:08:46.933] [bbotk] cf8b2097-7654-4f1e-8fa4-7d3db27aec39
## INFO [22:08:46.935] [bbotk] Evaluating 1 configuration(s)
## INFO [22:08:46.944] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:46.948] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:46.964] [mlr3] Finished benchmark
## INFO [22:08:46.999] [bbotk] Result of batch 5:
## INFO [22:08:47.001] [bbotk] cp minsplit classif.ce warnings errors runtime_learners
## INFO [22:08:47.001] [bbotk] 0.02575 3 0.28125 0 0 0
## INFO [22:08:47.001] [bbotk] uhash
## INFO [22:08:47.001] [bbotk] 35988734-d4b4-46a7-858e-12750e2a8092
## INFO [22:08:47.003] [bbotk] Evaluating 1 configuration(s)
## INFO [22:08:47.015] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:47.020] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:47.041] [mlr3] Finished benchmark
## INFO [22:08:47.079] [bbotk] Result of batch 6:
## INFO [22:08:47.081] [bbotk] cp minsplit classif.ce warnings errors runtime_learners
## INFO [22:08:47.081] [bbotk] 0.001 1 0.2890625 0 0 0.02
## INFO [22:08:47.081] [bbotk] uhash
## INFO [22:08:47.081] [bbotk] 29f2c9c1-c491-4511-86a0-2373e6fd56d9
## INFO [22:08:47.083] [bbotk] Evaluating 1 configuration(s)
## INFO [22:08:47.094] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:47.100] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:47.118] [mlr3] Finished benchmark
## INFO [22:08:47.142] [bbotk] Result of batch 7:
## INFO [22:08:47.143] [bbotk] cp minsplit classif.ce warnings errors runtime_learners
## INFO [22:08:47.143] [bbotk] 0.0505 10 0.28125 0 0 0.02
## INFO [22:08:47.143] [bbotk] uhash
## INFO [22:08:47.143] [bbotk] df71699d-54b9-4fb6-86e9-05d6b4dd0811
## INFO [22:08:47.144] [bbotk] Evaluating 1 configuration(s)
## INFO [22:08:47.154] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:47.158] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:47.175] [mlr3] Finished benchmark
## INFO [22:08:47.200] [bbotk] Result of batch 8:
## INFO [22:08:47.202] [bbotk] cp minsplit classif.ce warnings errors runtime_learners
## INFO [22:08:47.202] [bbotk] 0.0505 8 0.28125 0 0 0
## INFO [22:08:47.202] [bbotk] uhash
## INFO [22:08:47.202] [bbotk] 866067f1-a2c4-4f11-bbd6-4267d4e1de6e
## INFO [22:08:47.203] [bbotk] Evaluating 1 configuration(s)
## INFO [22:08:47.212] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:47.217] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:47.233] [mlr3] Finished benchmark
## INFO [22:08:47.256] [bbotk] Result of batch 9:
## INFO [22:08:47.257] [bbotk] cp minsplit classif.ce warnings errors runtime_learners
## INFO [22:08:47.257] [bbotk] 0.0505 5 0.28125 0 0 0.01
## INFO [22:08:47.257] [bbotk] uhash
## INFO [22:08:47.257] [bbotk] 8570ce12-777a-41c4-b84d-f096c3a628e6
## INFO [22:08:47.258] [bbotk] Evaluating 1 configuration(s)
## INFO [22:08:47.267] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:47.272] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:47.291] [mlr3] Finished benchmark
## INFO [22:08:47.316] [bbotk] Result of batch 10:
## INFO [22:08:47.317] [bbotk] cp minsplit classif.ce warnings errors runtime_learners
## INFO [22:08:47.317] [bbotk] 0.07525 8 0.28125 0 0 0.02
## INFO [22:08:47.317] [bbotk] uhash
## INFO [22:08:47.317] [bbotk] cde18441-f94b-4eb1-bd93-e417427c918c
## INFO [22:08:47.318] [bbotk] Evaluating 1 configuration(s)
## INFO [22:08:47.327] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:47.331] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:47.348] [mlr3] Finished benchmark
## INFO [22:08:47.375] [bbotk] Result of batch 11:
## INFO [22:08:47.376] [bbotk] cp minsplit classif.ce warnings errors runtime_learners
## INFO [22:08:47.376] [bbotk] 0.1 5 0.28125 0 0 0
## INFO [22:08:47.376] [bbotk] uhash
## INFO [22:08:47.376] [bbotk] 0c73b1ef-0b7a-4b10-8147-a618496950ed
## INFO [22:08:47.377] [bbotk] Evaluating 1 configuration(s)
## INFO [22:08:47.387] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:47.390] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:47.409] [mlr3] Finished benchmark
## INFO [22:08:47.437] [bbotk] Result of batch 12:
## INFO [22:08:47.438] [bbotk] cp minsplit classif.ce warnings errors runtime_learners
## INFO [22:08:47.438] [bbotk] 0.001 8 0.2695312 0 0 0
## INFO [22:08:47.438] [bbotk] uhash
## INFO [22:08:47.438] [bbotk] 68ef7cae-b18d-4320-a44d-70080d077b7a
## INFO [22:08:47.439] [bbotk] Evaluating 1 configuration(s)
## INFO [22:08:47.450] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:47.454] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:47.484] [mlr3] Finished benchmark
## INFO [22:08:47.516] [bbotk] Result of batch 13:
## INFO [22:08:47.517] [bbotk] cp minsplit classif.ce warnings errors runtime_learners
## INFO [22:08:47.517] [bbotk] 0.07525 10 0.28125 0 0 0.03
## INFO [22:08:47.517] [bbotk] uhash
## INFO [22:08:47.517] [bbotk] 29a1c066-5c74-4264-8cf3-97625f4af5d4
## INFO [22:08:47.518] [bbotk] Evaluating 1 configuration(s)
## INFO [22:08:47.527] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:47.532] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:47.547] [mlr3] Finished benchmark
## INFO [22:08:47.571] [bbotk] Result of batch 14:
## INFO [22:08:47.573] [bbotk] cp minsplit classif.ce warnings errors runtime_learners
## INFO [22:08:47.573] [bbotk] 0.07525 5 0.28125 0 0 0.01
## INFO [22:08:47.573] [bbotk] uhash
## INFO [22:08:47.573] [bbotk] 3703156b-b43a-48bd-84aa-88e21c73bb92
## INFO [22:08:47.574] [bbotk] Evaluating 1 configuration(s)
## INFO [22:08:47.586] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:47.593] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:47.612] [mlr3] Finished benchmark
## INFO [22:08:47.635] [bbotk] Result of batch 15:
## INFO [22:08:47.636] [bbotk] cp minsplit classif.ce warnings errors runtime_learners
## INFO [22:08:47.636] [bbotk] 0.1 10 0.28125 0 0 0
## INFO [22:08:47.636] [bbotk] uhash
## INFO [22:08:47.636] [bbotk] 77407ca2-47db-4c31-91d8-4e4be3e030d4
## INFO [22:08:47.638] [bbotk] Evaluating 1 configuration(s)
## INFO [22:08:47.646] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:47.652] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:47.670] [mlr3] Finished benchmark
## INFO [22:08:47.692] [bbotk] Result of batch 16:
## INFO [22:08:47.694] [bbotk] cp minsplit classif.ce warnings errors runtime_learners
## INFO [22:08:47.694] [bbotk] 0.07525 3 0.28125 0 0 0.02
## INFO [22:08:47.694] [bbotk] uhash
## INFO [22:08:47.694] [bbotk] d5328765-b394-454c-adad-e8d35fa15197
## INFO [22:08:47.695] [bbotk] Evaluating 1 configuration(s)
## INFO [22:08:47.704] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:47.709] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:47.729] [mlr3] Finished benchmark
## INFO [22:08:47.753] [bbotk] Result of batch 17:
## INFO [22:08:47.755] [bbotk] cp minsplit classif.ce warnings errors runtime_learners
## INFO [22:08:47.755] [bbotk] 0.001 10 0.2695312 0 0 0.02
## INFO [22:08:47.755] [bbotk] uhash
## INFO [22:08:47.755] [bbotk] 48a21756-50c5-4963-b5a9-07ef39819964
## INFO [22:08:47.756] [bbotk] Evaluating 1 configuration(s)
## INFO [22:08:47.765] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:47.769] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:47.786] [mlr3] Finished benchmark
## INFO [22:08:47.816] [bbotk] Result of batch 18:
## INFO [22:08:47.818] [bbotk] cp minsplit classif.ce warnings errors runtime_learners
## INFO [22:08:47.818] [bbotk] 0.001 3 0.28125 0 0 0
## INFO [22:08:47.818] [bbotk] uhash
## INFO [22:08:47.818] [bbotk] e0de4d3c-cc87-4f23-9aee-7859ee57e319
## INFO [22:08:47.820] [bbotk] Evaluating 1 configuration(s)
## INFO [22:08:47.831] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:47.836] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:47.852] [mlr3] Finished benchmark
## INFO [22:08:47.890] [bbotk] Result of batch 19:
## INFO [22:08:47.892] [bbotk] cp minsplit classif.ce warnings errors runtime_learners
## INFO [22:08:47.892] [bbotk] 0.1 1 0.28125 0 0 0.01
## INFO [22:08:47.892] [bbotk] uhash
## INFO [22:08:47.892] [bbotk] e6a1322c-1140-42ad-902e-12b7d5930dfe
## INFO [22:08:47.893] [bbotk] Evaluating 1 configuration(s)
## INFO [22:08:47.902] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:47.906] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:47.923] [mlr3] Finished benchmark
## INFO [22:08:47.950] [bbotk] Result of batch 20:
## INFO [22:08:47.951] [bbotk] cp minsplit classif.ce warnings errors runtime_learners
## INFO [22:08:47.951] [bbotk] 0.001 5 0.265625 0 0 0.02
## INFO [22:08:47.951] [bbotk] uhash
## INFO [22:08:47.951] [bbotk] 1e88c5da-7ef2-47e3-8f25-05c420a6e896
## INFO [22:08:47.956] [bbotk] Finished optimizing after 20 evaluation(s)
## INFO [22:08:47.957] [bbotk] Result:
## INFO [22:08:47.958] [bbotk] cp minsplit learner_param_vals x_domain classif.ce
## INFO [22:08:47.958] [bbotk] 0.001 5 <list[3]> <list[2]> 0.265625
## cp minsplit learner_param_vals x_domain classif.ce
## 1: 0.001 5 <list[3]> <list[2]> 0.265625
# The best hyperparameter settings
$result_learner_param_vals instance
## $xval
## [1] 0
##
## $cp
## [1] 0.001
##
## $minsplit
## [1] 5
# Corresponding measured performance
$result_y instance
## classif.ce
## 0.265625
# You can investigate all of the evaluations that were performed
as.data.table(instance$archive)
## cp minsplit classif.ce x_domain_cp x_domain_minsplit runtime_learners
## 1: 0.02575 8 0.2812500 0.02575 8 0.02
## 2: 0.10000 8 0.2812500 0.10000 8 0.01
## 3: 0.05050 3 0.2812500 0.05050 3 0.00
## 4: 0.05050 1 0.2812500 0.05050 1 0.00
## 5: 0.02575 3 0.2812500 0.02575 3 0.00
## 6: 0.00100 1 0.2890625 0.00100 1 0.02
## 7: 0.05050 10 0.2812500 0.05050 10 0.02
## 8: 0.05050 8 0.2812500 0.05050 8 0.00
## 9: 0.05050 5 0.2812500 0.05050 5 0.01
## 10: 0.07525 8 0.2812500 0.07525 8 0.02
## 11: 0.10000 5 0.2812500 0.10000 5 0.00
## 12: 0.00100 8 0.2695312 0.00100 8 0.00
## 13: 0.07525 10 0.2812500 0.07525 10 0.03
## 14: 0.07525 5 0.2812500 0.07525 5 0.01
## 15: 0.10000 10 0.2812500 0.10000 10 0.00
## 16: 0.07525 3 0.2812500 0.07525 3 0.02
## 17: 0.00100 10 0.2695312 0.00100 10 0.02
## 18: 0.00100 3 0.2812500 0.00100 3 0.00
## 19: 0.10000 1 0.2812500 0.10000 1 0.01
## 20: 0.00100 5 0.2656250 0.00100 5 0.02
## timestamp batch_nr warnings errors resample_result
## 1: 2022-04-09 22:08:46 1 0 0 <ResampleResult[22]>
## 2: 2022-04-09 22:08:46 2 0 0 <ResampleResult[22]>
## 3: 2022-04-09 22:08:46 3 0 0 <ResampleResult[22]>
## 4: 2022-04-09 22:08:46 4 0 0 <ResampleResult[22]>
## 5: 2022-04-09 22:08:46 5 0 0 <ResampleResult[22]>
## 6: 2022-04-09 22:08:47 6 0 0 <ResampleResult[22]>
## 7: 2022-04-09 22:08:47 7 0 0 <ResampleResult[22]>
## 8: 2022-04-09 22:08:47 8 0 0 <ResampleResult[22]>
## 9: 2022-04-09 22:08:47 9 0 0 <ResampleResult[22]>
## 10: 2022-04-09 22:08:47 10 0 0 <ResampleResult[22]>
## 11: 2022-04-09 22:08:47 11 0 0 <ResampleResult[22]>
## 12: 2022-04-09 22:08:47 12 0 0 <ResampleResult[22]>
## 13: 2022-04-09 22:08:47 13 0 0 <ResampleResult[22]>
## 14: 2022-04-09 22:08:47 14 0 0 <ResampleResult[22]>
## 15: 2022-04-09 22:08:47 15 0 0 <ResampleResult[22]>
## 16: 2022-04-09 22:08:47 16 0 0 <ResampleResult[22]>
## 17: 2022-04-09 22:08:47 17 0 0 <ResampleResult[22]>
## 18: 2022-04-09 22:08:47 18 0 0 <ResampleResult[22]>
## 19: 2022-04-09 22:08:47 19 0 0 <ResampleResult[22]>
## 20: 2022-04-09 22:08:47 20 0 0 <ResampleResult[22]>
Now we can take the optimized hyperparameters, set them for the previously-created Learner, and train it on the full dataset.
$param_set$values = instance$result_learner_param_vals
learner$train(task) learner
Tuning with Multiple Performance Measures
When tuning, you might want to use multiple criteria to find the best configuration of hyperparameters. The tuning process is identical to the previous example, with the expection that this time we will specify two performance measures, classification error and time to train the model (time_train) and instead of creating a new TuningInstanceSingleCrit, we create a new TuningInstanceMultiCrit with the two measures we are interested in here.
= msrs(c("classif.ce", "time_train"))
measures
= trm("evals", n_evals = 20)
evals20
= TuningInstanceMultiCrit$new(
instance task = task,
learner = learner,
resampling = hout,
measures = measures,
search_space = search_space,
terminator = evals20
) instance
## <TuningInstanceMultiCrit>
## * State: Not optimized
## * Objective: <ObjectiveTuning:classif.rpart_on_pima>
## * Search Space:
## id class lower upper nlevels
## 1: cp ParamDbl 0.001 0.1 Inf
## 2: minsplit ParamInt 1.000 10.0 10
## * Terminator: <TerminatorEvals>
Automating the Tuning
This whole process can be automated in mlr3 so that learners are tuned transparently, without the need to extract information on the best hyperparameter settings at the end. The AutoTuner wraps a learner and augments it with an automatic tuning process for a given set of hyperparameters.
= lrn("classif.rpart")
learner = ps(
search_space cp = p_dbl(lower = 0.001, upper = 0.1),
minsplit = p_int(lower = 1, upper = 10)
)= trm("evals", n_evals = 10)
terminator = tnr("random_search")
tuner
= AutoTuner$new(
at learner = learner,
resampling = rsmp("holdout"),
measure = msr("classif.ce"),
search_space = search_space,
terminator = terminator,
tuner = tuner
) at
## <AutoTuner:classif.rpart.tuned>
## * Model: -
## * Search Space:
## <ParamSet>
## id class lower upper nlevels default value
## 1: cp ParamDbl 0.001 0.1 Inf <NoDefault[3]>
## 2: minsplit ParamInt 1.000 10.0 10 <NoDefault[3]>
## * Packages: mlr3, mlr3tuning, rpart
## * Predict Type: response
## * Feature Types: logical, integer, numeric, factor, ordered
## * Properties: importance, missings, multiclass, selected_features,
## twoclass, weights
We can now use the learner like any other learner, calling the \(\$train()\) and \(\$predict()\) method.
$train(task) at
## INFO [22:08:48.175] [bbotk] Starting to optimize 2 parameter(s) with '<OptimizerRandomSearch>' and '<TerminatorEvals> [n_evals=10, k=0]'
## INFO [22:08:48.189] [bbotk] Evaluating 1 configuration(s)
## INFO [22:08:48.199] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:48.204] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:48.221] [mlr3] Finished benchmark
## INFO [22:08:48.246] [bbotk] Result of batch 1:
## INFO [22:08:48.248] [bbotk] cp minsplit classif.ce warnings errors runtime_learners
## INFO [22:08:48.248] [bbotk] 0.05141753 10 0.2734375 0 0 0
## INFO [22:08:48.248] [bbotk] uhash
## INFO [22:08:48.248] [bbotk] 0613da77-0978-4cf9-b3cd-506a65659632
## INFO [22:08:48.251] [bbotk] Evaluating 1 configuration(s)
## INFO [22:08:48.261] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:48.266] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:48.282] [mlr3] Finished benchmark
## INFO [22:08:48.308] [bbotk] Result of batch 2:
## INFO [22:08:48.310] [bbotk] cp minsplit classif.ce warnings errors runtime_learners
## INFO [22:08:48.310] [bbotk] 0.032403 9 0.2734375 0 0 0
## INFO [22:08:48.310] [bbotk] uhash
## INFO [22:08:48.310] [bbotk] 3421329e-b977-4580-a440-d3b137b8ced0
## INFO [22:08:48.313] [bbotk] Evaluating 1 configuration(s)
## INFO [22:08:48.325] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:48.329] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:48.348] [mlr3] Finished benchmark
## INFO [22:08:48.373] [bbotk] Result of batch 3:
## INFO [22:08:48.374] [bbotk] cp minsplit classif.ce warnings errors runtime_learners
## INFO [22:08:48.374] [bbotk] 0.03644441 6 0.2734375 0 0 0
## INFO [22:08:48.374] [bbotk] uhash
## INFO [22:08:48.374] [bbotk] 1ee461ea-f475-4883-8498-1fd50234e642
## INFO [22:08:48.377] [bbotk] Evaluating 1 configuration(s)
## INFO [22:08:48.395] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:48.402] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:48.421] [mlr3] Finished benchmark
## INFO [22:08:48.444] [bbotk] Result of batch 4:
## INFO [22:08:48.445] [bbotk] cp minsplit classif.ce warnings errors runtime_learners
## INFO [22:08:48.445] [bbotk] 0.08363938 10 0.2734375 0 0 0.02
## INFO [22:08:48.445] [bbotk] uhash
## INFO [22:08:48.445] [bbotk] ec9bc184-d8c9-4d0e-b933-08a7c1deb3c3
## INFO [22:08:48.448] [bbotk] Evaluating 1 configuration(s)
## INFO [22:08:48.458] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:48.462] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:48.480] [mlr3] Finished benchmark
## INFO [22:08:48.504] [bbotk] Result of batch 5:
## INFO [22:08:48.505] [bbotk] cp minsplit classif.ce warnings errors runtime_learners
## INFO [22:08:48.505] [bbotk] 0.003366969 4 0.3046875 0 0 0.02
## INFO [22:08:48.505] [bbotk] uhash
## INFO [22:08:48.505] [bbotk] ce726596-11d0-4460-8f36-d0f278fe7a08
## INFO [22:08:48.507] [bbotk] Evaluating 1 configuration(s)
## INFO [22:08:48.517] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:48.521] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:48.536] [mlr3] Finished benchmark
## INFO [22:08:48.560] [bbotk] Result of batch 6:
## INFO [22:08:48.562] [bbotk] cp minsplit classif.ce warnings errors runtime_learners
## INFO [22:08:48.562] [bbotk] 0.09247955 3 0.3046875 0 0 0
## INFO [22:08:48.562] [bbotk] uhash
## INFO [22:08:48.562] [bbotk] f442c1e1-7dd3-428f-8b25-74cde09deaeb
## INFO [22:08:48.564] [bbotk] Evaluating 1 configuration(s)
## INFO [22:08:48.574] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:48.581] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:48.597] [mlr3] Finished benchmark
## INFO [22:08:48.624] [bbotk] Result of batch 7:
## INFO [22:08:48.626] [bbotk] cp minsplit classif.ce warnings errors runtime_learners
## INFO [22:08:48.626] [bbotk] 0.04844731 1 0.2734375 0 0 0
## INFO [22:08:48.626] [bbotk] uhash
## INFO [22:08:48.626] [bbotk] f5b104c0-96a5-448c-8622-a669cbbe8d2c
## INFO [22:08:48.628] [bbotk] Evaluating 1 configuration(s)
## INFO [22:08:48.639] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:48.643] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:48.662] [mlr3] Finished benchmark
## INFO [22:08:48.686] [bbotk] Result of batch 8:
## INFO [22:08:48.687] [bbotk] cp minsplit classif.ce warnings errors runtime_learners
## INFO [22:08:48.687] [bbotk] 0.001663224 10 0.2578125 0 0 0
## INFO [22:08:48.687] [bbotk] uhash
## INFO [22:08:48.687] [bbotk] eefc555d-f204-4592-b0fb-c02db2641ba2
## INFO [22:08:48.690] [bbotk] Evaluating 1 configuration(s)
## INFO [22:08:48.703] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:48.711] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:48.732] [mlr3] Finished benchmark
## INFO [22:08:48.760] [bbotk] Result of batch 9:
## INFO [22:08:48.762] [bbotk] cp minsplit classif.ce warnings errors runtime_learners
## INFO [22:08:48.762] [bbotk] 0.005261892 5 0.2695312 0 0 0.02
## INFO [22:08:48.762] [bbotk] uhash
## INFO [22:08:48.762] [bbotk] fa9d2d9c-efa7-400d-93f3-4f16e1759dc7
## INFO [22:08:48.764] [bbotk] Evaluating 1 configuration(s)
## INFO [22:08:48.773] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:48.778] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:48.797] [mlr3] Finished benchmark
## INFO [22:08:48.821] [bbotk] Result of batch 10:
## INFO [22:08:48.823] [bbotk] cp minsplit classif.ce warnings errors runtime_learners
## INFO [22:08:48.823] [bbotk] 0.0542137 10 0.2734375 0 0 0.01
## INFO [22:08:48.823] [bbotk] uhash
## INFO [22:08:48.823] [bbotk] 4bc71a24-dbeb-4428-86e9-279aa3cdf345
## INFO [22:08:48.828] [bbotk] Finished optimizing after 10 evaluation(s)
## INFO [22:08:48.829] [bbotk] Result:
## INFO [22:08:48.830] [bbotk] cp minsplit learner_param_vals x_domain classif.ce
## INFO [22:08:48.830] [bbotk] 0.001663224 10 <list[3]> <list[2]> 0.2578125
Tuning Search Spaces
When running an optimization, it is important to inform the tuning algorithm about what hyperparameters are valid.
Creating ParamSets
Note, that ParamSet objects exist in two contexts. First, ParamSet-objects are used to define the space of valid parameter settings for a learner (and other objects). Second, they are used to define a search space for tuning.
ps takes named Domain arguments that are turned into parameters. A possible search space for the “classif.svm” learner could for example be:
= ps(
search_space cost = p_dbl(lower = 0.1, upper = 10),
kernel = p_fct(levels = c("polynomial", "radial"))
)print(search_space)
## <ParamSet>
## id class lower upper nlevels default value
## 1: cost ParamDbl 0.1 10 Inf <NoDefault[3]>
## 2: kernel ParamFct NA NA 2 <NoDefault[3]>
There are five domain constructors that produce a parameters when given to ps: - p_dbl: Real valued parameter (“double”) - p_int: Integer parameter - p_fct: Discrete valued parameter (“factor”) - p_lgl: Logical / Boolean parameter
- p_uty: Untyped parameter
These domain constructors each take some of the following arguments:
- lower, upper: lower and upper bound of numerical parameters.
- levels: Allowed categorical values for p_fct parameters.
- trafo: transformation function.
- depends: dependencies.
- tags: Further information about a parameter.
- default: Value corresponding to default behavior when the parameter is not given.
- special_vals: Valid values besides the normally accepted values for a parameter.
- custom_check: Function that checks whether a value given to p_uty is valid.
= ps(cost = p_dbl(0.1, 10), kernel = p_fct(c("polynomial", "radial"))) search_space
Transformations (trafo)
library("data.table")
rbindlist(generate_design_grid(search_space, 3)$transpose())
## cost kernel
## 1: 0.10 polynomial
## 2: 0.10 radial
## 3: 5.05 polynomial
## 4: 5.05 radial
## 5: 10.00 polynomial
## 6: 10.00 radial
We see that the cost parameter is taken on a linear scale. We assume, that the difference of cost between 0.1 and 1 should have a similar effect as the difference between 1 and 10. Therefore it makes more sense to tune it on a logarithmic scale. Which can be done by using a transformation (trafo).
= ps(
search_space cost = p_dbl(-1, 1, trafo = function(x) 10^x),
kernel = p_fct(c("polynomial", "radial"))
)rbindlist(generate_design_grid(search_space, 3)$transpose())
## cost kernel
## 1: 0.1 polynomial
## 2: 0.1 radial
## 3: 1.0 polynomial
## 4: 1.0 radial
## 5: 10.0 polynomial
## 6: 10.0 radial
It is even possible to attach another transformation to the ParamSet as a whole that gets executed after individual parameter’s transformations were performed (\(.extra\_trafo\)).
= ps(
search_space cost = p_dbl(-1, 1, trafo = function(x) 10^x),
kernel = p_fct(c("polynomial", "radial")),
.extra_trafo = function(x, param_set) {
if (x$kernel == "polynomial") {
$cost = x$cost * 2
x
}
x
}
)rbindlist(generate_design_grid(search_space, 3)$transpose())
## cost kernel
## 1: 0.2 polynomial
## 2: 0.1 radial
## 3: 2.0 polynomial
## 4: 1.0 radial
## 5: 20.0 polynomial
## 6: 10.0 radial
Automatic Factor Level Transformation
A common use-case is the necessity to specify a list of values that should all be tried (or sampled from). For this, the p_fct constructor’s level argument may be a value that is not a character vector, but something else. If, for example, only the values 0.1, 3, and 10 should be tried for the cost parameter, even when doing random search.
= ps(
search_space cost = p_fct(c(0.1, 3, 10)),
kernel = p_fct(c("polynomial", "radial"))
)rbindlist(generate_design_grid(search_space, 3)$transpose())
## cost kernel
## 1: 0.1 polynomial
## 2: 0.1 radial
## 3: 3.0 polynomial
## 4: 3.0 radial
## 5: 10.0 polynomial
## 6: 10.0 radial
Parameter Dependencies (depends)
Some parameters are only relevant when another parameter has a certain value, or one of several values. The Support Vector Machine (SVM), for example, has the degree parameter that is only valid when kernel is “polynomial” which can be specified with the following:
= ps(
search_space cost = p_dbl(-1, 1, trafo = function(x) 10^x),
kernel = p_fct(c("polynomial", "radial")),
degree = p_int(1, 3, depends = kernel == "polynomial")
)rbindlist(generate_design_grid(search_space, 3)$transpose(), fill = TRUE)
## cost kernel degree
## 1: 0.1 polynomial 1
## 2: 0.1 polynomial 2
## 3: 0.1 polynomial 3
## 4: 0.1 radial NA
## 5: 1.0 polynomial 1
## 6: 1.0 polynomial 2
## 7: 1.0 polynomial 3
## 8: 1.0 radial NA
## 9: 10.0 polynomial 1
## 10: 10.0 polynomial 2
## 11: 10.0 polynomial 3
## 12: 10.0 radial NA
Creating Tuning ParamSets from other ParamSets
There is a way to define a tuning ParamSet for a Learner that already has parameter set information. This is done by setting values of a Learner’s ParamSet to so-called TuneTokens. This can be done in the same way that other hyperparameters are set to specific values. It can be understood as the hyperparameters being tagged for later tuning.
Nested Resampling
It is crucial to have an additional layer of resampling, when we have to do the hyperparameters tuning and/or features selection in our model. Using the same data for both model selection and evaluation might severely bias the performance estimate due to for instance test data leaking information about its structure into the model.
Nested resampling is a statistical procedure that separates the model selection steps from estimating the process estimating the performance of the model.
Fig. 3: Nested resampling scheme with 3-fold cross-validation in the outer resampling and 4-fold cross-validation in the inner resampling
Nested resampling process from above:
- Use a 3-fold cross-validation to get different testing and training data sets (outer resampling).
- Within the training data use a 4-fold cross-validation to get different inner testing and training data sets (inner resampling).
- Tunes the hyperparameters using the inner data splits.
- Fit the learner on the outer training data set using the tuned hyperparameter configuration obtained with the inner resampling.
- Evaluate the performance of the learner on the outer testing data.
- 2-5 is repeated for each of the three folds (outer resampling).
- The three performance values are aggregated for an unbiased performance estimate.
Execution
To execute the above algorithm, we will use the AutoTuner
to wrap a learner and augment it with an automatic tuning process for a given set of hyperparameters.
= lrn("classif.rpart")
learner
# 4-fold cross-validation in the inner resampling loop
= rsmp("cv", folds = 4)
resampling
= msr("classif.ce")
measure
# hyperparameter configurations are proposed by grid search
= ps(cp = p_dbl(lower = 0.001, upper = 0.1))
search_space
# terminator triggered after 5 evaluations
= trm("evals", n_evals = 5)
terminator
= tnr("grid_search", resolution = 10)
tuner
= AutoTuner$new(learner, resampling, measure, terminator, tuner, search_space) at
On each of the three outer train sets hyperparameter tuning is done and we receive three optimized hyperparameter configurations. To execute the nested resampling, we pass the AutoTuner
to the resample() function.
= tsk("pima")
task
#3-fold cross-validation in the outer resampling loop
= rsmp("cv", folds = 3)
outer_resampling
# store_models = TRUE because we need the AutoTuner models to investigate the inner tuning
= resample(task, at, outer_resampling, store_models = TRUE) rr
## INFO [22:08:49.108] [mlr3] Applying learner 'classif.rpart.tuned' on task 'pima' (iter 1/3)
## INFO [22:08:49.151] [bbotk] Starting to optimize 1 parameter(s) with '<TunerGridSearch>' and '<TerminatorEvals> [n_evals=5, k=0]'
## INFO [22:08:49.154] [bbotk] Evaluating 1 configuration(s)
## INFO [22:08:49.165] [mlr3] Running benchmark with 4 resampling iterations
## INFO [22:08:49.170] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 4/4)
## INFO [22:08:49.188] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/4)
## INFO [22:08:49.205] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 3/4)
## INFO [22:08:49.223] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 2/4)
## INFO [22:08:49.242] [mlr3] Finished benchmark
## INFO [22:08:49.267] [bbotk] Result of batch 1:
## INFO [22:08:49.268] [bbotk] cp classif.ce warnings errors runtime_learners
## INFO [22:08:49.268] [bbotk] 0.089 0.2617188 0 0 0.02
## INFO [22:08:49.268] [bbotk] uhash
## INFO [22:08:49.268] [bbotk] 16649172-e4dc-4708-b395-67b0c37ebf27
## INFO [22:08:49.269] [bbotk] Evaluating 1 configuration(s)
## INFO [22:08:49.278] [mlr3] Running benchmark with 4 resampling iterations
## INFO [22:08:49.283] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 3/4)
## INFO [22:08:49.300] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/4)
## INFO [22:08:49.318] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 4/4)
## INFO [22:08:49.335] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 2/4)
## INFO [22:08:49.352] [mlr3] Finished benchmark
## INFO [22:08:49.390] [bbotk] Result of batch 2:
## INFO [22:08:49.392] [bbotk] cp classif.ce warnings errors runtime_learners
## INFO [22:08:49.392] [bbotk] 0.034 0.2539062 0 0 0.03
## INFO [22:08:49.392] [bbotk] uhash
## INFO [22:08:49.392] [bbotk] e38091de-c504-4961-98c4-f81f587b0389
## INFO [22:08:49.393] [bbotk] Evaluating 1 configuration(s)
## INFO [22:08:49.402] [mlr3] Running benchmark with 4 resampling iterations
## INFO [22:08:49.406] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 4/4)
## INFO [22:08:49.423] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/4)
## INFO [22:08:49.440] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 3/4)
## INFO [22:08:49.458] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 2/4)
## INFO [22:08:49.478] [mlr3] Finished benchmark
## INFO [22:08:49.505] [bbotk] Result of batch 3:
## INFO [22:08:49.507] [bbotk] cp classif.ce warnings errors runtime_learners
## INFO [22:08:49.507] [bbotk] 0.045 0.2636719 0 0 0.02
## INFO [22:08:49.507] [bbotk] uhash
## INFO [22:08:49.507] [bbotk] c22c5a56-02a8-423d-bdfb-e20e9268cddf
## INFO [22:08:49.508] [bbotk] Evaluating 1 configuration(s)
## INFO [22:08:49.517] [mlr3] Running benchmark with 4 resampling iterations
## INFO [22:08:49.522] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/4)
## INFO [22:08:49.539] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 4/4)
## INFO [22:08:49.555] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 2/4)
## INFO [22:08:49.572] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 3/4)
## INFO [22:08:49.590] [mlr3] Finished benchmark
## INFO [22:08:49.621] [bbotk] Result of batch 4:
## INFO [22:08:49.622] [bbotk] cp classif.ce warnings errors runtime_learners
## INFO [22:08:49.622] [bbotk] 0.056 0.2617188 0 0 0.03
## INFO [22:08:49.622] [bbotk] uhash
## INFO [22:08:49.622] [bbotk] f0ecc73f-1a3e-4610-8649-727539097ad2
## INFO [22:08:49.624] [bbotk] Evaluating 1 configuration(s)
## INFO [22:08:49.633] [mlr3] Running benchmark with 4 resampling iterations
## INFO [22:08:49.638] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 4/4)
## INFO [22:08:49.656] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/4)
## INFO [22:08:49.676] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 2/4)
## INFO [22:08:49.696] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 3/4)
## INFO [22:08:49.714] [mlr3] Finished benchmark
## INFO [22:08:49.741] [bbotk] Result of batch 5:
## INFO [22:08:49.742] [bbotk] cp classif.ce warnings errors runtime_learners
## INFO [22:08:49.742] [bbotk] 0.001 0.2617188 0 0 0.02
## INFO [22:08:49.742] [bbotk] uhash
## INFO [22:08:49.742] [bbotk] c256749c-c916-4519-87e4-0eb98c8d7ad0
## INFO [22:08:49.745] [bbotk] Finished optimizing after 5 evaluation(s)
## INFO [22:08:49.746] [bbotk] Result:
## INFO [22:08:49.747] [bbotk] cp learner_param_vals x_domain classif.ce
## INFO [22:08:49.747] [bbotk] 0.034 <list[2]> <list[1]> 0.2539062
## INFO [22:08:49.780] [mlr3] Applying learner 'classif.rpart.tuned' on task 'pima' (iter 3/3)
## INFO [22:08:49.811] [bbotk] Starting to optimize 1 parameter(s) with '<TunerGridSearch>' and '<TerminatorEvals> [n_evals=5, k=0]'
## INFO [22:08:49.813] [bbotk] Evaluating 1 configuration(s)
## INFO [22:08:49.822] [mlr3] Running benchmark with 4 resampling iterations
## INFO [22:08:49.826] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/4)
## INFO [22:08:49.846] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 3/4)
## INFO [22:08:49.868] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 2/4)
## INFO [22:08:49.885] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 4/4)
## INFO [22:08:49.906] [mlr3] Finished benchmark
## INFO [22:08:49.930] [bbotk] Result of batch 1:
## INFO [22:08:49.931] [bbotk] cp classif.ce warnings errors runtime_learners
## INFO [22:08:49.931] [bbotk] 0.034 0.2265625 0 0 0.01
## INFO [22:08:49.931] [bbotk] uhash
## INFO [22:08:49.931] [bbotk] 6c4cd602-d4f1-4cf8-a680-5ef7596d818a
## INFO [22:08:49.932] [bbotk] Evaluating 1 configuration(s)
## INFO [22:08:49.941] [mlr3] Running benchmark with 4 resampling iterations
## INFO [22:08:49.945] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 3/4)
## INFO [22:08:49.962] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 2/4)
## INFO [22:08:49.980] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 4/4)
## INFO [22:08:49.997] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/4)
## INFO [22:08:50.019] [mlr3] Finished benchmark
## INFO [22:08:50.049] [bbotk] Result of batch 2:
## INFO [22:08:50.050] [bbotk] cp classif.ce warnings errors runtime_learners
## INFO [22:08:50.050] [bbotk] 0.001 0.2304688 0 0 0.06
## INFO [22:08:50.050] [bbotk] uhash
## INFO [22:08:50.050] [bbotk] 563187f9-d7ab-4a47-b52a-b80a3dfd53aa
## INFO [22:08:50.051] [bbotk] Evaluating 1 configuration(s)
## INFO [22:08:50.059] [mlr3] Running benchmark with 4 resampling iterations
## INFO [22:08:50.064] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 4/4)
## INFO [22:08:50.081] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 2/4)
## INFO [22:08:50.098] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 3/4)
## INFO [22:08:50.115] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/4)
## INFO [22:08:50.134] [mlr3] Finished benchmark
## INFO [22:08:50.162] [bbotk] Result of batch 3:
## INFO [22:08:50.163] [bbotk] cp classif.ce warnings errors runtime_learners
## INFO [22:08:50.163] [bbotk] 0.012 0.2226562 0 0 0.03
## INFO [22:08:50.163] [bbotk] uhash
## INFO [22:08:50.163] [bbotk] b78f46bf-33c4-450c-8339-8186de327e50
## INFO [22:08:50.164] [bbotk] Evaluating 1 configuration(s)
## INFO [22:08:50.174] [mlr3] Running benchmark with 4 resampling iterations
## INFO [22:08:50.179] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 4/4)
## INFO [22:08:50.202] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 3/4)
## INFO [22:08:50.220] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 2/4)
## INFO [22:08:50.237] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/4)
## INFO [22:08:50.255] [mlr3] Finished benchmark
## INFO [22:08:50.283] [bbotk] Result of batch 4:
## INFO [22:08:50.285] [bbotk] cp classif.ce warnings errors runtime_learners
## INFO [22:08:50.285] [bbotk] 0.067 0.2324219 0 0 0.03
## INFO [22:08:50.285] [bbotk] uhash
## INFO [22:08:50.285] [bbotk] 226ac13a-be0e-44ca-9896-6eee9fabc3c8
## INFO [22:08:50.286] [bbotk] Evaluating 1 configuration(s)
## INFO [22:08:50.297] [mlr3] Running benchmark with 4 resampling iterations
## INFO [22:08:50.305] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 4/4)
## INFO [22:08:50.328] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 3/4)
## INFO [22:08:50.345] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 2/4)
## INFO [22:08:50.361] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/4)
## INFO [22:08:50.378] [mlr3] Finished benchmark
## INFO [22:08:50.405] [bbotk] Result of batch 5:
## INFO [22:08:50.407] [bbotk] cp classif.ce warnings errors runtime_learners
## INFO [22:08:50.407] [bbotk] 0.078 0.2324219 0 0 0.03
## INFO [22:08:50.407] [bbotk] uhash
## INFO [22:08:50.407] [bbotk] e216461b-1fb6-4644-a6b9-9291007fca55
## INFO [22:08:50.411] [bbotk] Finished optimizing after 5 evaluation(s)
## INFO [22:08:50.411] [bbotk] Result:
## INFO [22:08:50.412] [bbotk] cp learner_param_vals x_domain classif.ce
## INFO [22:08:50.412] [bbotk] 0.012 <list[2]> <list[1]> 0.2226562
## INFO [22:08:50.444] [mlr3] Applying learner 'classif.rpart.tuned' on task 'pima' (iter 2/3)
## INFO [22:08:50.475] [bbotk] Starting to optimize 1 parameter(s) with '<TunerGridSearch>' and '<TerminatorEvals> [n_evals=5, k=0]'
## INFO [22:08:50.477] [bbotk] Evaluating 1 configuration(s)
## INFO [22:08:50.486] [mlr3] Running benchmark with 4 resampling iterations
## INFO [22:08:50.490] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 2/4)
## INFO [22:08:50.509] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 4/4)
## INFO [22:08:50.533] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/4)
## INFO [22:08:50.550] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 3/4)
## INFO [22:08:50.568] [mlr3] Finished benchmark
## INFO [22:08:50.593] [bbotk] Result of batch 1:
## INFO [22:08:50.594] [bbotk] cp classif.ce warnings errors runtime_learners
## INFO [22:08:50.594] [bbotk] 0.045 0.2675781 0 0 0.02
## INFO [22:08:50.594] [bbotk] uhash
## INFO [22:08:50.594] [bbotk] ecd9dd73-6daa-4fda-9e40-e2632ea323f0
## INFO [22:08:50.595] [bbotk] Evaluating 1 configuration(s)
## INFO [22:08:50.604] [mlr3] Running benchmark with 4 resampling iterations
## INFO [22:08:50.609] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 4/4)
## INFO [22:08:50.634] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 2/4)
## INFO [22:08:50.654] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/4)
## INFO [22:08:50.673] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 3/4)
## INFO [22:08:50.691] [mlr3] Finished benchmark
## INFO [22:08:50.717] [bbotk] Result of batch 2:
## INFO [22:08:50.718] [bbotk] cp classif.ce warnings errors runtime_learners
## INFO [22:08:50.718] [bbotk] 0.056 0.2792969 0 0 0.02
## INFO [22:08:50.718] [bbotk] uhash
## INFO [22:08:50.718] [bbotk] c2866cc7-6373-4de8-844e-4f760917e25e
## INFO [22:08:50.719] [bbotk] Evaluating 1 configuration(s)
## INFO [22:08:50.728] [mlr3] Running benchmark with 4 resampling iterations
## INFO [22:08:50.734] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/4)
## INFO [22:08:50.756] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 4/4)
## INFO [22:08:50.775] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 2/4)
## INFO [22:08:50.793] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 3/4)
## INFO [22:08:50.812] [mlr3] Finished benchmark
## INFO [22:08:50.839] [bbotk] Result of batch 3:
## INFO [22:08:50.840] [bbotk] cp classif.ce warnings errors runtime_learners
## INFO [22:08:50.840] [bbotk] 0.1 0.2578125 0 0 0.03
## INFO [22:08:50.840] [bbotk] uhash
## INFO [22:08:50.840] [bbotk] e82a9fcd-ca94-4e5a-b8a7-a83b1f9ae207
## INFO [22:08:50.841] [bbotk] Evaluating 1 configuration(s)
## INFO [22:08:50.850] [mlr3] Running benchmark with 4 resampling iterations
## INFO [22:08:50.855] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 2/4)
## INFO [22:08:50.873] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 4/4)
## INFO [22:08:50.895] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/4)
## INFO [22:08:50.913] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 3/4)
## INFO [22:08:50.932] [mlr3] Finished benchmark
## INFO [22:08:50.964] [bbotk] Result of batch 4:
## INFO [22:08:50.965] [bbotk] cp classif.ce warnings errors runtime_learners
## INFO [22:08:50.965] [bbotk] 0.012 0.2539062 0 0 0.04
## INFO [22:08:50.965] [bbotk] uhash
## INFO [22:08:50.965] [bbotk] 3a85d117-16b2-4f03-9c1f-26a2adece9f0
## INFO [22:08:50.966] [bbotk] Evaluating 1 configuration(s)
## INFO [22:08:50.975] [mlr3] Running benchmark with 4 resampling iterations
## INFO [22:08:50.980] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 3/4)
## INFO [22:08:51.004] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 2/4)
## INFO [22:08:51.021] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 4/4)
## INFO [22:08:51.039] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/4)
## INFO [22:08:51.059] [mlr3] Finished benchmark
## INFO [22:08:51.083] [bbotk] Result of batch 5:
## INFO [22:08:51.084] [bbotk] cp classif.ce warnings errors runtime_learners
## INFO [22:08:51.084] [bbotk] 0.089 0.2578125 0 0 0.04
## INFO [22:08:51.084] [bbotk] uhash
## INFO [22:08:51.084] [bbotk] cb4403b2-774b-4101-a896-c56ac43fc7ec
## INFO [22:08:51.088] [bbotk] Finished optimizing after 5 evaluation(s)
## INFO [22:08:51.088] [bbotk] Result:
## INFO [22:08:51.089] [bbotk] cp learner_param_vals x_domain classif.ce
## INFO [22:08:51.089] [bbotk] 0.012 <list[2]> <list[1]> 0.2539062
Nested resampling is not restricted to hyperparameter tuning. One can swap the AutoTuner
for a AutoFSelector
and estimate the performance of a model which is fitted on an optimized feature subset.
Evaluation
With the created ResampleResult
we can now inspect the executed resampling iterations.
We check the inner tuning results for stable hyperparameters. Unstable models might be observed in case of the small data set and/or the low number of resampling iterations.
extract_inner_tuning_results(rr)
## iteration cp classif.ce learner_param_vals x_domain task_id
## 1: 1 0.034 0.2539062 <list[2]> <list[1]> pima
## 2: 2 0.012 0.2539062 <list[2]> <list[1]> pima
## 3: 3 0.012 0.2226562 <list[2]> <list[1]> pima
## learner_id resampling_id
## 1: classif.rpart.tuned cv
## 2: classif.rpart.tuned cv
## 3: classif.rpart.tuned cv
Next, compare the predictive performances estimated on the outer resampling to the inner resampling. Significantly lower predictive performances on the outer resampling indicate that the models with the optimized hyperparameters overfit the data.
$score() rr
## task task_id learner learner_id
## 1: <TaskClassif[50]> pima <AutoTuner[42]> classif.rpart.tuned
## 2: <TaskClassif[50]> pima <AutoTuner[42]> classif.rpart.tuned
## 3: <TaskClassif[50]> pima <AutoTuner[42]> classif.rpart.tuned
## resampling resampling_id iteration prediction
## 1: <ResamplingCV[20]> cv 1 <PredictionClassif[20]>
## 2: <ResamplingCV[20]> cv 2 <PredictionClassif[20]>
## 3: <ResamplingCV[20]> cv 3 <PredictionClassif[20]>
## classif.ce
## 1: 0.2773438
## 2: 0.2304688
## 3: 0.2304688
The aggregated performance of all outer resampling iterations is essentially the unbiased performance of the model with optimal hyperparameter found by grid search.
$aggregate() rr
## classif.ce
## 0.2460938
Note: nested resampling is computationally expensive.
Final Model
$train(task) at
## INFO [22:08:51.274] [bbotk] Starting to optimize 1 parameter(s) with '<TunerGridSearch>' and '<TerminatorEvals> [n_evals=5, k=0]'
## INFO [22:08:51.276] [bbotk] Evaluating 1 configuration(s)
## INFO [22:08:51.286] [mlr3] Running benchmark with 4 resampling iterations
## INFO [22:08:51.294] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/4)
## INFO [22:08:51.313] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 3/4)
## INFO [22:08:51.335] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 4/4)
## INFO [22:08:51.356] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 2/4)
## INFO [22:08:51.377] [mlr3] Finished benchmark
## INFO [22:08:51.409] [bbotk] Result of batch 1:
## INFO [22:08:51.411] [bbotk] cp classif.ce warnings errors runtime_learners
## INFO [22:08:51.411] [bbotk] 0.045 0.2604167 0 0 0.04
## INFO [22:08:51.411] [bbotk] uhash
## INFO [22:08:51.411] [bbotk] ef4b8462-df37-432e-a023-ca9461420990
## INFO [22:08:51.412] [bbotk] Evaluating 1 configuration(s)
## INFO [22:08:51.424] [mlr3] Running benchmark with 4 resampling iterations
## INFO [22:08:51.429] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/4)
## INFO [22:08:51.447] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 2/4)
## INFO [22:08:51.464] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 3/4)
## INFO [22:08:51.482] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 4/4)
## INFO [22:08:51.501] [mlr3] Finished benchmark
## INFO [22:08:51.526] [bbotk] Result of batch 2:
## INFO [22:08:51.528] [bbotk] cp classif.ce warnings errors runtime_learners
## INFO [22:08:51.528] [bbotk] 0.034 0.2604167 0 0 0.04
## INFO [22:08:51.528] [bbotk] uhash
## INFO [22:08:51.528] [bbotk] ff3f22ec-75f8-47da-a1bf-2fd493178c45
## INFO [22:08:51.529] [bbotk] Evaluating 1 configuration(s)
## INFO [22:08:51.538] [mlr3] Running benchmark with 4 resampling iterations
## INFO [22:08:51.542] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 4/4)
## INFO [22:08:51.561] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 2/4)
## INFO [22:08:51.579] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/4)
## INFO [22:08:51.598] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 3/4)
## INFO [22:08:51.616] [mlr3] Finished benchmark
## INFO [22:08:51.647] [bbotk] Result of batch 3:
## INFO [22:08:51.649] [bbotk] cp classif.ce warnings errors runtime_learners
## INFO [22:08:51.649] [bbotk] 0.001 0.2786458 0 0 0.04
## INFO [22:08:51.649] [bbotk] uhash
## INFO [22:08:51.649] [bbotk] ba037a87-e20f-4e74-8949-ab7d6b67329c
## INFO [22:08:51.650] [bbotk] Evaluating 1 configuration(s)
## INFO [22:08:51.659] [mlr3] Running benchmark with 4 resampling iterations
## INFO [22:08:51.663] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/4)
## INFO [22:08:51.681] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 2/4)
## INFO [22:08:51.700] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 3/4)
## INFO [22:08:51.718] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 4/4)
## INFO [22:08:51.736] [mlr3] Finished benchmark
## INFO [22:08:51.769] [bbotk] Result of batch 4:
## INFO [22:08:51.770] [bbotk] cp classif.ce warnings errors runtime_learners
## INFO [22:08:51.770] [bbotk] 0.023 0.2617188 0 0 0.06
## INFO [22:08:51.770] [bbotk] uhash
## INFO [22:08:51.770] [bbotk] e72b4b73-bcb2-40df-8a38-f63a1f9bb849
## INFO [22:08:51.771] [bbotk] Evaluating 1 configuration(s)
## INFO [22:08:51.782] [mlr3] Running benchmark with 4 resampling iterations
## INFO [22:08:51.786] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 4/4)
## INFO [22:08:51.809] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 3/4)
## INFO [22:08:51.828] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/4)
## INFO [22:08:51.845] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 2/4)
## INFO [22:08:51.863] [mlr3] Finished benchmark
## INFO [22:08:51.889] [bbotk] Result of batch 5:
## INFO [22:08:51.891] [bbotk] cp classif.ce warnings errors runtime_learners
## INFO [22:08:51.891] [bbotk] 0.056 0.2604167 0 0 0.03
## INFO [22:08:51.891] [bbotk] uhash
## INFO [22:08:51.891] [bbotk] ed3d52de-c3ba-4743-9f0b-692b34f5d515
## INFO [22:08:51.894] [bbotk] Finished optimizing after 5 evaluation(s)
## INFO [22:08:51.895] [bbotk] Result:
## INFO [22:08:51.896] [bbotk] cp learner_param_vals x_domain classif.ce
## INFO [22:08:51.896] [bbotk] 0.045 <list[2]> <list[1]> 0.2604167
The trained model can now be used to make predictions on new data.
Tuning with Hyperband
Besides the more traditional tuning methods, the ecosystem around mlr3 offers another procedure for hyperparameter optimization called Hyperband implemented in the mlr3hyperband package.
Fig. 4: Visulization of how the training processes may look like. The left plot corresponds to the non-selective trainer, while the right one to the selective trainer.
Hyperband is a budget-oriented procedure, weeding out suboptimal performing configurations early on during a partially sequential training process, increasing tuning efficiency as a consequence. For this, a combination of incremental resource allocation and early stopping is used: As optimization progresses, computational resources are increased for more promising configurations, while less promising ones are terminated early.
Hyperband:
- Consists of several brackets.
- Each bracket is placed at a unique spot between fully explorative of later training stages and extremely selective, equal to higher exploration of early training stages.
Thanks to the broad ecosystem of the mlr3verse
a learner does not require a natural budget parameter.
set.seed(52)
# extend "classif.rpart" with "subsampling" as preprocessing step
= po("subsample") %>>% lrn("classif.rpart")
ll
# extend hyperparameters of "classif.rpart" with subsampling fraction as budget
= ps(
search_space classif.rpart.cp = p_dbl(lower = 0.001, upper = 0.1),
classif.rpart.minsplit = p_int(lower = 1, upper = 10),
subsample.frac = p_dbl(lower = 0.1, upper = 1, tags = "budget")
)
Now plug the new learner with the extended hyperparameter set into a TuningInstanceSingleCrit
the same way as usual. Hyperband terminates once all of its brackets are evaluated, so a Terminator in the tuning instance acts as an upper bound and should be only set to a low value if one is unsure of how long Hyperband will take to finish.
= TuningInstanceSingleCrit$new(
instance task = tsk("iris"),
learner = ll,
resampling = rsmp("holdout"),
measure = msr("classif.ce"),
terminator = trm("none"), # hyperband terminates itself
search_space = search_space
)
Initialize a new instance of the mlr3hyperband::mlr_tuners_hyperband
class and start tuning. Every stage the configuration budget is increased by a factor of eta and only the best 1/eta points are used for the next stage.
library("mlr3hyperband")
## Warning: package 'mlr3hyperband' was built under R version 4.1.3
## Loading required package: mlr3tuning
## Loading required package: paradox
## Warning: package 'paradox' was built under R version 4.1.3
= tnr("hyperband", eta = 3)
tuner
# reduce logging output
::get_logger("bbotk")$set_threshold("warn")
lgr
$optimize(instance) tuner
## INFO [22:08:52.323] [mlr3] Running benchmark with 9 resampling iterations
## INFO [22:08:52.328] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1)
## INFO [22:08:52.385] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1)
## INFO [22:08:52.443] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1)
## INFO [22:08:52.504] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1)
## INFO [22:08:52.561] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1)
## INFO [22:08:52.622] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1)
## INFO [22:08:52.676] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1)
## INFO [22:08:52.725] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1)
## INFO [22:08:52.784] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1)
## INFO [22:08:52.885] [mlr3] Finished benchmark
## INFO [22:08:53.098] [mlr3] Running benchmark with 8 resampling iterations
## INFO [22:08:53.102] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1)
## INFO [22:08:53.156] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1)
## INFO [22:08:53.227] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1)
## INFO [22:08:53.279] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1)
## INFO [22:08:53.330] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1)
## INFO [22:08:53.386] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1)
## INFO [22:08:53.438] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1)
## INFO [22:08:53.493] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1)
## INFO [22:08:53.555] [mlr3] Finished benchmark
## INFO [22:08:53.721] [mlr3] Running benchmark with 5 resampling iterations
## INFO [22:08:53.726] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1)
## INFO [22:08:53.781] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1)
## INFO [22:08:53.845] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1)
## INFO [22:08:53.896] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1)
## INFO [22:08:53.946] [mlr3] Applying learner 'subsample.classif.rpart' on task 'iris' (iter 1/1)
## INFO [22:08:54.004] [mlr3] Finished benchmark
## classif.rpart.cp classif.rpart.minsplit subsample.frac learner_param_vals
## 1: 0.08171096 3 0.3333333 <list[6]>
## x_domain classif.ce
## 1: <list[3]> 0
Access the best found configuration through the instance object
$result instance
## classif.rpart.cp classif.rpart.minsplit subsample.frac learner_param_vals
## 1: 0.08171096 3 0.3333333 <list[6]>
## x_domain classif.ce
## 1: <list[3]> 0
$result_learner_param_vals instance
## $subsample.frac
## [1] 0.3333333
##
## $subsample.stratify
## [1] FALSE
##
## $subsample.replace
## [1] FALSE
##
## $classif.rpart.xval
## [1] 0
##
## $classif.rpart.cp
## [1] 0.08171096
##
## $classif.rpart.minsplit
## [1] 3
$result_y instance
## classif.ce
## 0
Feature Selection / Filtering
Two different approaches are emphasized in the literature: one is called Filtering and the other approach is often referred to as feature subset selection or wrapper methods.
Filtering:
- Computes a rank of the features (e.g. based on the correlation to the response).
- Features are subsetted by a certain criteria (e.g. an absolute number or a percentage of the number of variables).
- The selected features are then be used to fit a model (with optional hyperparameters selected by tuning).
This calculation is usually cheaper than “feature subset selection” in terms of computation time. All filters are connected via package mlr3filters
.
Wrapper Methods:
- Selects a subset of the features.
- Evaluates the set by calculating the resampled predictive performance.
- Proposes a new set of features (or terminates).
A simple example - sequential forward selection. This is usually computationally very intensive: a lot of models are fitted and all models need to be tuned before the performance is estimated.
Embedded Methods: Many learners internally select a subset of the features which they find helpful for prediction. These subsets can usually be queried:
library("mlr3verse")
= tsk("iris")
task = lrn("classif.rpart")
learner
# ensure that the learner selects features
stopifnot("selected_features" %in% learner$properties)
# fit a simple classification tree
= learner$train(task)
learner
# extract all features used in the classification tree:
$selected_features() learner
## [1] "Petal.Length" "Petal.Width"
Filters
Currently, only classification and regression tasks are supported.
First step - create a new R object using the class of the desired filter method. Each object of class Filter has a .$calculate()
method which computes the filter values and ranks them in a descending order.
= flt("jmim")
filter
= tsk("iris")
task $calculate(task)
filter
as.data.table(filter)
## feature score
## 1: Petal.Width 1.0000000
## 2: Sepal.Length 0.6666667
## 3: Petal.Length 0.3333333
## 4: Sepal.Width 0.0000000
We can also change the hyperparameters of the filter.
= flt("correlation")
filter_cor $param_set filter_cor
## <ParamSet>
## id class lower upper nlevels default value
## 1: use ParamFct NA NA 5 everything
## 2: method ParamFct NA NA 3 pearson
# change parameter 'method'
$param_set$values = list(method = "spearman")
filter_cor$param_set filter_cor
## <ParamSet>
## id class lower upper nlevels default value
## 1: use ParamFct NA NA 5 everything
## 2: method ParamFct NA NA 3 pearson spearman
Wrapper Methods
Wrapper feature selection is supported via the mlr3fselect
extension package. At the heart of mlr3fselect
are the R6 classes:
FSelectInstanceSingleCrit
: describes the feature selection problem and store the results.FSelector
base class for implementations of feature selection algorithms.
= tsk("pima")
task print(task)
## <TaskClassif:pima> (768 x 9): Pima Indian Diabetes
## * Target: diabetes
## * Properties: twoclass
## * Features (8):
## - dbl (8): age, glucose, insulin, mass, pedigree, pregnant, pressure,
## triceps
= lrn("classif.rpart") learner
Choose a resampling strategy and a performance measure.
= rsmp("holdout")
hout = msr("classif.ce") measure
Choose the available budget for the feature selection. This is done by selecting one of the available Terminators:
- Terminate after a given time (
TerminatorClockTime
) - Terminate after a given amount of iterations (
TerminatorEvals
) - Terminate after a specific performance is reached (
TerminatorPerfReached
) - Terminate when feature selection does not improve (
TerminatorStagnation
) - A combination of the above in an ALL or ANY fashion (
TerminatorCombo
)
= trm("evals", n_evals = 20)
evals20
= FSelectInstanceSingleCrit$new(
instance task = task,
learner = learner,
resampling = hout,
measure = measure,
terminator = evals20
)
instance
## <FSelectInstanceSingleCrit>
## * State: Not optimized
## * Objective: <ObjectiveFSelect:classif.rpart_on_pima>
## * Search Space:
## id class lower upper nlevels
## 1: age ParamLgl NA NA 2
## 2: glucose ParamLgl NA NA 2
## 3: insulin ParamLgl NA NA 2
## 4: mass ParamLgl NA NA 2
## 5: pedigree ParamLgl NA NA 2
## 6: pregnant ParamLgl NA NA 2
## 7: pressure ParamLgl NA NA 2
## 8: triceps ParamLgl NA NA 2
## * Terminator: <TerminatorEvals>
To start the feature selection, we still need to select an algorithm which are defined via the FSelector
class.
The following algorithms are currently implemented in mlr3fselect
:
- Random Search (
FSelectorRandomSearch
) - Exhaustive Search (
FSelectorExhaustiveSearch
) - Sequential Search (
FSelectorSequential
) - Recursive Feature Elimination (
FSelectorRFE
) - Design Points (
FSelectorDesignPoints
)
= fs("random_search") fselector
To start the feature selection, we simply pass the FSelectInstanceSingleCrit
to the $optimize() method of the initialized FSelector
.
# reduce logging output
::get_logger("bbotk")$set_threshold("warn")
lgr
$optimize(instance) fselector
## INFO [22:08:54.424] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:54.429] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:54.492] [mlr3] Finished benchmark
## INFO [22:08:54.587] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:54.594] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:54.656] [mlr3] Finished benchmark
## INFO [22:08:54.742] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:54.746] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:54.809] [mlr3] Finished benchmark
## INFO [22:08:54.907] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:54.911] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:54.972] [mlr3] Finished benchmark
## INFO [22:08:55.063] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:55.067] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:55.133] [mlr3] Finished benchmark
## INFO [22:08:55.255] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:55.260] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:55.323] [mlr3] Finished benchmark
## INFO [22:08:55.411] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:55.415] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:55.475] [mlr3] Finished benchmark
## INFO [22:08:55.575] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:55.579] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:55.640] [mlr3] Finished benchmark
## INFO [22:08:55.728] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:55.732] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:55.794] [mlr3] Finished benchmark
## INFO [22:08:55.894] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:55.899] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:55.960] [mlr3] Finished benchmark
## INFO [22:08:56.048] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:56.053] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:56.115] [mlr3] Finished benchmark
## INFO [22:08:56.229] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:56.233] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:56.294] [mlr3] Finished benchmark
## INFO [22:08:56.381] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:56.385] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:56.459] [mlr3] Finished benchmark
## INFO [22:08:56.553] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:56.559] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:56.624] [mlr3] Finished benchmark
## INFO [22:08:56.713] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:56.718] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:56.777] [mlr3] Finished benchmark
## INFO [22:08:56.878] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:56.883] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:56.942] [mlr3] Finished benchmark
## INFO [22:08:57.034] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:57.038] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:57.117] [mlr3] Finished benchmark
## INFO [22:08:57.211] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:57.215] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:57.275] [mlr3] Finished benchmark
## INFO [22:08:57.367] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:57.371] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:57.447] [mlr3] Finished benchmark
## INFO [22:08:57.530] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:57.535] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:57.598] [mlr3] Finished benchmark
## age glucose insulin mass pedigree pregnant pressure triceps
## 1: TRUE TRUE TRUE TRUE TRUE TRUE TRUE FALSE
## features classif.ce
## 1: age,glucose,insulin,mass,pedigree,pregnant,... 0.265625
$result_feature_set instance
## [1] "age" "glucose" "insulin" "mass" "pedigree" "pregnant" "pressure"
$result_feature_set instance
## [1] "age" "glucose" "insulin" "mass" "pedigree" "pregnant" "pressure"
$result_y instance
## classif.ce
## 0.265625
Automating the Feature Selection
The AutoFSelector
wraps a learner and augments it with an automatic feature selection for a given task. Analogously to the previous subsection, a new classification tree learner is created. This classification tree learner automatically starts a feature selection on the given task using an inner resampling (holdout).
= lrn("classif.rpart")
learner = trm("evals", n_evals = 10)
terminator = fs("random_search")
fselector
= AutoFSelector$new(
at learner = learner,
resampling = rsmp("holdout"),
measure = msr("classif.ce"),
terminator = terminator,
fselector = fselector
) at
## <AutoFSelector:classif.rpart.fselector>
## * Model: -
## * Parameters: xval=0
## * Packages: mlr3, mlr3fselect, rpart
## * Predict Type: response
## * Feature types: logical, integer, numeric, factor, ordered
## * Properties: importance, missings, multiclass, selected_features,
## twoclass, weights
We can now use the learner like any other learner, calling the \(train()** and **\)predict() method. This time however, we pass it to benchmark() to compare the optimized feature subset to the complete feature set. This way, the AutoFSelector
will do nested resampling for feature selection.
To compare the optimized feature subset with the complete feature set, we can use benchmark():
= benchmark_grid(
grid task = tsk("pima"),
learner = list(at, lrn("classif.rpart")),
resampling = rsmp("cv", folds = 3)
)
= benchmark(grid, store_models = TRUE) bmr
## INFO [22:08:57.797] [mlr3] Running benchmark with 6 resampling iterations
## INFO [22:08:57.803] [mlr3] Applying learner 'classif.rpart.fselector' on task 'pima' (iter 1/3)
## INFO [22:08:57.889] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:57.893] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:57.958] [mlr3] Finished benchmark
## INFO [22:08:58.059] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:58.064] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:58.124] [mlr3] Finished benchmark
## INFO [22:08:58.234] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:58.239] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:58.311] [mlr3] Finished benchmark
## INFO [22:08:58.420] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:58.425] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:58.508] [mlr3] Finished benchmark
## INFO [22:08:58.597] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:58.602] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:58.661] [mlr3] Finished benchmark
## INFO [22:08:58.756] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:58.764] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:58.841] [mlr3] Finished benchmark
## INFO [22:08:58.935] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:58.939] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:59.003] [mlr3] Finished benchmark
## INFO [22:08:59.103] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:59.109] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:59.169] [mlr3] Finished benchmark
## INFO [22:08:59.256] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:59.261] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:59.324] [mlr3] Finished benchmark
## INFO [22:08:59.436] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:59.442] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:59.505] [mlr3] Finished benchmark
## INFO [22:08:59.578] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 3/3)
## INFO [22:08:59.597] [mlr3] Applying learner 'classif.rpart.fselector' on task 'pima' (iter 2/3)
## INFO [22:08:59.687] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:59.693] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:59.760] [mlr3] Finished benchmark
## INFO [22:08:59.842] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:08:59.846] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:08:59.909] [mlr3] Finished benchmark
## INFO [22:08:59.995] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:09:00.000] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:09:00.087] [mlr3] Finished benchmark
## INFO [22:09:00.185] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:09:00.189] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:09:00.251] [mlr3] Finished benchmark
## INFO [22:09:00.563] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:09:00.567] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:09:00.639] [mlr3] Finished benchmark
## INFO [22:09:00.732] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:09:00.736] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:09:00.798] [mlr3] Finished benchmark
## INFO [22:09:00.891] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:09:00.897] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:09:00.967] [mlr3] Finished benchmark
## INFO [22:09:01.058] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:09:01.063] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:09:01.123] [mlr3] Finished benchmark
## INFO [22:09:01.215] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:09:01.220] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:09:01.284] [mlr3] Finished benchmark
## INFO [22:09:01.370] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:09:01.375] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:09:01.442] [mlr3] Finished benchmark
## INFO [22:09:01.509] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 2/3)
## INFO [22:09:01.535] [mlr3] Applying learner 'classif.rpart.fselector' on task 'pima' (iter 3/3)
## INFO [22:09:01.628] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:09:01.635] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:09:01.697] [mlr3] Finished benchmark
## INFO [22:09:01.781] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:09:01.786] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:09:01.854] [mlr3] Finished benchmark
## INFO [22:09:01.939] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:09:01.943] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:09:02.003] [mlr3] Finished benchmark
## INFO [22:09:02.091] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:09:02.096] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:09:02.159] [mlr3] Finished benchmark
## INFO [22:09:02.244] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:09:02.254] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:09:02.315] [mlr3] Finished benchmark
## INFO [22:09:02.400] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:09:02.405] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:09:02.462] [mlr3] Finished benchmark
## INFO [22:09:02.560] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:09:02.564] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:09:02.626] [mlr3] Finished benchmark
## INFO [22:09:02.720] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:09:02.724] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:09:02.796] [mlr3] Finished benchmark
## INFO [22:09:02.881] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:09:02.886] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:09:02.957] [mlr3] Finished benchmark
## INFO [22:09:03.040] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:09:03.045] [mlr3] Applying learner 'select.classif.rpart' on task 'pima' (iter 1/1)
## INFO [22:09:03.112] [mlr3] Finished benchmark
## INFO [22:09:03.186] [mlr3] Applying learner 'classif.rpart' on task 'pima' (iter 1/3)
## INFO [22:09:03.207] [mlr3] Finished benchmark
$aggregate(msrs(c("classif.ce", "time_train"))) bmr
## nr resample_result task_id learner_id resampling_id iters
## 1: 1 <ResampleResult[22]> pima classif.rpart.fselector cv 3
## 2: 2 <ResampleResult[22]> pima classif.rpart cv 3
## classif.ce time_train
## 1: 0.2721354 0
## 2: 0.2760417 0
Note that we do not expect any significant differences since we only evaluated a small fraction of the possible feature subsets.
Good-ole German dataset example
set.seed(47)
library(patchwork)
load(url("https://statmath.wu.ac.at/~hornik/DTM/german.rda"))
# Creating task for german dataset
<- as_task_classif(german, target = "Class", positive = "good")
task $feature_names task
## [1] "Age" "Amount"
## [3] "Duration" "Employment_since"
## [5] "Foreign" "History"
## [7] "Housing" "Installment_rate"
## [9] "Job" "N_of_credits"
## [11] "N_of_liables" "Other_debtors_or_guarantors"
## [13] "Other_installment_plans" "Phone"
## [15] "Property" "Purpose"
## [17] "Residence_since" "Savings"
## [19] "Status_and_sex" "Status_of_checking_account"
# Subset the task to desired features
$select(c("Status_of_checking_account", "Age", "Amount"))
task
# Plots provided with mlr3viz
autoplot(task)
autoplot(task, type = "pairs")
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
library(mlr3learners)
## Warning: package 'mlr3learners' was built under R version 4.1.3
# Logistic Regression Learner
<- lrn("classif.log_reg")
lr_learner # Random Forest Learner
<- lrn("classif.ranger")
rf_learner
# train/test rows
<- sample(task$nrow, 0.7 * task$nrow)
train <- setdiff(seq_len(task$nrow), train)
test
# Training the learners
$train(task, row_ids = train)
lr_learnerprint(lr_learner)
## <LearnerClassifLogReg:classif.log_reg>
## * Model: glm
## * Parameters: list()
## * Packages: mlr3, mlr3learners, stats
## * Predict Type: response
## * Feature types: logical, integer, numeric, character, factor, ordered
## * Properties: loglik, twoclass, weights
$train(task, row_ids = train)
rf_learnerprint(rf_learner)
## <LearnerClassifRanger:classif.ranger>
## * Model: ranger
## * Parameters: num.threads=1
## * Packages: mlr3, mlr3learners, ranger
## * Predict Type: response
## * Feature types: logical, integer, numeric, character, factor, ordered
## * Properties: hotstart_backward, importance, multiclass, oob_error,
## twoclass, weights
# Prediction on the test set
<- lr_learner$predict(task, row_ids = test)
lr_pred head(as.data.table(lr_pred))
## row_ids truth response
## 1: 3 good good
## 2: 6 good good
## 3: 7 good good
## 4: 11 bad good
## 5: 14 bad good
## 6: 19 bad bad
<- rf_learner$predict(task, row_ids = test)
rf_pred head(as.data.table(rf_pred))
## row_ids truth response
## 1: 3 good good
## 2: 6 good good
## 3: 7 good good
## 4: 11 bad good
## 5: 14 bad good
## 6: 19 bad bad
# Confusion Matrices
$confusion lr_pred
## truth
## response good bad
## good 191 76
## bad 11 22
$confusion rf_pred
## truth
## response good bad
## good 190 84
## bad 12 14
# Again autoplots from mlr3viz
<- autoplot(lr_pred) + labs(title = "Logit")
p1 <- autoplot(rf_pred) + labs(title = "RFC")
p2 + p2 p1
# Assessing performance
<- msr("classif.acc")
accuracy <- msr("classif.auc")
auc
$score(accuracy) lr_pred
## classif.acc
## 0.71
$score(accuracy) rf_pred
## classif.acc
## 0.68
# Default is Classification Error
$score() lr_pred
## classif.ce
## 0.29
$score() rf_pred
## classif.ce
## 0.32
# We can also get the prediction probabilities
$predict_type = "prob"
lr_learner$predict_type = "prob"
rf_learner
# We re-fit the model
$train(task, row_ids = train)
lr_learner$train(task, row_ids = train)
rf_learner
# We now get probabilities
<- lr_learner$predict(task, row_ids = test)
lr_pred head(as.data.table(lr_pred))
## row_ids truth response prob.good prob.bad
## 1: 3 good good 0.9206041 0.07939590
## 2: 6 good good 0.7912406 0.20875939
## 3: 7 good good 0.9196340 0.08036598
## 4: 11 bad good 0.6440619 0.35593812
## 5: 14 bad good 0.7088747 0.29112525
## 6: 19 bad bad 0.3970204 0.60297958
<- rf_learner$predict(task, row_ids = test)
rf_pred head(as.data.table(rf_pred))
## row_ids truth response prob.good prob.bad
## 1: 3 good good 0.9494838 0.05051616
## 2: 6 good good 0.8857721 0.11422791
## 3: 7 good good 0.9049654 0.09503457
## 4: 11 bad good 0.5834960 0.41650397
## 5: 14 bad good 0.5973205 0.40267950
## 6: 19 bad bad 0.3441812 0.65581882
<- autoplot(lr_pred, type = "roc") + labs(title = "Logit")
p1 <- autoplot(rf_pred, type = "roc") + labs(title = "RFC")
p2 + p2 p1
# Autotuning the hyperparameters for the RFC
$param_set rf_learner
## <ParamSet>
## id class lower upper nlevels default
## 1: alpha ParamDbl -Inf Inf Inf 0.5
## 2: always.split.variables ParamUty NA NA Inf <NoDefault[3]>
## 3: class.weights ParamUty NA NA Inf
## 4: holdout ParamLgl NA NA 2 FALSE
## 5: importance ParamFct NA NA 4 <NoDefault[3]>
## 6: keep.inbag ParamLgl NA NA 2 FALSE
## 7: max.depth ParamInt 0 Inf Inf
## 8: min.node.size ParamInt 1 Inf Inf 1
## 9: min.prop ParamDbl -Inf Inf Inf 0.1
## 10: minprop ParamDbl -Inf Inf Inf 0.1
## 11: mtry ParamInt 1 Inf Inf <NoDefault[3]>
## 12: mtry.ratio ParamDbl 0 1 Inf <NoDefault[3]>
## 13: num.random.splits ParamInt 1 Inf Inf 1
## 14: num.threads ParamInt 1 Inf Inf 1
## 15: num.trees ParamInt 1 Inf Inf 500
## 16: oob.error ParamLgl NA NA 2 TRUE
## 17: regularization.factor ParamUty NA NA Inf 1
## 18: regularization.usedepth ParamLgl NA NA 2 FALSE
## 19: replace ParamLgl NA NA 2 TRUE
## 20: respect.unordered.factors ParamFct NA NA 3 ignore
## 21: sample.fraction ParamDbl 0 1 Inf <NoDefault[3]>
## 22: save.memory ParamLgl NA NA 2 FALSE
## 23: scale.permutation.importance ParamLgl NA NA 2 FALSE
## 24: se.method ParamFct NA NA 2 infjack
## 25: seed ParamInt -Inf Inf Inf
## 26: split.select.weights ParamUty NA NA Inf
## 27: splitrule ParamFct NA NA 2 gini
## 28: verbose ParamLgl NA NA 2 TRUE
## 29: write.forest ParamLgl NA NA 2 TRUE
## id class lower upper nlevels default
## parents value
## 1:
## 2:
## 3:
## 4:
## 5:
## 6:
## 7:
## 8:
## 9:
## 10:
## 11:
## 12:
## 13: splitrule
## 14: 1
## 15:
## 16:
## 17:
## 18:
## 19:
## 20:
## 21:
## 22:
## 23: importance
## 24:
## 25:
## 26:
## 27:
## 28:
## 29:
## parents value
# HyperParameter Configuration
<- ps(
search_space max.depth = p_int(lower = 0, upper = 5),
min.node.size = p_int(lower = 5, upper = 100)
)
<- trm("evals", n_evals = 10)
terminator <- tnr("random_search")
tuner
$predict_type = "prob"
rf_learner# Autotune
<- AutoTuner$new(
at learner = rf_learner,
resampling = rsmp("holdout"),
measure = msr("classif.ce"),
search_space = search_space,
terminator = terminator,
tuner = tuner
)
# Training
$train(task) at
## INFO [22:09:11.050] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:09:11.055] [mlr3] Applying learner 'classif.ranger' on task 'german' (iter 1/1)
## INFO [22:09:11.148] [mlr3] Finished benchmark
## INFO [22:09:11.194] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:09:11.199] [mlr3] Applying learner 'classif.ranger' on task 'german' (iter 1/1)
## INFO [22:09:11.260] [mlr3] Finished benchmark
## INFO [22:09:11.306] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:09:11.310] [mlr3] Applying learner 'classif.ranger' on task 'german' (iter 1/1)
## INFO [22:09:11.374] [mlr3] Finished benchmark
## INFO [22:09:11.420] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:09:11.424] [mlr3] Applying learner 'classif.ranger' on task 'german' (iter 1/1)
## INFO [22:09:11.529] [mlr3] Finished benchmark
## INFO [22:09:11.575] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:09:11.579] [mlr3] Applying learner 'classif.ranger' on task 'german' (iter 1/1)
## INFO [22:09:11.664] [mlr3] Finished benchmark
## INFO [22:09:11.716] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:09:11.721] [mlr3] Applying learner 'classif.ranger' on task 'german' (iter 1/1)
## INFO [22:09:11.792] [mlr3] Finished benchmark
## INFO [22:09:11.837] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:09:11.842] [mlr3] Applying learner 'classif.ranger' on task 'german' (iter 1/1)
## INFO [22:09:11.941] [mlr3] Finished benchmark
## INFO [22:09:11.987] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:09:11.992] [mlr3] Applying learner 'classif.ranger' on task 'german' (iter 1/1)
## INFO [22:09:12.059] [mlr3] Finished benchmark
## INFO [22:09:12.103] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:09:12.108] [mlr3] Applying learner 'classif.ranger' on task 'german' (iter 1/1)
## INFO [22:09:12.229] [mlr3] Finished benchmark
## INFO [22:09:12.275] [mlr3] Running benchmark with 1 resampling iterations
## INFO [22:09:12.279] [mlr3] Applying learner 'classif.ranger' on task 'german' (iter 1/1)
## INFO [22:09:12.351] [mlr3] Finished benchmark
# New Prediction
<- at$predict(task, row_ids = test)
new_pred
$confusion new_pred
## truth
## response good bad
## good 186 48
## bad 16 50
# New Classification Error
$score() new_pred
## classif.ce
## 0.2133333
# ROC curve comparision
<- autoplot(new_pred, type = "roc") + labs(title = "at_RFC")
p3 + p3 p2
# --- More advanced ---
= benchmark_grid(
design tasks = tsks(c("german_credit")),
learners = lrns(c("classif.ranger", "classif.rpart", "classif.log_reg"),
predict_type = "prob", predict_sets = c("train", "test")),
resamplings = rsmps("cv", folds = 10)
)print(design)
## task learner resampling
## 1: <TaskClassif[50]> <LearnerClassifRanger[38]> <ResamplingCV[20]>
## 2: <TaskClassif[50]> <LearnerClassifRpart[38]> <ResamplingCV[20]>
## 3: <TaskClassif[50]> <LearnerClassifLogReg[37]> <ResamplingCV[20]>
# We call the benchmark on our design
= benchmark(design) bmr
## INFO [22:09:12.910] [mlr3] Running benchmark with 30 resampling iterations
## INFO [22:09:12.914] [mlr3] Applying learner 'classif.ranger' on task 'german_credit' (iter 6/10)
## INFO [22:09:13.276] [mlr3] Applying learner 'classif.ranger' on task 'german_credit' (iter 10/10)
## INFO [22:09:13.636] [mlr3] Applying learner 'classif.log_reg' on task 'german_credit' (iter 1/10)
## INFO [22:09:13.691] [mlr3] Applying learner 'classif.rpart' on task 'german_credit' (iter 1/10)
## INFO [22:09:13.724] [mlr3] Applying learner 'classif.rpart' on task 'german_credit' (iter 2/10)
## INFO [22:09:13.757] [mlr3] Applying learner 'classif.log_reg' on task 'german_credit' (iter 2/10)
## INFO [22:09:13.821] [mlr3] Applying learner 'classif.rpart' on task 'german_credit' (iter 7/10)
## INFO [22:09:13.859] [mlr3] Applying learner 'classif.ranger' on task 'german_credit' (iter 8/10)
## INFO [22:09:14.218] [mlr3] Applying learner 'classif.ranger' on task 'german_credit' (iter 7/10)
## INFO [22:09:14.573] [mlr3] Applying learner 'classif.log_reg' on task 'german_credit' (iter 8/10)
## INFO [22:09:14.635] [mlr3] Applying learner 'classif.ranger' on task 'german_credit' (iter 2/10)
## INFO [22:09:15.000] [mlr3] Applying learner 'classif.rpart' on task 'german_credit' (iter 10/10)
## INFO [22:09:15.036] [mlr3] Applying learner 'classif.rpart' on task 'german_credit' (iter 4/10)
## INFO [22:09:15.071] [mlr3] Applying learner 'classif.rpart' on task 'german_credit' (iter 3/10)
## INFO [22:09:15.107] [mlr3] Applying learner 'classif.log_reg' on task 'german_credit' (iter 5/10)
## INFO [22:09:15.163] [mlr3] Applying learner 'classif.ranger' on task 'german_credit' (iter 3/10)
## INFO [22:09:15.540] [mlr3] Applying learner 'classif.rpart' on task 'german_credit' (iter 5/10)
## INFO [22:09:15.576] [mlr3] Applying learner 'classif.log_reg' on task 'german_credit' (iter 6/10)
## INFO [22:09:15.628] [mlr3] Applying learner 'classif.ranger' on task 'german_credit' (iter 4/10)
## INFO [22:09:15.996] [mlr3] Applying learner 'classif.ranger' on task 'german_credit' (iter 9/10)
## INFO [22:09:16.456] [mlr3] Applying learner 'classif.log_reg' on task 'german_credit' (iter 3/10)
## INFO [22:09:16.549] [mlr3] Applying learner 'classif.log_reg' on task 'german_credit' (iter 10/10)
## INFO [22:09:16.613] [mlr3] Applying learner 'classif.ranger' on task 'german_credit' (iter 1/10)
## INFO [22:09:16.976] [mlr3] Applying learner 'classif.log_reg' on task 'german_credit' (iter 7/10)
## INFO [22:09:17.036] [mlr3] Applying learner 'classif.rpart' on task 'german_credit' (iter 9/10)
## INFO [22:09:17.076] [mlr3] Applying learner 'classif.rpart' on task 'german_credit' (iter 6/10)
## INFO [22:09:17.118] [mlr3] Applying learner 'classif.rpart' on task 'german_credit' (iter 8/10)
## INFO [22:09:17.161] [mlr3] Applying learner 'classif.log_reg' on task 'german_credit' (iter 9/10)
## INFO [22:09:17.213] [mlr3] Applying learner 'classif.ranger' on task 'german_credit' (iter 5/10)
## INFO [22:09:17.561] [mlr3] Applying learner 'classif.log_reg' on task 'german_credit' (iter 4/10)
## INFO [22:09:17.628] [mlr3] Finished benchmark
# We can choose many performance measures
= list(
measures msr("classif.auc", predict_sets = "train", id = "auc_train"),
msr("classif.auc", id = "auc_test")
)
= bmr$aggregate(measures)
tab print(tab)
## nr resample_result task_id learner_id resampling_id iters
## 1: 1 <ResampleResult[22]> german_credit classif.ranger cv 10
## 2: 2 <ResampleResult[22]> german_credit classif.rpart cv 10
## 3: 3 <ResampleResult[22]> german_credit classif.log_reg cv 10
## auc_train auc_test
## 1: 0.9984376 0.8018888
## 2: 0.7862836 0.7299559
## 3: 0.8396874 0.7877817
# Plotting the benchmark results (wrt classification error)
autoplot(bmr)
# lastly the goodlooking ROC curves
autoplot(bmr, type = "roc")