|
192 | 192 | "# compute percentile interval of mean\n", |
193 | 193 | "A_seq = jnp.linspace(start=-3, stop=3.2, num=30)\n", |
194 | 194 | "post = m5_1.sample_posterior(random.PRNGKey(1), p5_1, sample_shape=(1000,))\n", |
| 195 | + "post.pop(\"mu\")\n", |
195 | 196 | "post_pred = Predictive(m5_1.model, post)(random.PRNGKey(2), A=A_seq)\n", |
196 | 197 | "mu = post_pred[\"mu\"]\n", |
197 | 198 | "mu_mean = jnp.mean(mu, 0)\n", |
|
551 | 552 | "outputs": [], |
552 | 553 | "source": [ |
553 | 554 | "post = m5_4.sample_posterior(random.PRNGKey(1), p5_4, sample_shape=(1000,))\n", |
| 555 | + "post.pop(\"mu\")\n", |
554 | 556 | "post_pred = Predictive(m5_4.model, post)(random.PRNGKey(2), A=d.A.values)\n", |
555 | 557 | "mu = post_pred[\"mu\"]\n", |
556 | 558 | "mu_mean = jnp.mean(mu, 0)\n", |
|
573 | 575 | "# call predictive without specifying new data\n", |
574 | 576 | "# so it uses original data\n", |
575 | 577 | "post = m5_3.sample_posterior(random.PRNGKey(1), p5_3, (int(1e4),))\n", |
| 578 | + "post.pop(\"mu\")\n", |
576 | 579 | "post_pred = Predictive(m5_3.model, post)(random.PRNGKey(2), M=d.M.values, A=d.A.values)\n", |
577 | 580 | "mu = post_pred[\"mu\"]\n", |
578 | 581 | "\n", |
|
1394 | 1397 | "source": [ |
1395 | 1398 | "xseq = jnp.linspace(start=dcc.N.min() - 0.15, stop=dcc.N.max() + 0.15, num=30)\n", |
1396 | 1399 | "post = m5_5.sample_posterior(random.PRNGKey(1), p5_5, sample_shape=(1000,))\n", |
| 1400 | + "post.pop(\"mu\")\n", |
1397 | 1401 | "post_pred = Predictive(m5_5.model, post)(random.PRNGKey(2), N=xseq)\n", |
1398 | 1402 | "mu = post_pred[\"mu\"]\n", |
1399 | 1403 | "mu_mean = jnp.mean(mu, 0)\n", |
|
1610 | 1614 | "source": [ |
1611 | 1615 | "xseq = jnp.linspace(start=dcc.N.min() - 0.15, stop=dcc.N.max() + 0.15, num=30)\n", |
1612 | 1616 | "post = m5_7.sample_posterior(random.PRNGKey(1), p5_7, sample_shape=(1000,))\n", |
| 1617 | + "post.pop(\"mu\")\n", |
1613 | 1618 | "post_pred = Predictive(m5_7.model, post)(random.PRNGKey(2), M=0, N=xseq)\n", |
1614 | 1619 | "mu = post_pred[\"mu\"]\n", |
1615 | 1620 | "mu_mean = jnp.mean(mu, 0)\n", |
|
0 commit comments