3.3 CART con el paquete rpart
La metodología CART está implementada en el paquete rpart
(acrónimo de Recursive PARTitioning)37, implementado por Therneau et al. (2013).
La función principal es rpart()
y habitualmente se emplea de la forma:
rpart(formula, data, method, parms, control, ...)
formula
: permite especificar la respuesta y las variables predictoras de la forma habitual; se suele establecer de la formarespuesta ~ .
para incluir todas las posibles variables explicativas.data
:data.frame
(opcional; donde se evaluará la fórmula) con la muestra de entrenamiento.method
: método empleado para realizar las particiones, puede ser"anova"
(regresión),"class"
(clasificación),"poisson"
(regresión de Poisson) o"exp"
(supervivencia), o alternativamente una lista de funciones (con componentesinit
,split
,eval
; ver la vignette User Written Split Functions). Por defecto se selecciona a partir de la variable respuesta enformula
, por ejemplo, si es un factor (lo recomendado en clasificación) empleamethod = "class"
.parms
: lista de parámetros opcionales para la partición en el caso de clasificación (o regresión de Poisson). Puede contener los componentesprior
(vector de probabilidades previas; por defecto las frecuencias observadas),loss
(matriz de pérdidas; con ceros en la diagonal y por defecto 1 en el resto) ysplit
(criterio de error; por defecto"gini"
o alternativamente"information"
).control
: lista de opciones que controlan el algoritmo de partición, por defecto se seleccionan mediante la funciónrpart.control()
, aunque también se pueden establecer en la llamada a la función principal, y los principales parámetros son:rpart.control(minsplit = 20, minbucket = round(minsplit/3), cp = 0.01, xval = 10, maxdepth = 30, ...)
cp
es el parámetro de complejidad \(\tilde \alpha\) para la poda del árbol, de forma que un valor de 1 se corresponde con un árbol sin divisiones, y un valor de 0, con un árbol de profundidad máxima. Adicionalmente, para reducir el tiempo de computación, el algoritmo empleado no realiza una partición si la proporción de reducción del error es inferior a este valor (valores más grandes simplifican el modelo y reducen el tiempo de computación).maxdepth
es la profundidad máxima del árbol (la profundidad de la raíz sería 0).minsplit
yminbucket
son, respectivamente, los números mínimos de observaciones en un nodo intermedio para particionarlo y en un nodo terminal.xval
es el número de grupos (folds) para validación cruzada.
Para más detalles, consultar la documentación de esta función o la vignette Introduction to Rpart.
3.3.1 Ejemplo: regresión
Emplearemos el conjunto de datos winequality
del paquete mpae
, que contiene información fisico-química
(fixed.acidity
, volatile.acidity
, citric.acid
, residual.sugar
, chlorides
, free.sulfur.dioxide
,
total.sulfur.dioxide
, density
, pH
, sulphates
y alcohol
) y sensorial (quality
) de una muestra de 1250 vinos portugueses de la variedad vinho verde (Cortez et al., 2009)
library(mpae)
# data(winequality, package = "mpae")
str(winequality)
## 'data.frame': 1250 obs. of 12 variables:
## $ fixed.acidity : num 6.8 7.1 6.9 7.5 8.6 7.7 5.4 6.8 6.1 5.5 ...
## $ volatile.acidity : num 0.37 0.24 0.32 0.23 0.36 0.28 0.59 0.16 0.28 0.2..
## $ citric.acid : num 0.47 0.34 0.13 0.49 0.26 0.63 0.07 0.36 0.27 0.2..
## $ residual.sugar : num 11.2 1.2 7.8 7.7 11.1 11.1 7 1.3 4.7 1.6 ...
## $ chlorides : num 0.071 0.045 0.042 0.049 0.03 0.039 0.045 0.034 0..
## $ free.sulfur.dioxide : num 44 6 11 61 43.5 58 36 32 56 23 ...
## $ total.sulfur.dioxide: num 136 132 117 209 171 179 147 98 140 85 ...
## $ density : num 0.997 0.991 0.996 0.994 0.995 ...
## $ pH : num 2.98 3.16 3.23 3.14 3.03 3.08 3.34 3.02 3.16 3.4..
## $ sulphates : num 0.88 0.46 0.37 0.3 0.49 0.44 0.57 0.58 0.42 0.42..
## $ alcohol : num 9.2 11.2 9.2 11.1 12 8.8 9.7 11.3 12.5 12.5 ...
## $ quality : int 5 4 5 7 5 4 6 6 8 5 ...
Como respuesta consideraremos la variable quality
, mediana de al menos 3 evaluaciones de la calidad del vino realizadas por expertos, que los evaluaron entre 0 (muy malo) y 10 (excelente) como puede observarse en el gráfico de barras de la Figura 3.3.
barplot(table(winequality$quality), xlab = "Calidad", ylab = "Frecuencia")
En primer lugar se selecciona el 80 % de los datos como muestra de entrenamiento y el 20 % restante como muestra de test:
set.seed(1)
<- nrow(winequality)
nobs <- sample(nobs, 0.8 * nobs)
itrain <- winequality[itrain, ]
train <- winequality[-itrain, ] test
Podemos obtener el árbol de decisión con las opciones por defecto con el comando:
<- rpart(quality ~ ., data = train) tree
Al imprimirlo se muestra el número de observaciones e información sobre los distintos nodos (número de nodo, condición que define la partición, número de observaciones en el nodo, función de pérdida y predicción), marcando con un *
los nodos terminales.
tree
## n= 1000
##
## node), split, n, deviance, yval
## * denotes terminal node
##
## 1) root 1000 768.960 5.8620
## 2) alcohol< 10.75 622 340.810 5.5868
## 4) volatile.acidity>=0.2575 329 154.760 5.3708
## 8) total.sulfur.dioxide< 98.5 24 12.500 4.7500 *
## 9) total.sulfur.dioxide>=98.5 305 132.280 5.4197
## 18) pH< 3.315 269 101.450 5.3532 *
## 19) pH>=3.315 36 20.750 5.9167 *
## 5) volatile.acidity< 0.2575 293 153.470 5.8294
## 10) sulphates< 0.475 144 80.326 5.6597 *
## 11) sulphates>=0.475 149 64.993 5.9933 *
## 3) alcohol>=10.75 378 303.540 6.3148
## 6) alcohol< 11.775 200 173.870 6.0750
## 12) free.sulfur.dioxide< 11.5 15 10.933 4.9333 *
## 13) free.sulfur.dioxide>=11.5 185 141.810 6.1676
## 26) volatile.acidity>=0.395 7 12.857 5.1429 *
## 27) volatile.acidity< 0.395 178 121.310 6.2079
## 54) citric.acid>=0.385 31 21.935 5.7419 *
## 55) citric.acid< 0.385 147 91.224 6.3061 *
## 7) alcohol>=11.775 178 105.240 6.5843 *
Para representarlo se pueden emplear las herramientas del paquete rpart
(ver Figura 3.4):
plot(tree)
text(tree)
Pero puede ser preferible emplear el paquete rpart.plot
(Milborrow, 2019) (ver Figura 3.5):
library(rpart.plot)
rpart.plot(tree)
Nos interesa conocer cómo se clasificaría a una nueva observación en los nodos terminales, junto con las predicciones correspondientes (la media de la respuesta en el nodo terminal). En los nodos intermedios, solo nos interesan las condiciones y el orden de las variables consideradas hasta llegar a las hojas. Para ello, puede ser útil imprimir las reglas:
rpart.rules(tree, style = "tall")
## quality is 4.8 when
## alcohol < 11
## volatile.acidity >= 0.26
## total.sulfur.dioxide < 99
##
## quality is 4.9 when
## alcohol is 11 to 12
## free.sulfur.dioxide < 12
##
## quality is 5.1 when
## alcohol is 11 to 12
## volatile.acidity >= 0.40
## free.sulfur.dioxide >= 12
##
## quality is 5.4 when
## alcohol < 11
## volatile.acidity >= 0.26
## total.sulfur.dioxide >= 99
## pH < 3.3
##
## quality is 5.7 when
## alcohol < 11
## volatile.acidity < 0.26
## sulphates < 0.48
##
## quality is 5.7 when
## alcohol is 11 to 12
## volatile.acidity < 0.40
## free.sulfur.dioxide >= 12
## citric.acid >= 0.39
##
## quality is 5.9 when
## alcohol < 11
## volatile.acidity >= 0.26
## total.sulfur.dioxide >= 99
## pH >= 3.3
##
## quality is 6.0 when
## alcohol < 11
## volatile.acidity < 0.26
## sulphates >= 0.48
##
## quality is 6.3 when
## alcohol is 11 to 12
## volatile.acidity < 0.40
## free.sulfur.dioxide >= 12
## citric.acid < 0.39
##
## quality is 6.6 when
## alcohol >= 12
Por defecto, el árbol se poda considerando cp = 0.01
, que puede ser adecuado en muchas situaciones.
Sin embargo, para seleccionar el valor óptimo de este hiperparámetro se puede emplear validación cruzada.
En primer lugar habría que establecer cp = 0
para construir el árbol completo, a la profundidad máxima. La profundidad máxima viene determinada por los valores de minsplit
y minbucket
, los cuales pueden ser ajustados manualmente dependiendo del número de observaciones o tratados como hiperparámetros; esto último no está implementado en rpart
, ni en principio en caret
38.
<- rpart(quality ~ ., data = train, cp = 0) tree
Posteriormente, podemos emplear la función printcp()
para obtener los valores de CP para los árboles (óptimos) de menor tamaño, junto con su error de validación cruzada xerror
(reescalado de forma que el máximo de rel error
es 1):
printcp(tree)
##
## Regression tree:
## rpart(formula = quality ~ ., data = train, cp = 0)
##
## Variables actually used in tree construction:
## [1] alcohol chlorides citric.acid
## [4] density fixed.acidity free.sulfur.dioxide
## [7] pH residual.sugar sulphates
## [10] total.sulfur.dioxide volatile.acidity
##
## Root node error: 769/1000 = 0.769
##
## n= 1000
##
## CP nsplit rel error xerror xstd
## 1 0.162047 0 1.000 1.002 0.0486
## 2 0.042375 1 0.838 0.858 0.0436
## 3 0.031765 2 0.796 0.828 0.0435
## 4 0.027487 3 0.764 0.813 0.0428
## 5 0.013044 4 0.736 0.770 0.0397
## 6 0.010596 6 0.710 0.782 0.0394
## 7 0.010266 7 0.700 0.782 0.0391
## 8 0.008408 9 0.679 0.782 0.0391
## 9 0.008139 10 0.671 0.801 0.0399
## 10 0.007806 11 0.663 0.800 0.0405
## 11 0.006842 13 0.647 0.798 0.0402
## 12 0.006738 15 0.633 0.814 0.0409
## [ reached getOption("max.print") -- omitted 48 rows ]
También plotcp()
para representarlos39 (ver Figura 3.6):
plotcp(tree)
La tabla con los valores de las podas (óptimas, dependiendo del parámetro de complejidad) está almacenada en la componente $cptable
:
head(tree$cptable, 10)
## CP nsplit rel error xerror xstd
## 1 0.1620471 0 1.00000 1.00203 0.048591
## 2 0.0423749 1 0.83795 0.85779 0.043646
## 3 0.0317653 2 0.79558 0.82810 0.043486
## 4 0.0274870 3 0.76381 0.81350 0.042814
## 5 0.0130437 4 0.73633 0.77038 0.039654
## 6 0.0105961 6 0.71024 0.78168 0.039353
## 7 0.0102661 7 0.69964 0.78177 0.039141
## 8 0.0084080 9 0.67911 0.78172 0.039123
## 9 0.0081392 10 0.67070 0.80117 0.039915
## 10 0.0078057 11 0.66256 0.80020 0.040481
A partir de esta misma tabla podríamos seleccionar el valor óptimo de forma automática, siguiendo el criterio de un error estándar de Breiman et al. (1984):
<- tree$cptable[,"xerror"]
xerror <- which.min(xerror)
imin.xerror # Valor óptimo
$cptable[imin.xerror, ] tree
## CP nsplit rel error xerror xstd
## 0.013044 4.000000 0.736326 0.770380 0.039654
# Límite superior "oneSE rule" y complejidad mínima por debajo de ese valor
<- xerror[imin.xerror] + tree$cptable[imin.xerror, "xstd"]
upper.xerror <- min(which(xerror <= upper.xerror))
icp <- tree$cptable[icp, "CP"] cp
Para obtener el modelo final (ver Figura 3.7) podamos el árbol con el valor de complejidad obtenido 0.01304 (que en este caso coincide con el valor óptimo).
<- prune(tree, cp = cp)
tree rpart.plot(tree)
Podríamos estudiar el modelo final, por ejemplo, mediante el método summary.rpart()
. Este método muestra, entre otras cosas, una medida (en porcentaje) de la importancia de las variables explicativas para la predicción de la respuesta (teniendo en cuenta todas las particiones, principales y secundarias, en las que se emplea cada variable explicativa). Como alternativa, podríamos emplear el siguiente código:
# summary(tree)
<- tree$variable.importance # Equivalente a caret::varImp(tree)
importance <- round(100*importance/sum(importance), 1)
importance >= 1] importance[importance
## alcohol density chlorides
## 36.1 21.7 11.3
## volatile.acidity total.sulfur.dioxide free.sulfur.dioxide
## 8.7 8.5 5.0
## residual.sugar sulphates citric.acid
## 4.0 1.9 1.1
## pH
## 1.1
El último paso sería evaluarlo en la muestra de test siguiendo los pasos descritos en la Sección 1.3.4. Representamos los valores observados frente a las predicciones (ver Figura 3.8):
<- test$quality
obs <- predict(tree, newdata = test)
pred # plot(pred, obs, xlab = "Predicción", ylab = "Observado")
plot(jitter(pred), jitter(obs), xlab = "Predicción", ylab = "Observado")
abline(a = 0, b = 1)
y calculamos medidas de error de las predicciones, bien empleando el paquete caret
:
::postResample(pred, obs) caret
## RMSE Rsquared MAE
## 0.81456 0.19695 0.65743
o con la función accuracy()
:
accuracy(pred, test$quality)
## me rmse mae mpe mape r.squared
## -0.0012694 0.8145614 0.6574264 -1.9523422 11.5767160 0.1920077
Como se puede observar, el ajuste del modelo es bastante deficiente. Esto es habitual en árboles de regresión, especialmente si son tan pequeños, y por ello solo se utilizan, por lo general, en un análisis exploratorio inicial o como base para modelos más avanzados como los mostrados en el siguiente capítulo. En problemas de clasificación es más habitual que se puedan llegar a obtener buenos ajustes con árboles de decisión.
Ejercicio 3.1 Como se comentó en la introducción del Capítulo 1, al emplear el procedimiento habitual en AE de particionar los datos no se garantiza la reproducibilidad/repetibilidad de los resultados, ya que dependen de la semilla. El modelo ajustado puede variar con diferentes semillas, especialmente si el conjunto de entrenamiento es pequeño, aunque generalmente no se observan cambios significativos en las predicciones.
Podemos ilustrar el efecto de la semilla en los resultados empleando el ejemplo anterior. Para ello, repite el ajuste de un árbol de regresión considerando distintas semillas y compara los resultados obtenidos.
La dificultad podría estar en cómo comparar los resultados.
Una posible solución sería mantener fija la muestra de test, de modo que no dependa de la semilla utilizada.
Por comodidad, considera las primeras ntest
observaciones del conjunto de datos como muestra de test.
Posteriormente, para cada semilla, selecciona la muestra de entrenamiento de la forma habitual y ajusta un árbol.
Finalmente, evalúa los resultados en la muestra de test.
Como base se puede utilizar el siguiente código:
<- 10
ntest <- winequality[1:ntest, ]
test <- winequality[-(1:ntest), ]
df <- nrow(df)
nobs # Para las distintas semillas
set.seed(semilla)
<- sample(nobs, 0.8 * nobs)
itrain <- df[itrain, ]
train # tree <- ...
Como observación final, en este caso el conjunto de datos no es muy grande y tampoco se obtuvo un buen ajuste con un árbol de regresión, por lo que sería de esperar que se observaran más diferencias.
Ejercicio 3.2 Como se indicó previamente, el paquete rpart
implementa la selección del parámetro de complejidad mediante validación cruzada.
Como alternativa, siguiendo la idea del Ejercicio 1.1, y considerando de nuevo el ejemplo anterior, particiona la muestra en datos de entrenamiento (70 %), de validación (15 %) y de test (15 %), para ajustar los árboles de decisión, seleccionar el parámetro de complejidad (el hiperparámetro) y evaluar las predicciones del modelo final, respectivamente.
Ejercicio 3.3 Una alternativa a particionar en entrenamiento y validación sería emplear bootstrap. La idea consiste en emplear una remuestra bootstrap del conjunto de datos de entrenamiento para ajustar el modelo y utilizar las observaciones no seleccionadas (se suelen denominar datos out-of-bag) como conjunto de validación.
set.seed(1)
<- nrow(winequality)
nobs <- sample(nobs, 0.8 * nobs)
itrain <- winequality[itrain, ]
train <- winequality[-itrain, ]
test # Indice muestra de entrenamiento bootstrap
set.seed(1)
<- nrow(train)
ntrain <- sample(ntrain, replace = TRUE)
itrain.boot <- train[itrain.boot, ] train.boot
La muestra bootstrap va a contener muchas observaciones repetidas y habrá observaciones no seleccionadas. La probabilidad de que una observación no sea seleccionada es \((1 - 1/n)^n \approx e^{-1} \approx 0.37\).
# Número de casos "out of bag"
- length(unique(itrain.boot)) ntrain
## [1] 370
# Muestra "out of bag"
<- train[-itrain.boot, ] oob
El procedimiento restante sería análogo al caso anterior, cambiando train
por train.boot
y validate
por oob
.
Sin embargo, lo recomendable sería repetir el proceso un número grande de veces y promediar los errores, especialmente cuando el tamaño muestral es pequeño (este enfoque se relaciona con el método de bagging, descrito en el siguiente capítulo).
No obstante, y por simplicidad, realiza el ajuste empleando una única muestra bootstrap y evalúa las predicciones en la muestra de test.
3.3.2 Ejemplo: modelo de clasificación
Para ilustrar los árboles de clasificación CART, podemos emplear los datos anteriores de calidad de vino, considerando como respuesta una nueva variable taste
que clasifica los vinos en “good” o “bad” dependiendo de si winequality$quality >= 6
(este conjunto de datos está disponible en mpae::winetaste
).
# data(winetaste, package = "mpae")
<- winequality[, colnames(winequality)!="quality"]
winetaste $taste <- factor(winequality$quality < 6,
winetastelabels = c('good', 'bad')) # levels = c('FALSE', 'TRUE')
str(winetaste)
## 'data.frame': 1250 obs. of 12 variables:
## $ fixed.acidity : num 6.8 7.1 6.9 7.5 8.6 7.7 5.4 6.8 6.1 5.5 ...
## $ volatile.acidity : num 0.37 0.24 0.32 0.23 0.36 0.28 0.59 0.16 0.28 0.2..
## $ citric.acid : num 0.47 0.34 0.13 0.49 0.26 0.63 0.07 0.36 0.27 0.2..
## $ residual.sugar : num 11.2 1.2 7.8 7.7 11.1 11.1 7 1.3 4.7 1.6 ...
## $ chlorides : num 0.071 0.045 0.042 0.049 0.03 0.039 0.045 0.034 0..
## $ free.sulfur.dioxide : num 44 6 11 61 43.5 58 36 32 56 23 ...
## $ total.sulfur.dioxide: num 136 132 117 209 171 179 147 98 140 85 ...
## $ density : num 0.997 0.991 0.996 0.994 0.995 ...
## $ pH : num 2.98 3.16 3.23 3.14 3.03 3.08 3.34 3.02 3.16 3.4..
## $ sulphates : num 0.88 0.46 0.37 0.3 0.49 0.44 0.57 0.58 0.42 0.42..
## $ alcohol : num 9.2 11.2 9.2 11.1 12 8.8 9.7 11.3 12.5 12.5 ...
## $ taste : Factor w/ 2 levels "good","bad": 2 2 2 1 2 2 1 1 1 2 ..
table(winetaste$taste)
##
## good bad
## 828 422
Como en el caso anterior, se contruyen las muestras de entrenamiento (80 %) y de test (20 %):
# set.seed(1)
# nobs <- nrow(winetaste)
# itrain <- sample(nobs, 0.8 * nobs)
<- winetaste[itrain, ]
train <- winetaste[-itrain, ] test
Al igual que en el caso anterior, podemos obtener el árbol de clasificación con las opciones por defecto (cp = 0.01
y split = "gini"
) con el comando:
<- rpart(taste ~ ., data = train) tree
Al imprimirlo, además de mostrar el número de nodo, la condición de la partición y el número de observaciones en el nodo, también se incluye el número de observaciones mal clasificadas, la predicción y las proporciones estimadas (frecuencias relativas en la muestra de entrenamiento) de las clases:
tree
## n= 1000
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 1000 338 good (0.66200 0.33800)
## 2) alcohol>=10.117 541 100 good (0.81516 0.18484)
## 4) free.sulfur.dioxide>=8.5 522 87 good (0.83333 0.16667)
## 8) fixed.acidity< 8.55 500 73 good (0.85400 0.14600) *
## 9) fixed.acidity>=8.55 22 8 bad (0.36364 0.63636) *
## 5) free.sulfur.dioxide< 8.5 19 6 bad (0.31579 0.68421) *
## 3) alcohol< 10.117 459 221 bad (0.48148 0.51852)
## 6) volatile.acidity< 0.2875 264 102 good (0.61364 0.38636)
## 12) fixed.acidity< 7.45 213 71 good (0.66667 0.33333)
## 24) citric.acid>=0.265 160 42 good (0.73750 0.26250) *
## 25) citric.acid< 0.265 53 24 bad (0.45283 0.54717)
## 50) free.sulfur.dioxide< 42.5 33 13 good (0.60606 0.39394) *
## 51) free.sulfur.dioxide>=42.5 20 4 bad (0.20000 0.80000) *
## 13) fixed.acidity>=7.45 51 20 bad (0.39216 0.60784)
## 26) total.sulfur.dioxide>=150 26 10 good (0.61538 0.38462) *
## 27) total.sulfur.dioxide< 150 25 4 bad (0.16000 0.84000) *
## 7) volatile.acidity>=0.2875 195 59 bad (0.30256 0.69744)
## 14) pH>=3.235 49 24 bad (0.48980 0.51020)
## 28) chlorides< 0.0465 18 4 good (0.77778 0.22222) *
## 29) chlorides>=0.0465 31 10 bad (0.32258 0.67742) *
## 15) pH< 3.235 146 35 bad (0.23973 0.76027) *
También puede ser preferible emplear el paquete rpart.plot
para representarlo (ver Figura 3.9):
library(rpart.plot)
rpart.plot(tree) # Alternativa: rattle::fancyRpartPlot
Nos interesa cómo se clasificaría a una nueva observación, es decir, cómo se llegaría a los nodos terminales, así como su probabilidad estimada, que representa la frecuencia relativa de la clase más frecuente en el correspondiente nodo terminal. Para lograrlo, se puede modificar la información que se muestra en cada nodo (ver Figura 3.10):
rpart.plot(tree,
extra = 104, # show fitted class, probs, percentages
box.palette = "GnBu", # color scheme
branch.lty = 3, # dotted branch lines
shadow.col = "gray", # shadows under the node boxes
nn = TRUE) # display the node numbers
Al igual que en el caso de regresión, puede ser de utilidad imprimir las reglas con rpart.rules(tree, style = "tall")
.
También se suele emplear el mismo procedimiento para seleccionar un valor óptimo para el hiperparámetro de complejidad: se construye un árbol de decisión completo y se emplea validación cruzada para podarlo. Además, si el número de observaciones es grande y las clases están más o menos balanceadas, se podría aumentar los valores mínimos de observaciones en los nodos intermedios y terminales40, por ejemplo:
<- rpart(taste ~ ., data = train, cp = 0, minsplit = 30, minbucket = 10) tree
En este caso mantenemos el resto de valores por defecto:
<- rpart(taste ~ ., data = train, cp = 0) tree
Representamos los errores (reescalados) de validación cruzada (ver Figura 3.11):
plotcp(tree)
Para obtener el modelo final, seleccionamos el valor óptimo de complejidad siguiendo el criterio de un error estándar de Breiman et al. (1984) y podamos el árbol (ver Figura 3.12).
<- tree$cptable[,"xerror"]
xerror <- which.min(xerror)
imin.xerror <- xerror[imin.xerror] + tree$cptable[imin.xerror, "xstd"]
upper.xerror <- min(which(xerror <= upper.xerror))
icp <- tree$cptable[icp, "CP"]
cp <- prune(tree, cp = cp)
tree rpart.plot(tree)
Si nos interesase estudiar la importancia de los predictores, podríamos utilizar el mismo código de la Sección 3.3.1 (no evaluado):
::varImp(tree)
caret<- tree$variable.importance
importance <- round(100*importance/sum(importance), 1)
importance >= 1] importance[importance
El último paso sería evaluar el modelo en la muestra de test siguiendo los pasos descritos en la Sección 1.3.5.
El método predict.rpart()
devuelve por defecto (type = "prob"
) una matriz con las probabilidades de cada clase, por lo que habría que establecer type = "class"
para obtener la clase predicha (consultar la ayuda de esta función para más detalles).
<- test$taste
obs head(predict(tree, newdata = test))
## good bad
## 1 0.30256 0.69744
## 4 0.81516 0.18484
## 9 0.81516 0.18484
## 10 0.81516 0.18484
## 12 0.81516 0.18484
## 16 0.81516 0.18484
<- predict(tree, newdata = test, type = "class")
pred table(obs, pred)
## pred
## obs good bad
## good 153 13
## bad 54 30
::confusionMatrix(pred, obs) caret
## Confusion Matrix and Statistics
##
## Reference
## Prediction good bad
## good 153 54
## bad 13 30
##
## Accuracy : 0.732
## 95% CI : (0.673, 0.786)
## No Information Rate : 0.664
## P-Value [Acc > NIR] : 0.0125
##
## Kappa : 0.317
##
## Mcnemar's Test P-Value : 1.02e-06
##
## Sensitivity : 0.922
## Specificity : 0.357
## Pos Pred Value : 0.739
## Neg Pred Value : 0.698
## Prevalence : 0.664
## Detection Rate : 0.612
## Detection Prevalence : 0.828
## Balanced Accuracy : 0.639
##
## 'Positive' Class : good
##
Ejercicio 3.4 En este ejercicio se empleará el conjunto de datos mpae::bfan
del paquete mpae
utilizado anteriormente en las secciones 2.2 y 2.3.
Considerando como respuesta la variable indicadora bfan
, que clasifica a los individuos en Yes
o No
dependiendo de si su porcentaje de grasa corporal es superior al rango normal:
Particiona los datos, considerando un 80 % de las observaciones como muestra de aprendizaje y el 20 % restante como muestra de test.
Ajusta un árbol de decisión a los datos de entrenamiento seleccionando el parámetro de complejidad mediante la regla de un error estándar de Breiman.
Representa e interpreta el árbol resultante, estudiando la importancia de las variables predictoras.
Evalúa la precisión de las predicciones en la muestra de test (precisión, sensibilidad y especificidad) y la estimación de las probabilidades mediante el AUC.
3.3.3 Interfaz de caret
En caret
podemos ajustar un árbol CART seleccionando method = "rpart"
.
Por defecto, caret
realiza bootstrap de las observaciones para seleccionar el valor óptimo del hiperparámetro cp
(considerando únicamente tres posibles valores).
Si queremos emplear validación cruzada, como se hizo en el caso anterior, podemos emplear la función auxiliar trainControl()
, y para considerar un mayor rango de posibles valores podemos hacer uso del argumento tuneLength
(ver Figura 3.13).
library(caret)
# modelLookup("rpart") # Información sobre hiperparámetros
set.seed(1)
<- trainControl(method = "cv", number = 10)
trControl <- train(taste ~ ., method = "rpart", data = train,
caret.rpart tuneLength = 20, trControl = trControl)
caret.rpart
## CART
##
## 1000 samples
## 11 predictor
## 2 classes: 'good', 'bad'
##
## No pre-processing
## Resampling: Cross-Validated (10 fold)
## Summary of sample sizes: 901, 900, 900, 900, 900, 900, ...
## Resampling results across tuning parameters:
##
## cp Accuracy Kappa
## 0.000000 0.70188 0.34873
## 0.005995 0.73304 0.38706
## 0.011990 0.74107 0.38785
## 0.017985 0.72307 0.33745
## 0.023980 0.73607 0.36987
## 0.029975 0.73407 0.35064
## 0.035970 0.73207 0.34182
## 0.041965 0.73508 0.34227
## 0.047960 0.73508 0.34227
## 0.053955 0.73508 0.34227
## 0.059950 0.73508 0.34227
## 0.065945 0.73508 0.34227
## 0.071940 0.73508 0.34227
## 0.077935 0.73508 0.34227
## 0.083930 0.73508 0.34227
## 0.089925 0.73508 0.34227
## 0.095920 0.73508 0.34227
## 0.101915 0.73508 0.34227
## 0.107910 0.72296 0.29433
## 0.113905 0.68096 0.10877
##
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was cp = 0.01199.
ggplot(caret.rpart, highlight = TRUE)
El modelo final se devuelve en la componente $finalModel
(ver Figura 3.14):
$finalModel caret.rpart
## n= 1000
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 1000 338 good (0.66200 0.33800)
## 2) alcohol>=10.117 541 100 good (0.81516 0.18484)
## 4) free.sulfur.dioxide>=8.5 522 87 good (0.83333 0.16667)
## 8) fixed.acidity< 8.55 500 73 good (0.85400 0.14600) *
## 9) fixed.acidity>=8.55 22 8 bad (0.36364 0.63636) *
## 5) free.sulfur.dioxide< 8.5 19 6 bad (0.31579 0.68421) *
## 3) alcohol< 10.117 459 221 bad (0.48148 0.51852)
## 6) volatile.acidity< 0.2875 264 102 good (0.61364 0.38636)
## 12) fixed.acidity< 7.45 213 71 good (0.66667 0.33333)
## 24) citric.acid>=0.265 160 42 good (0.73750 0.26250) *
## 25) citric.acid< 0.265 53 24 bad (0.45283 0.54717)
## 50) free.sulfur.dioxide< 42.5 33 13 good (0.60606 0.39394) *
## 51) free.sulfur.dioxide>=42.5 20 4 bad (0.20000 0.80000) *
## 13) fixed.acidity>=7.45 51 20 bad (0.39216 0.60784)
## 26) total.sulfur.dioxide>=150 26 10 good (0.61538 0.38462) *
## 27) total.sulfur.dioxide< 150 25 4 bad (0.16000 0.84000) *
## 7) volatile.acidity>=0.2875 195 59 bad (0.30256 0.69744)
## 14) pH>=3.235 49 24 bad (0.48980 0.51020)
## 28) chlorides< 0.0465 18 4 good (0.77778 0.22222) *
## 29) chlorides>=0.0465 31 10 bad (0.32258 0.67742) *
## 15) pH< 3.235 146 35 bad (0.23973 0.76027) *
rpart.plot(caret.rpart$finalModel)
Para utilizar la regla de “un error estándar” se puede añadir selectionFunction = "oneSE"
en las opciones de entrenamiento41:
set.seed(1)
<- trainControl(method = "cv", number = 10,
trControl selectionFunction = "oneSE")
<- train(taste ~ ., method = "rpart", data = train,
caret.rpart tuneLength = 20, trControl = trControl)
# ggplot(caret.rpart, highlight = TRUE)
caret.rpart
## CART
##
## 1000 samples
## 11 predictor
## 2 classes: 'good', 'bad'
##
## No pre-processing
## Resampling: Cross-Validated (10 fold)
## Summary of sample sizes: 901, 900, 900, 900, 900, 900, ...
## Resampling results across tuning parameters:
##
## cp Accuracy Kappa
## 0.000000 0.70188 0.34873
## 0.005995 0.73304 0.38706
## 0.011990 0.74107 0.38785
## 0.017985 0.72307 0.33745
## 0.023980 0.73607 0.36987
## 0.029975 0.73407 0.35064
## 0.035970 0.73207 0.34182
## 0.041965 0.73508 0.34227
## 0.047960 0.73508 0.34227
## 0.053955 0.73508 0.34227
## 0.059950 0.73508 0.34227
## 0.065945 0.73508 0.34227
## 0.071940 0.73508 0.34227
## 0.077935 0.73508 0.34227
## 0.083930 0.73508 0.34227
## 0.089925 0.73508 0.34227
## 0.095920 0.73508 0.34227
## 0.101915 0.73508 0.34227
## 0.107910 0.72296 0.29433
## 0.113905 0.68096 0.10877
##
## Accuracy was used to select the optimal model using the one SE rule.
## The final value used for the model was cp = 0.10192.
Como cabría esperar, el modelo resultante es más simple (ver Figura 3.15):
rpart.plot(caret.rpart$finalModel)
Adicionalmente, representamos la importancia de los predictores (ver Figura 3.16):
<- varImp(caret.rpart)
var.imp plot(var.imp)
Finalmente, calculamos las predicciones con el método predict.train()
y posteriormente evaluamos su precisión con confusionMatrix()
:
<- predict(caret.rpart, newdata = test)
pred confusionMatrix(pred, test$taste)
## Confusion Matrix and Statistics
##
## Reference
## Prediction good bad
## good 153 54
## bad 13 30
##
## Accuracy : 0.732
## 95% CI : (0.673, 0.786)
## No Information Rate : 0.664
## P-Value [Acc > NIR] : 0.0125
##
## Kappa : 0.317
##
## Mcnemar's Test P-Value : 1.02e-06
##
## Sensitivity : 0.922
## Specificity : 0.357
## Pos Pred Value : 0.739
## Neg Pred Value : 0.698
## Prevalence : 0.664
## Detection Rate : 0.612
## Detection Prevalence : 0.828
## Balanced Accuracy : 0.639
##
## 'Positive' Class : good
##
También podríamos calcular las estimaciones de las probabilidades (añadiendo el argumento type = "prob"
) y, por ejemplo, generar la curva ROC con pROC::roc()
(ver Figura 3.17):
library(pROC)
<- predict(caret.rpart, newdata = test, type = "prob")
p.est <- roc(response = obs, predictor = p.est$good)
roc_tree roc_tree
##
## Call:
## roc.default(response = obs, predictor = p.est$good)
##
## Data: p.est$good in 166 controls (obs good) > 84 cases (obs bad).
## Area under the curve: 0.72
plot(roc_tree, xlab = "Especificidad", ylab = "Sensibilidad")
A la vista de los resultados, los árboles de clasificación no parecen muy adecuados para este problema.
Sobre todo por la mala clasificación en la categoría "bad"
(en este último ajuste se obtuvo un valor de especificidad de 0.3571 en la muestra de test), lo que podría ser debido a que las clases están desbalanceadas y en el ajuste recibe mayor peso el error en la clase mayoritaria.
Para tratar de evitarlo, se podría emplear una matriz de pérdidas a través del componente loss
del argumento parms
(en el Ejercicio 5.1 se propone una aproximación similar).
También se podría pensar en cambiar el criterio de optimalidad.
Por ejemplo, emplear el coeficiente kappa en lugar de la precisión (solo habría que establecer metric = "Kappa"
en la llamada a la función train()
), o el área bajo la curva ROC (ver Ejercicio 5.3).
Ejercicio 3.5 Continuando con el Ejercicio 3.4, emplea caret
para clasificar los individuos mediante un árbol de decisión.
Utiliza la misma partición de los datos, seleccionando el parámetro de complejidad mediante validación cruzada con 10 grupos y el criterio de un error estándar de Breiman.
Representa e interpreta el árbol resultante (comentando la importancia de las variables) y evalúa la precisión de las predicciones en la muestra de test.
Bibliografía
Los parámetros
maxsurrogate
,usesurrogate
ysurrogatestyle
serían de utilidad si hay datos faltantes.↩︎Realmente en la tabla de texto se muestra el valor mínimo de CP, ya que se obtendría la misma solución para un rango de valores de CP (desde ese valor hasta el anterior, sin incluirlo), mientras que en el gráfico generado por
plotcp()
se representa la media geométrica de los extremos de ese intervalo.↩︎Otra opción, más interesante para regresión, sería considerar estos valores como hiperparámetros.↩︎
En principio también se podría utilizar la regla de un error estándar estableciendo
method = "rpart1SE"
en la llamada atrain()
, perocaret
implementa internamente este método y en ocasiones no se obtienen los resultados esperados.↩︎