forked from perlatex/R_for_Data_Science
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbayesian_glm.Rmd
395 lines (275 loc) · 8.53 KB
/
bayesian_glm.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
# 贝叶斯广义线性模型 {#bayesian-glm}
```{r, message=FALSE, warning=FALSE}
library(tidyverse)
library(tidybayes)
library(rstan)
library(wesanderson)
rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())
theme_set(
bayesplot::theme_default() +
ggthemes::theme_tufte() +
theme(plot.background = element_rect(fill = wes_palette("Moonrise2")[3],
color = wes_palette("Moonrise2")[3]))
)
```
## 广义线性模型
广义线性模型必须要明确的三个元素:
1. 响应变量的概率分布(Normal, Binomial, Poisson, Categorical, Multinomial, Poisson, Beta)
2. 预测变量的线性组合
$$
\eta = X \vec{\beta}
$$
3. 连接函数($g(.)$), 将期望值映射到预测变量的线性组合
$$
g(\mu) = \eta
$$
连接函数是可逆的(invertible)。连接函数的逆,将预测变量的线性组合**映射**到响应变量的期望值
$$
\mu = g^{-1}(\eta)
$$
通过值域来看,连接函数的逆,将 $\eta$ 从 $(-\infty, +\infty)$ 转换到特定的范围.
### 连接函数
```{r, echo=FALSE}
knitr::include_graphics(here::here("images", "Link_Functions.png"))
```
### 不同分布对应的函数
```{r, echo=FALSE}
knitr::include_graphics(here::here("images", "distributions_and_link_functions.png"))
```
## 研究生院录取时有性别歧视?
这是美国一所大学研究生院的录取人数。我们想看下是否存在性别歧视?
```{r}
UCBadmit <- readr::read_rds(here::here('demo_data', "UCBadmit.rds")) %>%
mutate(applicant_gender = factor(applicant_gender, levels = c("male", "female")))
UCBadmit
```
我们首先将申请者的性别作为预测变量,建立二项式回归模型如下
$$
\begin{align*}
\text{admit}_i & \sim \operatorname{Binomial}(n_i, p_i) \\
\text{logit}(p_i) & = \alpha_{\text{gender}[i]} \\
\alpha_j & \sim \operatorname{Normal}(0, 1.5),
\end{align*}
$$
这里连接函数`logit()`需要说明一下
$$
\begin{align*}
\text{logit}(p_{i}) &= \log\Big(\frac{p_{i}}{1 - p_{i}}\Big) = \alpha_{\text{gender}[i]}\\
\text{equivalent to,} \quad p_{i} &= \frac{1}{1 + \exp[- \alpha_{\text{gender}[i]}]} \\
& = \frac{\exp(\alpha_{\text{gender}[i]})}{1 + \exp (\alpha_{\text{gender}[i]})} \\
& = \text{inv_logit}(\alpha_{\text{gender}[i]})
\end{align*}
$$
R语言glm函数能够拟合一系列的广义线性模型
```{r}
model_logit <- glm(
cbind(admit, rejection) ~ 0 + applicant_gender,
data = UCBadmit,
family = binomial(link = "logit")
)
summary(model_logit)
```
### Stan 代码
```{r}
stan_program_A <- '
data {
int n;
int admit[n];
int applications[n];
int applicant_gender[n];
}
parameters {
real a[2];
}
transformed parameters {
vector[n] p;
for (i in 1:n) {
p[i] = inv_logit(a[applicant_gender[i]]);
}
}
model {
a ~ normal(0, 1.5);
for (i in 1:n) {
admit[i] ~ binomial(applications[i], p[i]);
}
}
'
stan_data <- UCBadmit %>%
compose_data()
fit01 <- stan(model_code = stan_program_A, data = stan_data)
```
```{r}
fit01
```
```{r}
inv_logit <- function(x) {
exp(x) / (1 + exp(x))
}
fit01 %>%
tidybayes::spread_draws(a[i]) %>%
pivot_wider(
names_from = i,
values_from = a,
names_prefix = "a_"
) %>%
mutate(
diff_a = a_1 - a_2,
diff_p = inv_logit(a_1) - inv_logit(a_2)
) %>%
pivot_longer(contains("diff")) %>%
group_by(name) %>%
tidybayes::mean_qi(value, .width = .89)
```
从这个模型的结果看,男性确实有优势。
- 从 log-odd 度量看,男性录取率高于女性录取率
- 从概率的角度看,男性的录取概率比女性高 12% 到 16%
下面我们来看模型拟合的情况
```{r, fig.width = 5, fig.asp = 0.618}
fit01 %>%
tidybayes::gather_draws(p[i]) %>%
tidybayes::mean_qi(.width = .89) %>%
ungroup() %>%
rename(Estimate = .value) %>%
bind_cols(UCBadmit) %>%
ggplot(aes(x = applicant_gender, y = ratio)) +
geom_point(aes(y = Estimate),
color = wes_palette("Moonrise2")[1],
shape = 1, size = 3
) +
geom_point(color = wes_palette("Moonrise2")[2]) +
geom_line(aes(group = dept),
color = wes_palette("Moonrise2")[2]) +
scale_y_continuous(limits = 0:1) +
facet_grid(. ~ dept) +
labs(x = NULL, y = 'Probability of admission')
```
我们能说存在性别歧视?我们发现一些违反直觉的问题:
- 原始数据中只有学院C和E,女性录取率略低于男性,但模型结果却表明,女性预期的录取概率比男性低14%。
- 同时,我们看到男性和女性申请的院系不一样,以下是各学院男女申请人数的比例。女性更倾向与选择A、B之外的学院,比如F学院,而这些学院申请人数比较多,因而男女录取率都很低,甚至不到10%.
也就说,大多女性选择录取率低的学院,从而拉低了女性整体的录取率。
```{r}
UCBadmit %>%
group_by(dept) %>%
mutate(proportion = applications / sum(applications)) %>%
select(dept, applicant_gender, proportion) %>%
pivot_wider(
names_from = dept,
values_from = proportion
) %>%
mutate(
across(where(is.double), round, digits = 2)
)
```
- 模型没有问题,而是我们的提问(对全体学院,男女平均录取率有什么差别?)是有问题的。
因此,我们的提问要修改为:**在每个院系内部,男女平均录取率的差别是多少?**
### 增加预测变量
增加院系项,也就说一个院系对应一个单独的截距,可以捕获院系之间的录取率差别。
$$
\begin{align*}
\text{admit}_i & \sim \operatorname{Binomial} (n_i, p_i) \\
\text{logit}(p_i) & = \alpha_{\text{gender}[i]} + \delta_{\text{dept}[i]} \\
\alpha_j & \sim \operatorname{Normal} (0, 1.5) \\
\delta_k & \sim \operatorname{Normal} (0, 1.5),
\end{align*}
$$
```{r}
stan_program_B <- '
data {
int n;
int n_dept;
int admit[n];
int applications[n];
int applicant_gender[n];
int dept[n];
}
parameters {
real a[2];
real b[n_dept];
}
transformed parameters {
vector[n] p;
for (i in 1:n) {
p[i] = inv_logit(a[applicant_gender[i]] + b[dept[i]]);
}
}
model {
a ~ normal(0, 1.5);
b ~ normal(0, 1.5);
for (i in 1:n) {
admit[i] ~ binomial(applications[i], p[i]);
}
}
'
stan_data <- UCBadmit %>%
compose_data()
fit02 <- stan(model_code = stan_program_B, data = stan_data)
```
```{r}
fit02
```
```{r}
inv_logit <- function(x) {
exp(x) / (1 + exp(x))
}
fit02 %>%
tidybayes::spread_draws(a[i]) %>%
pivot_wider(
names_from = i,
values_from = a,
names_prefix = "a_"
) %>%
mutate(
diff_a = a_1 - a_2,
diff_p = inv_logit(a_1) - inv_logit(a_2)
) %>%
pivot_longer(contains("diff")) %>%
group_by(name) %>%
tidybayes::mean_qi(value, .width = .89)
```
从第二个模型的结果看,男性没有优势,甚至不如女性。
- 从 log-odd 度量看,男性录取率低于女性录取率
- 从概率的角度看,男性的录取概率比女性低 2%
(增加了一个变量,剧情反转了。**辛普森佯谬**)
```{r, fig.width = 5, fig.asp = 0.618}
fit02 %>%
tidybayes::gather_draws(p[i]) %>%
tidybayes::mean_qi(.width = .89) %>%
ungroup() %>%
rename(Estimate = .value) %>%
bind_cols(UCBadmit) %>%
ggplot(aes(x = applicant_gender, y = ratio)) +
geom_point(aes(y = Estimate),
color = wes_palette("Moonrise2")[1],
shape = 1, size = 3
) +
geom_point(color = wes_palette("Moonrise2")[2]) +
geom_line(aes(group = dept),
color = wes_palette("Moonrise2")[2]) +
scale_y_continuous(limits = 0:1) +
facet_grid(. ~ dept) +
labs(x = NULL, y = 'Probability of admission')
```
从图看出,我们第二个模型能够捕捉到数据大部分特征,但还有改进空间。
## 作业
- 讲性别从模型中去除,然后看看模型结果
- 如果我们假定申请人性别影响**院系选择**和**录取率**,把院系当作中介,建立中介模型
```{r, fig.width = 5, fig.asp = 0.6}
library(ggdag)
dag_coords <-
tibble(name = c("G", "D", "A"),
x = c(1, 2, 3),
y = c(1, 2, 1))
dagify(D ~ G,
A ~ D + G,
coords = dag_coords) %>%
ggplot(aes(x = x, y = y, xend = xend, yend = yend)) +
geom_dag_text(color = wes_palette("Moonrise2")[4], family = "serif") +
geom_dag_edges(edge_color = wes_palette("Moonrise2")[4]) +
scale_x_continuous(NULL, breaks = NULL) +
scale_y_continuous(NULL, breaks = NULL)
```
```{r, echo = F, message = F, warning = F, results = "hide"}
ggplot2::theme_set(ggplot2::theme_grey())
pacman::p_unload(pacman::p_loaded(), character.only = TRUE)
```