-
Notifications
You must be signed in to change notification settings - Fork 69
/
Copy path07-moving-beyond-linearity.qmd
297 lines (225 loc) · 11.4 KB
/
07-moving-beyond-linearity.qmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
# Moving Beyond Linearity
```{r}
#| echo: false
set.seed(1234)
source("_common.R")
```
This lab will look at the various ways we can introduce non-linearity into our model by doing preprocessing. Methods include: polynomials expansion, step functions, and splines.
The GAMs section is WIP since they are now supported in [parsnip](https://github.com/tidymodels/parsnip/pull/512).
This chapter will use [parsnip](https://www.tidymodels.org/start/models/) for model fitting and [recipes and workflows](https://www.tidymodels.org/start/recipes/) to perform the transformations.
```{r}
#| message: false
library(tidymodels)
library(ISLR)
Wage <- as_tibble(Wage)
```
## Polynomial Regression and Step Functions
Polynomial regression can be thought of as doing polynomial expansion on a variable and passing that expansion into a linear regression model. We will be very explicit in this formulation in this chapter. `step_poly()` allows us to do a polynomial expansion on one or more variables.
The following step will take `age` and replace it with the variables `age`, `age^2`, `age^3`, and `age^4` since we set `degree = 4`.
```{r}
rec_poly <- recipe(wage ~ age, data = Wage) %>%
step_poly(age, degree = 4)
```
This recipe is combined with a linear regression specification and combined to create a workflow object.
```{r}
lm_spec <- linear_reg() %>%
set_mode("regression") %>%
set_engine("lm")
poly_wf <- workflow() %>%
add_model(lm_spec) %>%
add_recipe(rec_poly)
```
This object can now be `fit()`
```{r}
poly_fit <- fit(poly_wf, data = Wage)
poly_fit
```
And we cal pull the coefficients using `tidy()`
```{r}
tidy(poly_fit)
```
I was lying when I said that `step_poly()` returned `age`, `age^2`, `age^3`, and `age^4`. What is happening is that it returns variables that are a basis of orthogonal polynomials, which means that each of the columns is a linear combination of the variables `age`, `age^2`, `age^3`, and `age^4`. We can see this by using `poly()` directly with `raw = FALSE` since it is the default
```{r}
poly(1:6, degree = 4, raw = FALSE)
```
We see that these variables don't directly have a format we would have assumed. But this is still a well-reasoned transformation.
We can get the raw polynomial transformation by setting `raw = TRUE`
```{r}
poly(1:6, degree = 4, raw = TRUE)
```
These transformations align with what we would expect. It is still recommended to stick with the default of `raw = FALSE` unless you have a reason not to do that.
One of the benefits of using `raw = FALSE` is that the resulting variables are uncorrelated which is a desirable quality when using a linear regression model.
You can get the raw polynomials by setting `options = list(raw = TRUE)` in `step_poly()`
```{r}
rec_raw_poly <- recipe(wage ~ age, data = Wage) %>%
step_poly(age, degree = 4, options = list(raw = TRUE))
raw_poly_wf <- workflow() %>%
add_model(lm_spec) %>%
add_recipe(rec_raw_poly)
raw_poly_fit <- fit(raw_poly_wf, data = Wage)
tidy(raw_poly_fit)
```
Let us try something new and visualize the polynomial fit on our data. We can do this easily because we only have 1 predictor and 1 response. Starting with creating a tibble with different ranges of `age`. Then we take this tibble and predict with it, this will give us the repression curve. We are additionally adding confidence intervals by setting `type = "conf_int"` which we can do since we are using a linear regression model.
```{r}
age_range <- tibble(age = seq(min(Wage$age), max(Wage$age)))
regression_lines <- bind_cols(
augment(poly_fit, new_data = age_range),
predict(poly_fit, new_data = age_range, type = "conf_int")
)
regression_lines
```
We will then use `ggplot2` to visualize the fitted line and confidence interval. The green line is the regression curve and the dashed blue lines are the confidence interval.
```{r}
#| fig-alt: |
#| Scatter chart, age against the x-axis and wage against y-axis.
#| Fairly normally distributed around wage == 100, with some
#| another blob around wage == 275. A curve in dark green follows
#| the middle of the data with two dottled curves follows closely
#| around.
Wage %>%
ggplot(aes(age, wage)) +
geom_point(alpha = 0.2) +
geom_line(aes(y = .pred), color = "darkgreen",
data = regression_lines) +
geom_line(aes(y = .pred_lower), data = regression_lines,
linetype = "dashed", color = "blue") +
geom_line(aes(y = .pred_upper), data = regression_lines,
linetype = "dashed", color = "blue")
```
The regression curve is now a curve instead of a line as we would have gotten with a simple linear regression model. Notice furthermore that the confidence bands are tighter when there is a lot of data and they wider towards the ends of the data.
Let us take that one step further and see what happens to the regression line once we go past the domain it was trained on. the previous plot showed individuals within the age range 18-80. Let us see what happens once we push this to 18-100. This is not an impossible range but an unrealistic range.
```{r}
#| fig-alt: |
#| Scatter chart, age against the x-axis and wage against y-axis.
#| Fairly normally distributed around wage == 100, with some
#| another blob around wage == 275. A curve in dark green follows
#| the middle of the data with two dottled curves follows closely
#| around. The range for age has been increased beyond the data
#| points and the green curve trails negative and the dotted lines
#| quickly move away from the green curve.
wide_age_range <- tibble(age = seq(18, 100))
regression_lines <- bind_cols(
augment(poly_fit, new_data = wide_age_range),
predict(poly_fit, new_data = wide_age_range, type = "conf_int")
)
Wage %>%
ggplot(aes(age, wage)) +
geom_point(alpha = 0.2) +
geom_line(aes(y = .pred), color = "darkgreen",
data = regression_lines) +
geom_line(aes(y = .pred_lower), data = regression_lines,
linetype = "dashed", color = "blue") +
geom_line(aes(y = .pred_upper), data = regression_lines,
linetype = "dashed", color = "blue")
```
And we see that the curve starts diverging once we get to `r regression_lines %>% filter(.pred < 0) %>% slice(1) %>% pull(age)` the predicted `wage` is negative. The confidence bands also get wider and wider as we get farther away from the data.
We can also think of this problem as a classification problem, and we will do that just now by setting us the task of predicting whether an individual earns more than $250000 per year. We will add a new factor value denoting this response.
```{r}
Wage <- Wage %>%
mutate(high = factor(wage > 250,
levels = c(TRUE, FALSE),
labels = c("High", "Low")))
```
We cannot use the polynomial expansion recipe `rec_poly` we created earlier since it had `wage` as the response and now we want to have `high` as the response.
We also have to create a logistic regression specification that we will use as our classification model.
```{r}
rec_poly <- recipe(high ~ age, data = Wage) %>%
step_poly(age, degree = 4)
lr_spec <- logistic_reg() %>%
set_engine("glm") %>%
set_mode("classification")
lr_poly_wf <- workflow() %>%
add_model(lr_spec) %>%
add_recipe(rec_poly)
```
This polynomial logistic regression model workflow can now be fit and predicted with as usual.
```{r}
lr_poly_fit <- fit(lr_poly_wf, data = Wage)
predict(lr_poly_fit, new_data = Wage)
```
If we want we can also get back the underlying probability predictions for the two classes, and their confidence intervals for these probability predictions by setting `type = "prob"` and `type = "conf_int"`.
```{r}
predict(lr_poly_fit, new_data = Wage, type = "prob")
predict(lr_poly_fit, new_data = Wage, type = "conf_int")
```
We can use these to visualize the probability curve for the classification model.
```{r}
#| warning: false
#| fig-alt: |
#| Line chart with age on the x-axis and .pred_High on the
#| y-axis. The green curve starts at zero for low values age.
#| A local maxima is seen at 35 and 60. Curve goes back to zero
#| around 80. Two blue dotted lines representing the confidence
#| interval around the green curve. This confidence interval
#| is around 1% away from the green curve excepts when age is
#| larger than 60, where it quickly widens.
regression_lines <- bind_cols(
augment(lr_poly_fit, new_data = age_range, type = "prob"),
predict(lr_poly_fit, new_data = age_range, type = "conf_int")
)
regression_lines %>%
ggplot(aes(age)) +
ylim(c(0, 0.2)) +
geom_line(aes(y = .pred_High), color = "darkgreen") +
geom_line(aes(y = .pred_lower_High), color = "blue", linetype = "dashed") +
geom_line(aes(y = .pred_upper_High), color = "blue", linetype = "dashed") +
geom_jitter(aes(y = (high == "High") / 5), data = Wage,
shape = "|", height = 0, width = 0.2)
```
Next, let us take a look at the step function and how to fit a model using it as a preprocessor. You can create step functions in a couple of different ways. `step_discretize()` will convert a numeric variable into a factor variable with `n` bins, `n` here is specified with `num_breaks`. These will have approximately the same number of points in them according to the training data set.
```{r}
rec_discretize <- recipe(high ~ age, data = Wage) %>%
step_discretize(age, num_breaks = 4)
discretize_wf <- workflow() %>%
add_model(lr_spec) %>%
add_recipe(rec_discretize)
discretize_fit <- fit(discretize_wf, data = Wage)
discretize_fit
```
If you already know where you want the step function to break then you can use `step_cut()` and supply the breaks manually.
```{r}
rec_cut <- recipe(high ~ age, data = Wage) %>%
step_cut(age, breaks = c(30, 50, 70))
cut_wf <- workflow() %>%
add_model(lr_spec) %>%
add_recipe(rec_cut)
cut_fit <- fit(cut_wf, data = Wage)
cut_fit
```
## Splines
In order to fit regression splines, or in other words, use splines as preprocessors when fitting a linear model, we use `step_bs()` to construct the matrices of basis functions. The `bs()` function is used and arguments such as `knots` can be passed to `bs()` by using passing a named list to `options`.
```{r}
rec_spline <- recipe(wage ~ age, data = Wage) %>%
step_bs(age, options = list(knots = 25, 40, 60))
```
We already have the linear regression specification `lm_spec` so we can create the workflow, fit the model and predict with it like we have seen how to do in the previous chapters.
```{r}
spline_wf <- workflow() %>%
add_model(lm_spec) %>%
add_recipe(rec_spline)
spline_fit <- fit(spline_wf, data = Wage)
predict(spline_fit, new_data = Wage)
```
Lastly, we can plot the basic spline on top of the data.
```{r}
#| fig-alt: |
#| Scatter chart, age against the x-axis and wage against y-axis.
#| Fairly normally distributed around wage == 100, with some
#| another blob around wage == 275. A curve in dark green follows
#| the middle of the data with two dottled curves follows closely
#| around.
regression_lines <- bind_cols(
augment(spline_fit, new_data = age_range),
predict(spline_fit, new_data = age_range, type = "conf_int")
)
Wage %>%
ggplot(aes(age, wage)) +
geom_point(alpha = 0.2) +
geom_line(aes(y = .pred), data = regression_lines, color = "darkgreen") +
geom_line(aes(y = .pred_lower), data = regression_lines,
linetype = "dashed", color = "blue") +
geom_line(aes(y = .pred_upper), data = regression_lines,
linetype = "dashed", color = "blue")
```
## GAMs
GAM section is WIP since they are now supported in [parsnip](https://github.com/tidymodels/parsnip/pull/512).