From 327e3011c1aea6e8affdc5ad2f1c9816e832c6ab Mon Sep 17 00:00:00 2001 From: Alex Phillips Date: Thu, 17 Aug 2023 12:01:09 +0100 Subject: [PATCH 1/2] test: add tests for resampling --- test/generate_test_data.py | 13 +++ .../eight_schools_resampled_samples.npy | Bin 0 -> 32128 bytes test/test_resample.py | 108 ++++++++++++++++++ 3 files changed, 121 insertions(+) create mode 100644 test/test_models/eight_schools/eight_schools_resampled_samples.npy create mode 100644 test/test_resample.py diff --git a/test/generate_test_data.py b/test/generate_test_data.py index db4a535..59e1682 100644 --- a/test/generate_test_data.py +++ b/test/generate_test_data.py @@ -52,3 +52,16 @@ log_weights = new_logProbs - logProbs log_weights = log_weights - logsumexp(log_weights) np.save(os.path.join(model_path, "eight_schools_log_weights.npy"), log_weights) + +nsamples = samples.shape[0]*samples.shape[1] +tmp_samples = samples.reshape((nsamples, 1, samples.shape[2])) +tmp_log_weights = log_weights.reshape((nsamples)) + +rng = np.random.default_rng(seed=0) +resampled_iterations = rng.choice( + nsamples, + size=nsamples, + p=np.exp(tmp_log_weights)) + +resampled_samples = tmp_samples[resampled_iterations, :, :] +np.save(os.path.join(model_path, "eight_schools_resampled_samples.npy"), resampled_samples) diff --git a/test/test_models/eight_schools/eight_schools_resampled_samples.npy b/test/test_models/eight_schools/eight_schools_resampled_samples.npy new file mode 100644 index 0000000000000000000000000000000000000000..3d07bd5184e90d6586c66a681fe6b51a593ccfe1 GIT binary patch literal 32128 zcmb_l1#=Zi*N(g08~3Ym7a-U?xCeK4cU|1wA-KCc1iyp3JHahjump*}r@Qn0fv+aD zRa;~?A(`p!bIxPED79(51})9ahMA2D?b3B{=Ru)mHKFCZl?jc~gm&vcXvm7HE+F@qXH^Mn3q~s&d7TDIusidrj+)wI4_sGYh-^eHVt^8+X`m zDAHoznp_yF4vGRYxmp(}2(=qYQG`)P|I(_ts z!8KLF{O(ILrPqgERcqf-hy7>H)>|yWA?Ge~mKFwP+d_vE6;Xt^D_0$Mp|$qOslmcGbrX;;}ogP2US+L$P3Ne%S%3 zpG@Z~@9z0R6^Qar%XEr9Z7y3Cx4trWj=%hA%-i0#cKwn@zH_x%lmFKwHE+G&^fgbK z*La9wX|Wqt7|nb2dSWF9vBwP*hy$>m|&epy}5 zRf+f-QaQ;=wO^W;`Qn^$nj5~nDdi`RGU8I5(laW(ibm^uC*OLkj=|ZU^c^Z|5YT>9 zw^|xEY}?+o_n4pl=+feX|MIPBR82ZQ>2Zo0&%ZbD_}E#8Wxl1Z=G{1H8gu(#%ik&P zc$U*$YP888z1PmY_RtGuqDS0m{=*ISWhO@? zxAwqn`msC5Yq7HF+MT<7e6jlYua;{r1!3Hi3jejRQv-iS{=GEH6Q^c(I=1;ip|twZ zW8;qWzf!|)lj3!>yj$)B}; zh3*(|bNbyEaq+Uf*R({_XT6L~FHgHnu7qT#U{&KzR(R9GJN?e9a-i?#WDyyFm732L z&93;P*QU+ChMo6BsT<`|@4VLI*c_W5USWZ-nUGdx)~+C&|B`2Y;&m|Mzg%dLx>$p< zh^sNLz6K3FKDz(6Fh^=%egCt&z1-2{qUqLAH#PXaS?>!6_8=T|fP2St_jVdFIOamr zuE8GgpD?1WX}mXjk1g0-ZOj8x&K8~LMK343d)grWW_vp+ zZwu;$RrNxH%QrI~e|;~-FLuw0IHW0gPTnwoVPQWrxY)$kZA>^SO0(U9oK@IL%U*Xt zk6EP?<`q;5z``~~Rr0bD@H=|&@u+S-xH;#OeVvmoa^lu&ojSP(W9H7Qrz0okNyT4l z%)IOt1YLu|Ni}WWNY}%bv~HEUol3K2INM;<^y%}^yQ#6Q+l`EDV zS|F{PvG4V#w9BSS_s5tuI^iap^{15<8-j`-pX|yDFyPC!UK6*qR-uLY>5P*DUz%JG zSyXA39*ivdwf9B%;eO3J@3Sszp{3=q{YDVZHvIKDK<|b!|IT*V5myHM*i4JT$9mNw zX4mtUJM6TYb?%7|vUjE+E>kP3y3gy_zF@A&-}zPae6tV~PJF(vugbRM>mo6))Z7-c zZduB+Cbxe3I5^ETq2hwjTEjiSfBen1@-X%VSvK+{kMdL6Y=144d)Mu#`AZ-9`R9~R zV=U=;MDLz@zJj~Vj=OrrDzQ)fz0tnXiv2yxx+BV^``F~tI;3?NSSEX?8k^2Ot~`IR z4oCH?yiSJcaAToa^5I=B$k_Te-ZCu+$GiQyQEI6hLV_BUdVf`gjrwB+({jA9vvq~g zQ+plc%gb1SD3UHa(F(8c1wU3e&O{=#kL3II+*xVN$BB7MYU!ZeH>k&g=pr4b>yLUcp!$@?oXU@(-nOK_Is_h_+^^*;?3&{)4oWF=l93- z+gN0(+HhU{cQfqdrOm3m3992L_uA}s**N#1w63)0_P8VeO!O8l52UN*jK1D0z8v~( z;`{rn8B)`~x)tx+E|nfQ$W%q;j_8Wk@`w;`&{MRv3o(=P_t+JsFEod)U#ig(7o_hW zw!0k5^_KtI&;g$8iRNEpPlV3$n~ckHWh(13b{CQYTE zov*k3E`{t!w>dJ?7Kdjn$`5of6Z^8S9m=QPSNS^Lmq==NXVWJ1on2Zah{c zZa7qWS*)xoRsK|?VFo08*kgRRu?)Jt^gL61M2!Vmqow52k?5C1mG*8N&UM+*p~>AS zB+-}sJ8Qt)p#%M$VU(8~a2mMFQ9>g7=v__J&3#cQ{4?~`s(~T+Y?Y)cRS*i>dvT+y z*0lgTP8U2g;sCuJ^OJefQOVoc*ELI`syDlKuWYGd?-j#Wy-9>yivA9|6ibofHs>uotL`0?{H9PQZuxH164O4lEHxvvKTHOn?$idCcX*ShPo zSLjjhW&Jrz`x)_#9$&i^QMg)v$C$G1g0b-0xSFQIFxb#>x#6oXzI~BX^gE(4t50VB zc9jvoD&6Y#AT$Qaqte}liEdqPLFl}8XDzwg*8>ae*&O- zw{G%sFMpZW>0t(0oEtu9NoCD!O(Mdtxm`M5GXatAW4EvM_LAqYa(vqYR2({g%Zn2A zeQJuR$A^ANv}lHwcD2W}nFEVLn&@P0a893QyUU6?s&SH^Rks}B*d+n;3eL!b>j%U3 zNbP-P%KL&=N2Ad3k)m$bEen@_1#y)}{4;&^ny~-$gok1u_EcjQYY!VHfU5lPXsrQH z_N%5px}bq%w!Pk}XL_V`XufFcuCG#VJpIO1L*X_wc>Mh>R&w6hUMGh%j>n5(>yM?L zbH|h~4UYzOjX~oj+POpTgn&QxU%q(RXK?3{odUtC+82!hfmhy3o1|DI?0Q2TMi6L~ z(UJ@o>l7MvoilQJ_2zzXD@QBB#z*uMmuVl3ZsQ+OJs7tiH2WHxp~0g(-TG)-53sf? z@R$arT5b&5FwhTczEh!j6o=~Upxx$#zb*1^TOW4DqonU;s@aBum6tYwaY)SVw$lGz zARbL+4>l0P>DSN)YO%N%HC_c8?Araf{jmNRF=YhLsWwg+U^)uE3tXs}{ z=?^Jzp6|AkS`7TqeDHYcmvntbjNH8IucW1R>X5S@)Kwl>UEhEk?^gEhcikV|PWfD& z@lXd=Hl+|Fc;7ongKg}XpPmYBV0TQ9b8CkE>#fs@e%vJlydS^giKEv)9auJ9kJr`cMQC*h`(^mF;z}@F zPRA{e7!i&!tYMrHCHm(9I*ghWTX%~?41V@_JN@uc6`Iw|t#x>>24e%ubg%j;LG*p` zaVS18>eoIuJseK=%=>DVDP>;{b}B!-K&qyu&NkN`ON-}^&97Y+2kC!nuu~5Y>Eg4) zRx#+z2BVsOxcPE&WvgdSSXOU*$K34qQnuCE)cw6)NgXtH^VfHNAzjLuYz&&O73(YW zrPOKij~;J(XkanE^Q|@4?n#fk&;l)`0YA@l|0GpGe`C2)`BEY6u0tws^y^o*ev3Ef zOv%fpzSdiKB7aEvZtD&>$h^PZkYmbCy*@JD!v$7s5DN6g^wr@*+EwvVeEs#lUnS|d zrMH;pKIBXF{%s!nVzHeZLv3Zq9~(@ZAJxHPxEB29KT;rNIn#@o>5f;cRGUv}BSf8h z5so2j;QSVZR_b+Imh=t6l>JjQw$D+@H+>sEO~ip7Wy|>nxPbSM^D-oUKC})4o35MtYvuXu zSZi-7_U|^^l0Mx}w)T$f73cg@^4L@3^~fd;psD%Et{oaVeN~--O^^9uR_hZR54YGR z-ML1M_0#Xt``fyV+@c(*QP#bm%O==i^8Ku@k*=>yPc3)L^PA>Kyo`ONx}TI&e?a^{ z)A*XEF@Ym2P?yHT4x7!uzwD_7gCD7>v>5T0#>dw46H#|y(TXYqhI-&Sm8mgpW5L?5MVYZ^M&sS%mx1^g*S_8O)&98f;Vn%hbndVNqLB7? zBs;+aaDI9F8|H8PvE}xEbK2^BM17~eT0AdbbUhpUJ3GV|4L5glJv8x;^uhXZjjb`6 zra|XhjdmDsDPQ|LyK0ZgS*B+(NsoHIb19kEji}p)I>)pEiKgQzSDJ*&^WOamzFR*Q z9W2tneMr$`!R;l3!}n+v-*1zh>+fCs-W=403_EQ~#FpjXhBj)SVXALa7vBJI9?58do|aJ=KFLEir7(}?SGFbcN+yjw4!idg3hOM{i| zR<&zk_CMO#9krM`Z+iFwgBC{qB7O0;vDxhI4K!kX-txn>V;;@ZZfL;k`d(j9H4RB6 zTxQG4^O}8)4A}m{?{P$>9Eta(GpxbJnGQ+r81tX`pF2-cmpydO-?^yj-B5$NjjhU@n)O*aMP>Pl(HBSeT4v7v=z_pq#eBXays zRQ+8l5Vdj=X1*`wEq?z4A)qdn(3E5(nO0a@4YtBz&-`oO>_g!GptWV|u0FWipLTxF zc=+{`GpEt>jg{cA?Slim--Zc6K|M`Ntr*$)aIHAS1aux4h^vrn8X3tDAe$`yNqPzSrM7c8M)AD!8#yc>$BeeU==Mtq<0$ryfmtzT?l zh^Pa1sBbKEZ?o%i5JKkMDZAsi9&;RRD<41r))?9*=1%p;qOvpy?heOgTWZ{j1EHT> z!@q`85EgIh9rCh|HD*6-_`s^6Gji8;Klu5cKbCoS2wpwX27G?p8!jhN8+K*l52;0_ z&-_y^)ECRm2wL+&t}Cs~bAEVts#n*u{Z;a_<14q8J@19m43JwN3O*M<2mmX89u*V7 z;La(ZXf9*`^v-+}|NPugOc_nUx z%>xp_B#*9pLbc=m*9pViVApDbc4vMt+B?0kdcq+Hacuk1gwFSE7jN7(BVg>R5$g#C zURuopY`kyS!5&`;&g)pt3ybT~!S&1ynXJsEs?h#>(TKexRdN#s<2LY;y*ByI+H%kn z$#)z5SMKa7)4dn392fb>5d^a!0H3(({ z&l;-Zml*I8i5Z5Gz88KKNP&YApN%ipU??4*wcGil90P=3c;Xp@Nv(A7U(+u9T3j@q zx4Y}Jd!sA1H9Bb9$~+he9Po}rIBUd?#p5=Oy*{J#xI^W8LxcwV2`xJiWk6AfSI7Pi zutS^S*-fU=oZ_@-`^1Lf(RfcVdD1Osn0>daLvbqjB*2%va-PX)*Frc-e)s z!qMe=+d7v^$Kyn(tK;G7kznodlRz{*YoBjGKH-bc$nJS$cBWovosRJKIlb<2O5F|y10+BbqOd*Qi( z5FE=ST%od?FiwLfH~%(%TP+d)8o%AVvZ?`B3tg&BC~`&uwQ0eNJxcm~pOX1D9J*I- zJMLc(hQtVnlY}3tpAL#?szuSyc83!xX;2gyS+6$rV;mQbFORDU1T>{afa5m~ewfK% z#)YxsTu4nsAYZVNI6|xWV9Q9v(bQ2rtsLqQ8}ZOOMa*wb)c3N&vnLD#SYce}A?DtY z090^zv~fUSg4}uOHmcM=OcoPUMi#VlMVnpQ⋘-gY!D=DuQkMZD#_+93B^^cz?k0 z_-6@tLTz&p;WEQm<9;I^YzxP4qtH zk6i?()^H8R;rf?4+guIAAa?M?W=k*TQiHNk3x}K;IHF{Zq`rj4{~oX6OYpM5O zj<+j)-uEGxw)>*!lKD9tw2EmdhakMoytBVA8%lmZAM{Rd+hb2-Z-FOjA_V`qFdQ?g zl$sw?>W&mju%Fv!f5rJcrTGA*@lKj<4~$9(Lt@|JtF`sihcmpiO$*aiQ+t39fWN@nf)twbjImvb>@~c z{xFI2*&Z{?*E-&ZAsHVo-Y zld6BY?5{W;p){XIt2V67OD$fVrWarl3TMB((4&|2q955ECFbndFr53LiZ`nNNJc`R zO1CyI!8`t#ay+SV2=SAv{eCDnBYv?qjg9@61dI9@8VO=q9KH`KkT@Pv>8B~1j-~dQ zVd%~Y!{4R>m9-WwbiTVX;6B|T=G|@4=&;ZKNOWy~Y+hr%{-kTBsbT7&o%0Xa%l}Y&FeDilnsDO_IvBc&52J6qDJMTPVm{A zgOTodW2oP!0*TjQqe}2TV~9^7HNoU9R99I&t>lSj_dEDaz2$`AYE@CS1|Lkm7Y5bY z*)afKYycu$vgl#z;>qW~NVihi0Im@@+vFnC{YaX0e4@dQi5&)w%XSP+A}H_y-9h zx`m7Sb1f3o-Oe97Dij=lYL#t@&2zXrdrE+qKRsg6=d1cb>y5;(l5F6ZLz0-Uzgvp+ zI6MSXh98^${hbFkr_!K)Fb;2qGQ!ddYgV<=W!AJO_#a+X9z@L2EdMsRp4 zBnJFFgc=1NSjkUuzVQf^{cHdEpzZ6nF>hA-q5Ff(UXg#iaN!Cygv8^Y+|z!a|DPc6 zKH{uD&fTMhwILWx&M~UL8gq&MuQ@=4HI=FqTQxUVydEjtpK{*rjRvGjiL#$>A^5A< zQfPcT%Q)wbKm5BaTyRjM!9$`qPSl8ympv!o>Ue~hkNgtE^VyPsF?)`58uuv28LT6*Wg=)HiMpR55Nx^10uI;;LSF^rB2L0D|KM7 zOzd|v3>f~f--!kuSWIx)y*6QR43`M<`eEX{9PttxuiGr|V&#oBW!D{P9=B2AcyB@o z7z>?|Uh;G0MuPXnM?Ilsf}DC9G0$H2LC=dDK7Bp!gF6f!yh?wjgO#Cg-x4khb22J8P3 zK{YfUJ`Bhs9S`SqF1ibSNI^W&=*l$b+e0|d`KPmM1`H$Q6ag)vJq&pcnQ7WmWr z05G}DiCk+8V}MkaqrCiH*ZBP@mN-E0xA8jhJBwzonli=?F+~5xcXPt@V&0GSg-cA};5GXjEJDsE9?`l_Yd}wNQBCGP? zKtG{}n3OFgvW4aDCwQe|GjLtmSYLsIzth3X=fLp?*;evWqMO_^v|!)0j)gT+I}zhM z#GwSgr#Xx@*kiQFIzj~`&mAwyKhTzb>mfeBS_3wQ)*9nqvff;Tj$$a`a@BsA(!QF@ zTqpFx4?Dk}AlAea%e#<(V7nLSp@sfEp~j;AM~=lu>q>CTSTvl}lS+aa%3T`&_4l_N zNyi$x(Gi$(v-el^Z>nePyjWrYgX#9E;=$KxyqlP_hiFAz+!`lvjsH@x!{wN7zgH!^ za*kRi_HvQWyC^i-{={;(Mkn^8DNf+KZ=wZ1btMK|FD8XUX+A}1y@mXBP5G!CBbXej zPkL!=9kQ?JmevEDw>8kb8@tZ?TbsMm29oc5NKzr{(c$>t z0apfQ5-&vU?40Vv<1+qsq6OG?-PBQoi8RIi=@tcIFE2iOU_=GBf$6<7HsH)MeTiA6^j0k1|wRV9B9A4Y@%S%rLku0L@UgGiDe%tMX1_SnQ zusP^of=h&fP-{B+bwdASVWqe}jPqK_8iB9n1>-^5tzYF^sIlw9Rx=;c!|*=2b_n7o zv?k=mADj<59fG6pjy5jqxyuVJ34!|VuG1fgV7D!6J4}Si%)S&#|pqP|YKfJlV z_k-DMKd`6z@lP-aU80s&kLX*IP8?j~j*SZmSlkqZaFT^N&#)Bp|7I09e|gaYLzbRz zsY?E0;<#Te51H$^?iirq0K2)&^*^I>O$<)(9_WJdnY+|bYP{M3q50c2uviGymMcJ z)tyNhnQku!R8Ct_VV_>;+qxA<$(wgZe4cp3G~@H`3A@$jB~o9N-f~HU>5qOp4x{?H ztvMl010zAdbu*XUM%>%vF=Ujzh5Y>r5mB|=L2P@#7hM!490+ed;*p2IQ6EGJK0iGa zZx$cRt)7!5H6_?1Ro3DD{R4!A_(4T7nqj51Aa-Z3rmq+Fs7WbS^jpgHFnjBp6bM4) z20tR6=p%JYO9MSHjmlcW*y(>6U2oK~M2kHM@VtgpxN z30MF9^|cW?xplOk7(SNp*Tm<>3%)Y3mSwsdV)5tOyuUj1U2gZ?S8@J`%T?CuKxm$E=mvT|Oy^P0fa^;MFmJ0BeMVRSR!^yr^<$}4;C@53 z0(Tu;GH>S@aCYMy{lhnZC8hg&rRU$qQvZlX=KhGdn7gvUM-`e*d*)QSgDX0Rg+6U+ zXC`pJ8fF4NyW$QPqEEYCCZ2pxmrKC+lhomB$~$9cn!V}%I2{h&AmOc}r|2KQ(j2ycgt<>k=9E$~;P}CoU`#QR zAagV6h^airxB6$Q$_UnD)-u<9tA0wH7yfC3H#-9_q zddaWHG2wtM=9YJBxoO5ZQ{Opu7td^dWNMkTdbLwOJqSF_btRtRgYLBZ!53PrIh+xb zWjl1+qeJH=%eJ$iFpkUejZE=!Ee(zvLW337=n~Ldn#XbH3M~-a;9y* z#Py-$2v_qVCBe2h!Q#I4iei>^RiFEnScso``_~e zTZNMrX@sw#Z#?exs`;#QEnhf}Vv1K!S?o{xy&qoan*~3G9=+mhm;ChL>$%1Dk9ZybaLMktjkAUH)b7D*M0s*3bIUk2JZw7p5zDSGu z`-;n!&dQSJKRth?uwMZF=&N;UQ7upUw}UoplU%6`$>+}?SMp;543!PI9(9HDW@ndPZ$|5t>tmwO*|CyK?G`pN=TR^FVK2#s*PMz+_Ao}E`G9So@qgn)9k+7< zxt{1-6$qaCP%_x@(WA0a?y`6IxRXnrWaCha6?UWI#NQX;D~B<{m2@WDe{>^BX0p&Z z-IGN<*%1bYHjfRr7y9SC7`fBG>C4pH17yx`xmhxK{rMTZ}X#kxRJ5@ptu-OU0k^lNm|M@}UxJjBCSv1bx|Eve0)jhYD zcNBO8^#Q+`ZETo7HjE>Les$L4 zFsb;ve2yxabBe^=^VLSYPvV1^?63Wj7@^lc8KSuUPs0SVG|!$Pd)bl3Aux=K`|<3A zSqU#>5Ij?Kg3!J6(P7Q!`72+CxM50rn&=+r1g@MD4-&tccg=Q3NU30g;POo55?mlX zc9!?tBg4*|K$%={HHuz2$k_{vue?0<*~D?TmcJyf^X>UtdPVZwfV1}E@7?Mlmu*p; z85e8=rPs?6E9%f@1&jHpx)IBI<~MJa=LRw(?CRIlAB1|GRp0Ue8@InD1>x8Z0K0kO=s@;*-RO^bElyxJH~)c@3>aP2t@`N?ns zAcuP5U#Y?;^M?2f{K_U$W~xyq((?=7WunaFDm~tN$lPE0B3hj52|+TqhZ3Hyv>!q$ zJM{g*Hg2*bjjI-UA*jX(pyfunE)x<}vj_KiRQSI-8Y**qdx{avepGZWOw7^!|C%_z z=4DV^@2EI`uCU)g_>c>w8f1Slj3_1evW!rWSnjX+WgQlfygmJ!Bc`?@hsfgJQdP2_ z$jifI&XbUCi0d&Qkq?OZEpLP=uFsjv7_1*2V0H(e-w`k{qGTocs##y1t&@c=)e4_n zI(2slER?t|%%2`-EmBWmXCN=x`V!b0!QJyneTy3m%jJS8*Gau zq)yGa_QS-_-Is49OE#d+4TSTIQBh;Ixq{gPW7S%mUtDv;(mTFbbN5W1RKr>LvV*+9 z=#URVS}@yE{lgAm@VZRrB5PloXIX>UWI9m)?!Izf?>^&$_U)94ZkX18c!$T4w!a9uZiO&z{wMCyB~A{)4( z4e(_~fHMZ+iw*Y!xAARqMq`o#PUz@}8Fxv@nCOFN6E`0pzx0%e)Xu82=SbIi`e>8Q z(H~6Q@AoPI%x+sQI}lv&)X^OgjXL(xURDVo+C?486(qOZ*v=f>FFe2xD@j$I^Un_Y zjNGxk3S6=BN@-dWPUubJ-KTT864!^mju1F-Of+~bfU_5tHH=C>^wa|u45%y$6#n+{ z*0|bdey!>^NiW^J;lSdF#PNXBu|W zD6a1uKW511tJB`7B zUDe>ai$7%` z`c8lOmlFwz{&~op7nJ{+grC+$Zp4hqYHu(b&a|$ipJu+o&03Ip>gsy>GtV^`Zvc~% z8XuHGc;(N-{-oWLh}GY>wNNL&BRj2oWG$`yg~6`jgC(z~|uf0Q~EHpn8Q(8Yn$qa(z@!PsA`Hi|_)Dn|Y~4J*M0y z{_}>rBXt}h)x@`OpTk&laqi0gLT^HTqCt#k3^qgX(@PFx{ok2*t`PZNteGKycd_Zf zqZ87LFJ3b5+pWyy9@h<@t%Kj0y3X44{NvkqrXK$VJzJfo!QtF-4;JX(n|?eSb@W{s zN5%2`xqB%fgU-z@)Yfh~TWBIwrhPlv8~nVQ7B2X@rm>huaYe`HnF(BfsRqP4lJq%R zI-7 z7`rzdo-EL1vQcL6&_~kmGuyzt>3UFlR(8ABPpHdHdku+WZ#ynrh7SKYj!z5Q~F!s^WNr^R3^Q9|`m=W7TzS)>Ldvm0kTlb7!(!D_F_GsScb5a=ztWDiZJ}}N#tn(B8 z%9HUzr+LOtHtyf_Y|PwmChjv$i$f;Gllk_uD(PRtME*pc54evuS|fO-Bgw)ieLh6y zc>K>O5KFaU@mXJaeKrLsq^SizF-a}_lkMVUjx%kHgVK1|cS}-U%y7lJbEE$=Z>B{> z7ME5}h5yK2?H=F-MoasTjS+b8v?~&YB~A>uua@l@ura*o6*&JQZB zZ&JFRE6t~bkc{9)gbEJhe>(Xc4-&dCYlGk$b3-9=E`CW*hCTID&3=X ze;?Okgb#fm@wzPcB05>u6wjOExh}D`D93=d(?MWUXs>#aV6l<&eiC2J{8A@EOZ+T0;Qi;f9257= z*NqlAgoUN0=Q4DF6p6}5p9GPTwZuR5J597 z#_0N&__8L2;myLgsh=oklv4IK=j!^Xq!rKvF&2fb~ zT9MOHm2%WrW6@rRYD~CR%R}hN+v-8CkqRAsH5kT%j!x-Bj_ig4$+-_DYfjgq2NxcY z4mP{+{)p`%V05u{qte1x5FCa0cM-3;7qy1o2EPM^a&C=ic@N=ETTZ$kHp6gEpxo~@GBD6w^APD z{M5F&QLZ-VJ#kyUSr-dbW5FrMT%h!N=002n^^;s5s^EF5v>&JLg^MTWmU$}i+_iR; z3wmoTy%BmZZ2eunVd~a<%!1oq&Vt`_vA~ssi zM}_kamChHPPVSNY2eVBqCh|_9pS+J^sls-LqAnYV-h5s<|9Kk2E@ z5)HE>GMzN_seLnoc)ydH6!iL3rbT6yJU0lAqv> z8~7le;swX#*dpxv+a7bHUYR(KyV?_6UrPQdW@ohu_7b>nvJsTNdt-p58|su_*XQa3 zN5%PbrR$N>{hHEwT1ubyO7p89*S_jhEz2MLeu>GHlN@gSydCA(>dr;Hxctmh{+vcZ z)O{%5mgPR2wUC+Jq*Jl0n3w-W;j=xhlwI~R&&wn|*5;=abndPd{pQiKU^b7`KEBw} zkp#;Nt&kGAAtZ0~W7ED5W@iH0dI%lIlPIjFxY9F!6LHDC(yD3hA5A>ppo>v)J(AM; zd!Cb3BM7FShr2K7qQ#=-6yR2<67vJ)DDpfzZLpYgvO@5s*{xdN%LU>Dvzcdl3BSo& zEk@C;4i9F!DxQbT^QCraz{0BJbQtdKv~qoRW;OgNj=7)x%!j~MPh84wky&mHwa<>%3 z|13Y>*p?gd$|Q2Vy@fBD^3nJ`D%O7_k*n$|^LuFe1tExC3}f;Xxw|VuU{;o#nzBmf z`E^Er#rbe7sFDnd>n&JZ@Ro_5C45|XNq!pnnAv?Th!s3=XcRWu(8U6z!{&TRhL`X! znkRwDbi>+3!;F1&2vwX9;5k7nt;9VVRf2G;ECoj-yMg;!>V=_*;`y7!t8i~E1@-3! zU@66rM2zyqWs7I ztc@0V&#&#kw#Pfl;pBSguN6>(2W!}r#VmBvGg4Y)ha9y~{3Mud2-#*~sA zkZ6(XOZAprW0UnPROa_{IaNkwx~?O0MHP{c@GDvIyfoI(Kcpi0BoGk7WaK?>|QFFFiACk`1Amw~z^1-wCQ3dt-!=)UkYc-o~>jQRs*pMtwnb(2!Q6i@{$x-CQ z|L~DHpSPLv)R^wKX*^iW?zQ@H;Qk59+hDP6)$Xamby^R+v7Ta_2Orjn``u`N@xGFB zm^eSA3NPU?;zSP3RUJ4T`OsPM`va8bw|K5cf33K$z&QadrtHb{AmOuF8VM(gy&iA> z$i(wke_P62k9Ro%EWX-;?$hADzz(v+{ZJz#1>Zc;O5FQaRflJF{5B3<>?(38v*eo+Ty=UanNM@XbvyD6V(pzBmQ_JxcR=O7kC_ zk2(~E(X7C)2ok&z-EYEi8Jj4<_Z#BCVvVx4>rjwC+tF67FY$9z9b>+3${*u>Z^bah z{R>Lxl`8E|-m``X(Tc8^MAzyy97cJ$L@!?K?JsaWd#$>} zUl)#JWZ%smZjKI|xYmQ;5A>!`qV!7dNV?B(Wrf=gPPH_U)4o2o>!OCZuTmv)XO6pr z>rVda#d;d+F7ur6vGHKHSlqATB+j{S@uI)#6D{ft<)L(B*9$%gR6PG)>3k%m{dp{= zy~0cbHkP9smPr5oiDH?xZ4Kc0JoZ`)>tpq1Xax=0vB1g*Emlx${^nZ2LJu(}1QS9i zF_P}Crt)0F0w77q3hovm;KP}NI958~{_jBYYRrtOHpSJ7<=+7p) zt>IGmNLzWua{BQ7Vd}yHxwg8>eaJRZb)z{_58^}>hi@i+A4O|Balh$CTMQxiBtOPO z&gwE?@yW8qlG67lE4`k@{nk#RpIV$O?u%}x!lfQ`Vb3~k$$4rJe$mrD;(ku@fssr4 l&f7=6VAq3un4}W-DEiPm%&z5DFwaA2{jSpYohY40@jrJ Date: Thu, 17 Aug 2023 12:01:34 +0100 Subject: [PATCH 2/2] feat: add resample method Add resample method to resample from log_weights --- retrospectr/__init__.py | 2 +- retrospectr/resampling.py | 26 ++++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) create mode 100644 retrospectr/resampling.py diff --git a/retrospectr/__init__.py b/retrospectr/__init__.py index 863d295..b53fce8 100644 --- a/retrospectr/__init__.py +++ b/retrospectr/__init__.py @@ -1 +1 @@ -all = ["importance_weights"] +all = ["importance_weights", "resampling"] diff --git a/retrospectr/resampling.py b/retrospectr/resampling.py new file mode 100644 index 0000000..917948d --- /dev/null +++ b/retrospectr/resampling.py @@ -0,0 +1,26 @@ +import numpy as np +import cmdstanpy +from retrospectr.importance_weights import extract_samples + + +def resample(samples, log_weights, seed=0): + + if isinstance(samples, cmdstanpy.CmdStanMCMC): + samples = extract_samples(samples) + + rng = np.random.default_rng(seed=seed) + niters = log_weights.shape[0] + nchains = log_weights.shape[1] + nparams = samples.shape[2] + + nsamples = niters*nchains + flat_log_weights = log_weights.reshape((nsamples)) + + resampled_iterations = rng.choice( + nsamples, + size=nsamples, + p=np.exp(flat_log_weights)) + + flat_samples = samples.reshape(nsamples, 1, nparams) + resampled_samples = flat_samples[resampled_iterations, :] + return resampled_samples