Skip to content

Commit 812c303

Browse files
committed
remove deterministic in predictive
1 parent ee25cb9 commit 812c303

2 files changed

Lines changed: 6 additions & 1 deletion

File tree

.github/workflows/main.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ jobs:
88
- name: Check out
99
uses: actions/checkout@v2
1010
- name: Build and Deploy Nikola
11-
working-directory: ./site
1211
uses: getnikola/nikola-action@v4
1312
with:
1413
dry_run: false
14+
working-directory: ./site

notebooks/05_the_many_variables_and_the_spurious_waffles.ipynb

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@
192192
"# compute percentile interval of mean\n",
193193
"A_seq = jnp.linspace(start=-3, stop=3.2, num=30)\n",
194194
"post = m5_1.sample_posterior(random.PRNGKey(1), p5_1, sample_shape=(1000,))\n",
195+
"post.pop(\"mu\")\n",
195196
"post_pred = Predictive(m5_1.model, post)(random.PRNGKey(2), A=A_seq)\n",
196197
"mu = post_pred[\"mu\"]\n",
197198
"mu_mean = jnp.mean(mu, 0)\n",
@@ -551,6 +552,7 @@
551552
"outputs": [],
552553
"source": [
553554
"post = m5_4.sample_posterior(random.PRNGKey(1), p5_4, sample_shape=(1000,))\n",
555+
"post.pop(\"mu\")\n",
554556
"post_pred = Predictive(m5_4.model, post)(random.PRNGKey(2), A=d.A.values)\n",
555557
"mu = post_pred[\"mu\"]\n",
556558
"mu_mean = jnp.mean(mu, 0)\n",
@@ -573,6 +575,7 @@
573575
"# call predictive without specifying new data\n",
574576
"# so it uses original data\n",
575577
"post = m5_3.sample_posterior(random.PRNGKey(1), p5_3, (int(1e4),))\n",
578+
"post.pop(\"mu\")\n",
576579
"post_pred = Predictive(m5_3.model, post)(random.PRNGKey(2), M=d.M.values, A=d.A.values)\n",
577580
"mu = post_pred[\"mu\"]\n",
578581
"\n",
@@ -1394,6 +1397,7 @@
13941397
"source": [
13951398
"xseq = jnp.linspace(start=dcc.N.min() - 0.15, stop=dcc.N.max() + 0.15, num=30)\n",
13961399
"post = m5_5.sample_posterior(random.PRNGKey(1), p5_5, sample_shape=(1000,))\n",
1400+
"post.pop(\"mu\")\n",
13971401
"post_pred = Predictive(m5_5.model, post)(random.PRNGKey(2), N=xseq)\n",
13981402
"mu = post_pred[\"mu\"]\n",
13991403
"mu_mean = jnp.mean(mu, 0)\n",
@@ -1610,6 +1614,7 @@
16101614
"source": [
16111615
"xseq = jnp.linspace(start=dcc.N.min() - 0.15, stop=dcc.N.max() + 0.15, num=30)\n",
16121616
"post = m5_7.sample_posterior(random.PRNGKey(1), p5_7, sample_shape=(1000,))\n",
1617+
"post.pop(\"mu\")\n",
16131618
"post_pred = Predictive(m5_7.model, post)(random.PRNGKey(2), M=0, N=xseq)\n",
16141619
"mu = post_pred[\"mu\"]\n",
16151620
"mu_mean = jnp.mean(mu, 0)\n",

0 commit comments

Comments
 (0)