1.6 Introducción al paquete caret

Como ya se comentó en la Sección 1.2.2, el paquete caret (Classification And REgression Training, Kuhn, 2008) proporciona una interfaz unificada que simplifica el proceso de modelado empleando la mayoría de los métodos de AE implementados en R (actualmente admite 239 métodos; ver el Capítulo 6 del manual de este paquete). Además de proporcionar rutinas para los principales pasos del proceso, incluye también numerosas funciones auxiliares que permitirían implementar nuevos procedimientos.

En esta sección se describirán de forma esquemática las principales herramientas disponibles en este paquete, para más detalles se recomendaría consultar el manual del paquete caret. También está disponible una pequeña introducción en la vignette del paquete: A Short Introduction to the caret Package y una “chuleta”: Caret Cheat Sheet.

1.6.1 Métodos implementados

La función principal es train() (descrita en la siguiente subsección), que incluye un parámetro method que permite establecer el modelo mediante una cadena de texto. Podemos obtener información sobre los modelos disponibles con las funciones getModelInfo() y modelLookup() (puede haber varias implementaciones del mismo método con distintas configuraciones de hiperparámetros; también se pueden definir nuevos modelos, ver el Capítulo 13 del manual).

library(caret)
str(names(getModelInfo()))  # Listado de los métodos disponibles
##  chr [1:239] "ada" "AdaBag" "AdaBoost.M1" "adaboost" ...
# getModelInfo() devuelve coincidencias parciales por defecto
# names(getModelInfo("knn")) # 2 métodos
modelLookup("knn")  # Información sobre hiperparámetros
##   model parameter      label forReg forClass probModel
## 1   knn         k #Neighbors   TRUE     TRUE      TRUE

En la siguiente tabla se muestran los métodos actualmente disponibles:

Figura 1.16: Listado de los métodos disponiles en caret::train().

1.6.2 Herramientas

Este paquete permite, entre otras cosas:

  • Partición de los datos

    • createDataPartition(y, p = 0.5, list = TRUE, ...): crea particiones balanceadas de los datos.

      • En el caso de que la respuesta y sea categórica realiza el muestreo en cada clase. Para respuestas numéricas emplea cuantiles (definidos por el argumento groups = min(5, length(y))).

      • p: proporción de datos en la muestra de entrenamiento.

      • list: lógico; determina si el resultado es una lista con las muestras o un vector (o matriz) de índices

    • Funciones auxiliares: createFolds(), createMultiFolds(), groupKFold(), createResample(), createTimeSlices()

  • Análisis descriptivo: featurePlot()

  • Preprocesado de los datos:

    • La función principal es preProcess(x, method = c("center", "scale"), ...), aunque se puede integrar en el entrenamiento (función train()). Estimará los parámetros de las transformaciones con la muestra de entrenamiento y permitirá aplicarlas posteriormente de forma automática al hacer nuevas predicciones (p.e. en la muestra de test).

    • El parámetro method permite establecer una lista de procesados:

      • Imputación: "knnImpute", "bagImpute" o "medianImpute"

      • Creación y transformación de variables explicativas: "center", "scale", "range", "BoxCox", "YeoJohnson", "expoTrans", "spatialSign"

      • Selección de predictores y extracción de componentes: "corr", "nzv", "zv", "conditionalX", "pca", "ica"

    • Dispone de múltiples funciones auxiliares, como dummyVars() o rfe() (recursive feature elimination).

  • Entrenamiento y selección de los hiperparámetros del modelo:

    • La función principal es train(formula, data, method = "rf", trControl = trainControl(), tuneGrid = NULL, tuneLength = 3, ...)

      • trControl: permite establecer el método de remuestreo para la evaluación de los hiperparámetros y el método para seleccionar el óptimo, incluyendo las medidas de precisión. Por ejemplo trControl = trainControl(method = "cv", number = 10, selectionFunction = "oneSE").

        Los métodos disponibles son: "boot", "boot632", "optimism_boot", "boot_all", "cv", "repeatedcv", "LOOCV", "LGOCV", "timeslice", "adaptive_cv", "adaptive_boot" o "adaptive_LGOCV"

      • tuneLength y tuneGrid: permite establecer cuantos hiperparámetros serán evaluados (por defecto 3) o una rejilla con las combinaciones de hiperparámetros.

      • ... permite establecer opciones específicas de los métodos.

    • También admite matrices x, y en lugar de fórmulas (o recetas: recipe()).

    • Si se imputan datos en el preprocesado será necesario establecer na.action = na.pass.

  • Predicción: Una de las ventajas es que incorpora un único método predict() para objetos de tipo train con dos únicas opciones15 type = c("raw", "prob"), la primera para obtener predicciones de la respuesta y la segunda para obtener estimaciones de las probabilidades (en los métodos de clasificación que lo admitan).

    Además, si se incluyo un preprocesado en el entrenamiento, se emplearán las mismas transformaciones en un nuevo conjunto de datos newdata.

  • Evaluación de los modelos

    • postResample(pred, obs, ...): regresión

    • confusionMatrix(pred, obs, ...): clasificación

      • Funciones auxiliares: twoClassSummary(), prSummary()
  • Análisis de la importancia de los predictores:

    • varImp(): interfaz a las medidas específicas de los métodos de aprendizaje supervisado (Sección 15.1 del manual) o medidas genéricas (Sección 15.2).

