7.4 Regresión spline adaptativa multivariante
La regresión spline adaptativa multivariante, en inglés multivariate adaptive regression splines [MARS; Friedman (1991)], es un procedimiento adaptativo para problemas de regresión que puede verse como una generalización tanto de la regresión lineal por pasos (stepwise linear regression) como de los árboles de decisión CART.
El modelo MARS es un spline multivariante lineal:
\[m(\mathbf{x}) = \beta_0 + \sum_{m=1}^M \beta_m h_m(\mathbf{x})\]
(es un modelo lineal en transformaciones \(h_m(\mathbf{x})\) de los predictores originales), donde las bases \(h_m(\mathbf{x})\) se construyen de forma adaptativa empleando funciones bisagra (hinge functions)
\[ h(x) = (x)_+ = \left\{ \begin{array}{ll}
x & \mbox{si } x > 0 \\
0 & \mbox{si } x \leq 0
\end{array}
\right.\]
y considerando como posibles nodos los valores observados de los predictores
(en el caso univariante se emplean las bases de potencias truncadas con \(d=1\) descritas en la Sección 7.2.1, pero incluyendo también su versión simetrizada).
Vamos a empezar explicando el modelo MARS aditivo (sin interacciones), que funciona de forma muy parecida a los árboles de decisión CART, y después lo extenderemos al caso con interacciones. Asumimos que todas las variables predictoras son numéricas. El proceso de construcción del modelo es un proceso iterativo hacia delante (forward) que empieza con el modelo \[\hat m(\mathbf{x}) = \hat \beta_0 \] donde \(\hat \beta_0\) es la media de todas las respuestas, para a continuación considerar todos los puntos de corte (knots) posibles \(x_{ji}\) con \(i = 1, 2, \ldots, n\), \(j = 1, 2, \ldots, p\), es decir, todas las observaciones de todas las variables predictoras de la muestra de entrenamiento. Para cada punto de corte \(x_{ji}\) (combinación de variable y observación) se consideran dos bases: \[ \begin{aligned} h_1(\mathbf{x}) = h(x_j - x_{ji}) \\ h_2(\mathbf{x}) = h(x_{ji} - x_j) \end{aligned}\] y se construye el nuevo modelo \[\hat m(\mathbf{x}) = \hat \beta_0 + \hat \beta_1 h_1(\mathbf{x}) + \hat \beta_2 h_2(\mathbf{x})\] La estimación de los parámetros \(\beta_0, \beta_1, \beta_2\) se realiza de la forma estándar en regresión lineal, minimizando \(\mbox{RSS}\). De este modo se construyen muchos modelos alternativos y entre ellos se selecciona aquel que tenga un menor error de entrenamiento. En la siguiente iteración se conservan \(h_1(\mathbf{x})\) y \(h_2(\mathbf{x})\) y se añade una pareja de términos nuevos siguiendo el mismo procedimiento. Y así sucesivamente, añadiendo de cada vez dos nuevos términos. Este procedimiento va creando un modelo lineal segmentado (piecewise) donde cada nuevo término modeliza una porción aislada de los datos originales.
El tamaño de cada modelo es el número términos (funciones \(h_m\)) que este incorpora. El proceso iterativo se para cuando se alcanza un modelo de tamaño \(M\), que se consigue después de incorporar \(M/2\) cortes. Este modelo depende de \(M+1\) parámetros \(\beta_m\) con \(m=0,1,\ldots,M\). El objetivo es alcanzar un modelo lo suficientemente grande para que sobreajuste los datos, para a continuación proceder a su poda en un proceso de eliminación de variables hacia atrás (backward deletion) en el que se van eliminando las variables de una en una (no por parejas, como en la construcción). En cada paso de poda se elimina el término que produce el menor incremento en el error. Así, para cada tamaño \(\lambda = 0,1,\ldots, M\) se obtiene el mejor modelo estimado \(\hat{m}_{\lambda}\).
La selección óptima del valor del hiperparámetro \(\lambda\) puede realizarse por los procedimientos habituales tipo validación cruzada. Una alternativa mucho más rápida es utilizar validación cruzada generalizada (GCV) que es una aproximación de la validación cruzada leave-one-out mediante la fórmula \[\mbox{GCV} (\lambda) = \frac{\mbox{RSS}}{(1-M(\lambda)/n)^2}\] donde \(M(\lambda)\) es el número de parámetros efectivos del modelo, que depende del número de términos más el número de puntos de corte utilizados penalizado por un factor (2 en el caso aditivo que estamos explicando, 3 cuando hay interacciones).
Hemos explicado una caso particular de MARS: el modelo aditivo. El modelo general sólo se diferencia del caso aditivo en que se permiten interacciones, es decir, multiplicaciones entre las variables \(h_m(\mathbf{x})\). Para ello, en cada iteración durante la fase de construcción del modelo, además de considerar todos los puntos de corte, también se consideran todas las combinaciones con los términos incorporados previamente al modelo, denominados términos padre. De este modo, si resulta seleccionado un término padre \(h_l(\mathbf{x})\) (incluyendo \(h_0(\mathbf{x}) = 1\)) y un punto de corte \(x_{ji}\), después de analizar todas las posibilidades, al modelo anterior se le agrega \[\hat \beta_{m+1} h_l(\mathbf{x}) h(x_j - x_{ji}) + \hat \beta_{m+2} h_l(\mathbf{x}) h(x_{ji} - x_j)\] Recordando que en cada caso se vuelven a estimar todos los parámetros \(\beta_i\).
Al igual que \(\lambda\), también el grado de interacción máxima permitida se considera un hiperparámetro del problema, aunque lo habitual es trabajar con grado 1 (modelo aditivo) o interacción de grado 2. Una restricción adicional que se impone al modelo es que en cada producto no puede aparecer más de una vez la misma variable \(X_j\).
Aunque el procedimiento de construcción del modelo realiza búsquedas exhaustivas y en consecuencia puede parecer computacionalmente intratable, en la práctica se realiza de forma razonablemente rápida, al igual que ocurría en CART. Una de las principales ventajas de MARS es que realiza una selección automática de las variables predictoras. Aunque inicialmente pueda haber muchos predictores, y este método es adecuado para problemas de alta dimensión, en el modelo final van a aparecer muchos menos (pueden aparecer más de una vez). Además, si se utiliza un modelo aditivo su interpretación es directa, e incluso permitiendo interacciones de grado 2 el modelo puede ser interpretado. Otra ventaja es que no es necesario realizar un prepocesado de los datos, ni filtrando variables ni transformando los datos. Que haya predictores con correlaciones altas no va a afectar a la construcción del modelo (normalmente seleccionará el primero), aunque sí puede dificultar su interpretación. Aunque hemos supuesto al principio de la explicación que los predictores son numéricos, se pueden incorporar variables predictoras cualitativas siguiendo los procedimientos estándar. Por último, se puede realizar una cuantificación de la importancia de las variables de forma similar a como se hace en CART.
En conclusión, MARS utiliza splines lineales con una selección automática de los puntos de corte mediante un algoritmo avaricioso similar al empleado en los árboles CART, tratando de añadir más puntos de corte donde aparentemente hay más variaciones en la función de regresión y menos puntos donde esta es más estable.
7.4.1 MARS con el paquete earth
Actualmente el paquete de referencia para MARS es earth
(Enhanced Adaptive Regression Through Hinges)50.
La función principal es earth()
y se suelen considerar los siguientes argumentos:
earth(formula, data, glm = NULL, degree = 1, ...)
formula
ydata
(opcional): permiten especificar la respuesta y las variables predictoras de la forma habitual (e.g.respuesta ~ .
; también admite matrices). Admite respuestas multidimensionales (ajustará un modelo para cada componente) y categóricas (las convierte en multivariantes), también predictores categóricos, aunque no permite datos faltantes.glm
: lista con los parámetros del ajuste GLM (e.g.glm = list(family = binomial)
).degree
: grado máximo de interacción; por defecto 1 (modelo aditivo).
Otros parámetros que pueden ser de interés (afectan a la complejidad del modelo en el crecimiento, a la selección del modelo final o al tiempo de computación; para más detalles ver help(earth)
):
nk
: número máximo de términos en el crecimiento del modelo (dimensión \(M\) de la base); por defectomin(200, max(20, 2 * ncol(x))) + 1
(puede ser demasiado pequeña si muchos de los predictores influyen en la respuesta).thresh
: umbral de parada en el crecimiento (se interpretaría comocp
en los árboles CART); por defecto 0.001 (si se establece a 0 la única condición de parada será alcanzar el valor máximo de términosnk
).fast.k
: número máximo de términos padre considerados en cada paso durante el crecimiento; por defecto 20, si se establece a 0 no habrá limitación.linpreds
: índice de variables que se considerarán con efecto lineal.nprune
: número máximo de términos (incluida la intersección) en el modelo final (después de la poda); por defecto no hay límite (se podrían incluir todos los creados durante el crecimiento).pmethod
: método empleado para la poda; por defecto"backward"
. Otras opciones son:"forward"
,"seqrep"
,"exhaustive"
(emplea los métodos de selección implementados en paqueteleaps
),"cv"
(validación cruzada, empleandonflod
) y"none"
para no realizar poda.nfold
: número de grupos de validación cruzada; por defecto 0 (no se hace validación cruzada).varmod.method
: permite seleccionar un método para estimar las varianzas y por ejemplo poder realizar contrastes o construir intervalos de confianza (para más detalles ver?varmod
o la vignette “Variance models in earth”).
Utilizaremos como ejemplo inicial los datos de MASS:mcycle
(ver Figura 7.12):
# data(mcycle, package = "MASS")
library(earth)
<- earth(accel ~ times, data = mcycle)
mars # mars
summary(mars)
## Call: earth(formula=accel~times, data=mcycle)
##
## coefficients
## (Intercept) -90.992956
## h(19.4-times) 8.072585
## h(times-19.4) 9.249999
## h(times-31.2) -10.236495
##
## Selected 4 of 6 terms, and 1 of 1 predictors
## Termination condition: RSq changed by less than 0.001 at 6 terms
## Importance: times
## Number of terms at each degree of interaction: 1 3 (additive model)
## GCV 1119.813 RSS 133670.3 GRSq 0.5240328 RSq 0.5663192
plot(mars)
Por defecto, se representa un resumen de los errores de validación en la selección del modelo, la distribución empírica y el gráfico QQ de los residuos, y los residuos frente a las predicciones (en la muestra de entrenamiento).
Podemos representar el ajuste obtenido (ver Figura 7.13):
plot(accel ~ times, data = mcycle, col = 'darkgray')
lines(mcycle$times, predict(mars))
Como con las opciones por defecto el ajuste no es muy bueno (aunque puede ser suficiente para un análisis preliminar), podríamos forzar la complejidad del modelo en el crecimiento (minspan = 1
permite que todas las observaciones sean potenciales nodos; ver Figura 7.14):
<- earth(accel ~ times, data = mcycle, minspan = 1, thresh = 0)
mars2 summary(mars2)
## Call: earth(formula=accel~times, data=mcycle, minspan=1, thresh=0)
##
## coefficients
## (Intercept) -6.274366
## h(times-14.6) -25.333056
## h(times-19.2) 32.979264
## h(times-25.4) 153.699248
## h(times-25.6) -145.747392
## h(times-32) -30.041076
## h(times-35.2) 13.723887
##
## Selected 7 of 12 terms, and 1 of 1 predictors
## Termination condition: Reached nk 21
## Importance: times
## Number of terms at each degree of interaction: 1 6 (additive model)
## GCV 623.5209 RSS 67509.03 GRSq 0.7349776 RSq 0.7809732
plot(accel ~ times, data = mcycle, col = 'darkgray')
lines(mcycle$times, predict(mars2))
Como siguiente ejemplo consideramos los datos de carData::Prestige
(ver Figura 7.15):
# data(Prestige, package = "carData")
<- earth(prestige ~ education + income + women, data = Prestige,
mars degree = 2, nk = 40)
summary(mars)
## Call: earth(formula=prestige~education+income+women, data=Prestige, degree=2,
## nk=40)
##
## coefficients
## (Intercept) 19.9845240
## h(education-9.93) 5.7683265
## h(income-3161) 0.0085297
## h(income-5795) -0.0080222
## h(women-33.57) 0.2154367
## h(income-5299) * h(women-4.14) -0.0005163
## h(income-5795) * h(women-4.28) 0.0005409
##
## Selected 7 of 31 terms, and 3 of 3 predictors
## Termination condition: Reached nk 40
## Importance: education, income, women
## Number of terms at each degree of interaction: 1 4 2
## GCV 53.08737 RSS 3849.355 GRSq 0.8224057 RSq 0.8712393
plot(mars)
Para representar los efectos de las variables importa las herramientas del paquete plotmo
(del mismo autor; válido también para la mayoría de los modelos tratados en este libro, incluyendo mgcv::gam()
; ver Figura 7.16):
plotmo(mars)
## plotmo grid: education income women
## 10.54 5930 13.6
También podemos obtener la importancia de las variables (función evimp()
) y representarla gráficamente (método plot.evimp()
; ver Figura 7.17):
<- evimp(mars)
varimp varimp
## nsubsets gcv rss
## education 6 100.0 100.0
## income 5 36.0 40.3
## women 3 16.3 22.0
plot(varimp)
Para finalizar, destacar que podríamos tener en cuenta este modelo como punto de partida para ajustar un modelo GAM más flexible (como se mostró en la Sección 7.3). Por ejemplo:
# library(mgcv)
<- gam(prestige ~ s(education) + s(income) + s(women), data = Prestige, select = TRUE)
gam summary(gam)
##
## Family: gaussian
## Link function: identity
##
## Formula:
## prestige ~ s(education) + s(income) + s(women)
##
## Parametric coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 46.8333 0.6461 72.49 <2e-16 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Approximate significance of smooth terms:
## edf Ref.df F p-value
## s(education) 2.349 9 9.926 < 2e-16 ***
## s(income) 6.289 9 7.420 < 2e-16 ***
## s(women) 1.964 9 1.309 0.00149 **
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## R-sq.(adj) = 0.856 Deviance explained = 87.1%
## GCV = 48.046 Scale est. = 42.58 n = 102
<- gam(prestige ~ s(education) + s(income, women), data = Prestige)
gam2 summary(gam2)
##
## Family: gaussian
## Link function: identity
##
## Formula:
## prestige ~ s(education) + s(income, women)
##
## Parametric coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 46.833 0.679 68.97 <2e-16 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Approximate significance of smooth terms:
## edf Ref.df F p-value
## s(education) 2.802 3.489 25.09 <2e-16 ***
## s(income,women) 4.895 6.286 10.03 <2e-16 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## R-sq.(adj) = 0.841 Deviance explained = 85.3%
## GCV = 51.416 Scale est. = 47.032 n = 102
anova(gam, gam2, test = "F")
## Analysis of Deviance Table
##
## Model 1: prestige ~ s(education) + s(income) + s(women)
## Model 2: prestige ~ s(education) + s(income, women)
## Resid. Df Resid. Dev Df Deviance F Pr(>F)
## 1 88.325 3849.1
## 2 91.225 4388.3 -2.9001 -539.16 4.3661 0.00705 **
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
plotmo(gam2)
## plotmo grid: education income women
## 10.54 5930 13.6
En la Figura 7.18 (generada con plotmo::plotmo()
) se representan los efectos parciales de las componentes, y en la Figura 7.19 el efecto parcial de la interacción (empleando plot()
):
plot(gam2, scheme = 2, select = 2)
Ejercicio 7.3 Comentar brevemente los resultados del ajuste del modelo GAM del ejemplo anterior.
¿Observas algo extraño en el contraste ANOVA?
(Probar a ejecutar anova(gam2, gam, test = "F")
.)
7.4.2 MARS con el paquete caret
Emplearemos como ejemplo el conjunto de datos earth::Ozone1
:
# data(ozone1, package = "earth")
<- ozone1
df set.seed(1)
<- nrow(df)
nobs <- sample(nobs, 0.8 * nobs)
itrain <- df[itrain, ]
train <- df[-itrain, ] test
caret
implementa varios métodos basados en earth
, en este caso emplearemos el algoritmo original:
library(caret)
# names(getModelInfo("[Ee]arth")) # 4 métodos
modelLookup("earth")
## model parameter label forReg forClass probModel
## 1 earth nprune #Terms TRUE TRUE TRUE
## 2 earth degree Product Degree TRUE TRUE TRUE
Para selección de los hiperparámetros óptimos consideramos una rejilla de búsqueda personalizada:
<- expand.grid(degree = 1:2,
tuneGrid nprune = floor(seq(2, 20, len = 10)))
set.seed(1)
<- train(O3 ~ ., data = train, method = "earth",
caret.mars trControl = trainControl(method = "cv", number = 10),
tuneGrid = tuneGrid)
caret.mars
## Multivariate Adaptive Regression Spline
##
## 264 samples
## 9 predictor
##
## No pre-processing
## Resampling: Cross-Validated (10 fold)
## Summary of sample sizes: 238, 238, 238, 236, 237, 239, ...
## Resampling results across tuning parameters:
##
## degree nprune RMSE Rsquared MAE
## 1 2 4.842924 0.6366661 3.803870
## 1 4 4.558953 0.6834467 3.488040
## 1 6 4.345781 0.7142046 3.413213
## 1 8 4.256592 0.7295113 3.220256
## 1 10 4.158604 0.7436812 3.181941
## 1 12 4.128416 0.7509562 3.142176
## 1 14 4.069714 0.7600561 3.061458
## 1 16 4.058769 0.7609245 3.058843
## 1 18 4.058769 0.7609245 3.058843
## 1 20 4.058769 0.7609245 3.058843
## 2 2 4.842924 0.6366661 3.803870
## 2 4 4.652783 0.6725979 3.540031
## [ reached getOption("max.print") -- omitted 8 rows ]
##
## RMSE was used to select the optimal model using the smallest value.
## The final values used for the model were nprune = 10 and degree = 2.
ggplot(caret.mars, highlight = TRUE)
Podemos analizar el modelo final con las herramientas de earth
:
summary(caret.mars$finalModel)
## Call: earth(x=matrix[264,9], y=c(4,13,16,3,6,2...), keepxy=TRUE, degree=2,
## nprune=10)
##
## coefficients
## (Intercept) 11.6481994
## h(dpg-15) -0.0743900
## h(ibt-110) 0.1224848
## h(17-vis) -0.3363332
## h(vis-17) -0.0110360
## h(101-doy) -0.1041604
## h(doy-101) -0.0236813
## h(wind-3) * h(1046-ibh) -0.0023406
## h(humidity-52) * h(15-dpg) -0.0047940
## h(60-humidity) * h(ibt-110) -0.0027632
##
## Selected 10 of 21 terms, and 7 of 9 predictors (nprune=10)
## Termination condition: Reached nk 21
## Importance: humidity, ibt, dpg, doy, wind, ibh, vis, temp-unused, ...
## Number of terms at each degree of interaction: 1 6 3
## GCV 13.84161 RSS 3032.585 GRSq 0.7846289 RSq 0.8199031
Representamos los efectos parciales de las componentes, separando los efectos principales (ver Figura 7.21) de las interacciones (ver Figura 7.22):
# plotmo(caret.mars$finalModel)
plotmo(caret.mars$finalModel, degree2 = 0, caption = "")
## plotmo grid: vh wind humidity temp ibh dpg ibt vis doy
## 5770 5 64.5 62 2046.5 24 169.5 100 213.5
plotmo(caret.mars$finalModel, degree1 = 0, caption = "")
Finalmente evaluamos la precisión de las predicciones en la muestra de test con el procedimiento habitual:
<- predict(caret.mars, newdata = test)
pred accuracy(pred, test$O3)
## me rmse mae mpe mape r.squared
## 0.4817913 4.0952444 3.0764376 -14.1288949 41.2602037 0.7408061
References
Desarrollado a partir de la función
mda::mars()
de T. Hastie y R. Tibshirani. Utiliza este nombre porque MARS está registrado para un uso comercial por Salford Systems.↩︎