1.6.3 Ejemplo

Como ejemplo consideraremos el problema de regresión anterior empleando KNN en caret:

data(Boston, package = "MASS")
library(caret)

Particionamos los datos:

set.seed(1)
itrain <- createDataPartition(Boston$medv, p = 0.8, list = FALSE)
train <- Boston[itrain, ]
test <- Boston[-itrain, ]

Entrenamiento, con preprocesado de los datos (se almacenan las transformaciones para volver a aplicarlas en la predicción con nuevos datos) y empleando validación cruzada con 10 grupos para la selección de hiperparámetros:

set.seed(1)
knn <- train(medv ~ ., data = train,
             method = "knn",
             preProc = c("center", "scale"),
             tuneGrid = data.frame(k = 1:10),
             trControl = trainControl(method = "cv", number = 10))
plot(knn) # Alternativamente: ggplot(knn, highlight = TRUE)
Raíz del error cuadrático medio de validación cruzada dependiendo del valor del hiperparámetro.

Figura 1.17: Raíz del error cuadrático medio de validación cruzada dependiendo del valor del hiperparámetro.

knn$bestTune
##   k
## 3 3
knn$finalModel
## 3-nearest neighbor regression model

Importancia de las variables (interpretación del modelo final):

varImp(knn)
## loess r-squared variable importance
## 
##         Overall
## lstat    100.00
## rm        88.26
## indus     36.29
## ptratio   33.27
## tax       30.58
## crim      28.33
## nox       23.44
## black     21.29
## age       20.47
## rad       17.16
## zn        15.11
## dis       14.35
## chas       0.00

Evaluación del modelo final en la muestra de test:

postResample(predict(knn, newdata = test), test$medv)
##     RMSE Rsquared      MAE 
## 4.960971 0.733945 2.724242

1.6.4 Desarrollo futuro

Como comenta el autor del paquete caret (y coautor en Kuhn y Johnson, 2013):

“While I’m still supporting caret, the majority of my development effort has gone into the tidyverse modeling packages (called tidymodels)”.

— Max Kuhn (actualmente ingeniero de software en RStudio).

este paquete ha dejado de desarrollarse de forma activa, aunque consideramos que la alternativa tidymodels (Kuhn y Wickham, 2022) todavía está en fase de desarrollo16 y su uso requiere de más tiempo de aprendizaje. Este es uno de los motivos por los que se ha optado por mantener el uso de caret en este libro, aunque la intención es incluir apéndices adicionales en próximas ediciones ilustrando el uso de otras herramientas (como tidymodels, ver Kuhn y Silge, 2022; o incluso mlr3, Becker et al., 2021).

References

Becker, M., Binder, M., Bischl, B., Lang, M., Pfisterer, F., Reich, N. G., Richter, J., Schratz, P., Sonabend, R., y Pulatov, D. (2021). mlr3 book. https://mlr3book.mlr-org.com
Kuhn, M. (2008). Building Predictive Models in R Using the caret Package. Journal of Statistical Software, 28(5), 1-26. https://doi.org/10.18637/jss.v028.i05
Kuhn, M., y Johnson, K. (2013). Applied predictive modeling (Vol. 26). Springer. https://doi.org/10.1007/978-1-4614-6849-3
Kuhn, M., y Silge, J. (2022). Tidy Modeling with R. O’Reilly Media. https://www.tmwr.org
Kuhn, M., y Wickham, H. (2022). tidymodels: Easily Install and Load the Tidymodels Packages. https://CRAN.R-project.org/package=tidymodels

  1. En lugar de la variedad de opciones que emplean los distintos paquetes (e.g.: type = "response", "class", "posterior", "probability"… ).↩︎

  2. Sin embargo, desde la publicación del libro Kuhn y Silge (2022), disponible en línea en https://www.tmwr.org, ya podríamos considerar que ha superado la fase inicial de desarrollo.↩︎