From 53dd49a47072b38102bda006b20f00c12505463c Mon Sep 17 00:00:00 2001 From: Jeroen Van Goey Date: Tue, 27 Feb 2024 22:05:54 +0200 Subject: [PATCH] Add integration test and add documentation site (#27) * docs: update conda command * test: add integration test * test: add requirements-dev.txt * ci: remove default_language_version from .pre-commit-config.yaml * ci: lint pytest-multiversion.yml * ci: update pyopenms version (to install in python 3.10+) * ci: allow pyopenms to be installed in python 3.8+ * ci: remove deepspeed from requirements.txt * ci: use environment markers for pyopenms and scikit-learn See: https://peps.python.org/pep-0508/#environment-markers * ci: remove torchaudio and torchvision * docs: Publish docs via GitHub Pages * ci: install ninja-build * ci: Setup Ninja for Python 3.11 * ci: move install of ninja-build before installing requirements * ci: use newer version of Levenstein on Python 3.11 * ci: update pyarrow version due to security vurnability See: https://github.com/instadeepai/InstaNovo/security/dependabot/2 * ci: upgeade jiwer version and remove specific levensthein version for python 3.11 * ci: installing ninja-build no longer needed * ci: update actions/setup-python to v5 * docs: upgrade pip before building docs * ci: upgrade pytorch to be able to run tests on python 3.11 * docs: switch to latest mkdocs-material docker image * ci: upgrade actions/checkout to v4 * docs: upgrade pip before building docs (via forked GitHub Action) * ci: pin version of libCST * docs: split out requirements for docs * docs: add mkdocs.yml * docs: change gitlab to github url * docs: add assets and index page * ci: replace gitlab with github * ci: cache pip dependencies * ci: add cache-dependency-path * feat: add environment.yml file --- .coveragerc | 23 + .github/workflows/docs.yml | 21 + .github/workflows/pytest-multiversion.yml | 42 +- .github/workflows/python-publish.yml | 3 +- .gitignore | 1 - .pre-commit-config.yaml | 4 +- README.md | 2 +- docs/LICENSE.md | 1 + docs/assets/instadeep-logo.png | Bin 0 -> 15059 bytes docs/gen_ref_nav.py | 47 ++ docs/index.md | 1 + environment.yml | 12 + instanovo/transformer/predict.py | 6 +- instanovo/transformer/train.py | 2 +- mkdocs.yml | 87 ++++ .../getting_started_with_instanovo.ipynb | 22 +- requirements-dev.txt | 13 + requirements-docs.txt | 10 + requirements.txt | 15 +- setup.cfg | 2 +- tests/__init__.py | 0 tests/conftest.py | 62 +++ tests/integration/__init__.py | 0 tests/integration/model_test.py | 488 ++++++++++++++++++ 24 files changed, 814 insertions(+), 50 deletions(-) create mode 100644 .coveragerc create mode 100644 .github/workflows/docs.yml create mode 100644 docs/LICENSE.md create mode 100644 docs/assets/instadeep-logo.png create mode 100644 docs/gen_ref_nav.py create mode 100644 docs/index.md create mode 100644 environment.yml create mode 100644 mkdocs.yml create mode 100644 requirements-dev.txt create mode 100644 requirements-docs.txt create mode 100644 tests/__init__.py create mode 100644 tests/conftest.py create mode 100644 tests/integration/__init__.py create mode 100644 tests/integration/model_test.py diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..d113ede --- /dev/null +++ b/.coveragerc @@ -0,0 +1,23 @@ +[html] +directory = coverage + +[run] +source = instanovo +omit = + */__init__.py + *_test.py + +[report] +omit = + __init__.py + *_test.py + + +exclude_lines = + # Don't complain if tests don't hit defensive assertion code: + raise AssertionError + raise NotImplementedError + + # Don't complain if non-runnable code isn't run: + if 0: + if __name__ == .__main__.: diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml new file mode 100644 index 0000000..c88b26f --- /dev/null +++ b/.github/workflows/docs.yml @@ -0,0 +1,21 @@ +name: Publish docs via GitHub Pages +on: + push: + branches: + - main + - tests + +jobs: + build: + name: Deploy docs + runs-on: ubuntu-latest + steps: + - name: Checkout main + uses: actions/checkout@v4 + - name: Deploy docs + uses: BioGeek/mkdocs-deploy-gh-pages@master + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + CONFIG_FILE: mkdocs.yml + EXTRA_PACKAGES: build-base + REQUIREMENTS: requirements-docs.txt diff --git a/.github/workflows/pytest-multiversion.yml b/.github/workflows/pytest-multiversion.yml index 2af76fd..a9d24ef 100644 --- a/.github/workflows/pytest-multiversion.yml +++ b/.github/workflows/pytest-multiversion.yml @@ -1,17 +1,16 @@ # This workflow will install Python dependencies, run tests and lint with a variety of Python versions # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python -name: Python package +name: Test on multiple Python versions on: push: - branches: [ "main" ] + branches: ["main"] pull_request: - branches: [ "main" ] + branches: ["main"] jobs: build: - runs-on: ubuntu-latest strategy: fail-fast: false @@ -19,20 +18,21 @@ jobs: python-version: ["3.8", "3.9", "3.10", "3.11"] steps: - - uses: actions/checkout@v3 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v3 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - python -m pip install -r requirements.txt - python -m pip install -r requirements-dev.txt - - name: Lint with pre-commit - run: | - pre-commit run --all-files -c .pre-commit-config.yaml - - name: Test with pytest - run: | - pytest -v --alluredir=allure_results --cov-report=html --cov --cov-config=.coveragerc --random-order - + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: "pip" # caching pip dependencies + cache-dependency-path: "**/requirements*.txt" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install -r requirements.txt + python -m pip install -r requirements-dev.txt + - name: Lint with pre-commit + run: | + pre-commit run --all-files -c .pre-commit-config.yaml + - name: Test with pytest + run: | + pytest -v --alluredir=allure_results --cov-report=html --cov --cov-config=.coveragerc --random-order diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index 77d2618..6944bbe 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -21,11 +21,12 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python uses: actions/setup-python@v4 with: python-version: "3.x" + cache: "pip" # caching pip dependencies - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/.gitignore b/.gitignore index 9df4bf2..69d85af 100644 --- a/.gitignore +++ b/.gitignore @@ -165,7 +165,6 @@ docs/reference # Other folders checkpoints/ data/ -docs/ docs_public/ logs/ mlruns/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e05e808..c1ab354 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,3 @@ -default_language_version: - python: python3.8 - default_stages: [commit] repos: @@ -58,6 +55,7 @@ repos: rev: v1.6.0 hooks: - id: mypy + additional_dependencies: ["types-requests"] - repo: https://github.com/compilerla/conventional-pre-commit rev: v2.4.0 diff --git a/README.md b/README.md index f56563f..4bb0d11 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,7 @@ if you have installed: ```bash -conda create -n instanovo python=3.8 +conda env create -f environment.yml conda activate instanovo ``` diff --git a/docs/LICENSE.md b/docs/LICENSE.md new file mode 100644 index 0000000..af216c6 --- /dev/null +++ b/docs/LICENSE.md @@ -0,0 +1 @@ +--8<-- "LICENSE.md" diff --git a/docs/assets/instadeep-logo.png b/docs/assets/instadeep-logo.png new file mode 100644 index 0000000000000000000000000000000000000000..05c5dff410e8cba9fcd6e003bbbce8e123dd434b GIT binary patch literal 15059 zcmYLwWk6I<)b|1cDk1^`O6npZsdR&McStLpOE)Ya4I`b^{t2o&(}-6keT?0+7RJQ?V}C;2`I^vxCH?g8cFf&TAdGz5Xr zkK=sKHEc<#3$nEn{)lXR4WnvXsumXM=R;%yF35RDIPCxqshEcUYX60>^W6Cow8H}m z!21rJnVP-vkWg1lII<*^dmEeLk2(v;@)0V4=^LCianq(>M<8(lak7?AlP7rbt;WUz zWZMuo=~perqS+&;z2xJflxTg_@Oq4bw9Qw;2cX6M5hWFmH00Qs7Rh#Qq20Dr@*aY3 zA))fx>qH_#rAiP3d)EXy}cr__yWnVy!|MAx8z7J+MIj zQAJ^l)$@$&Bvamci)1ViJ@UuoawonhiOhmftCriLit;>shIAm}B<>91~jifI%mU<6-bz<8!sq7zcSlYiAB+E_it>h-~IMkR6z9|L%s z{rW8bL`e#B8!=nr$cX$KllC;eQ^tUz*2^ruW#-W97Fuj-|uXiJyxl>XA?{d2wT3cNfJt^n}43gQ-YCPyCJu-SRX9@}rUt*feFBD{j zG>f_(YKL@%Xcmg%2@E)B3oqhGQ0X|O6N27s8tGBd`ak4?%VPzeurjQCHc7??fJv?K zmf^#pN$CTA|CjgxHr3wZo$i;@uP8G-K>vCfWkFC|;0dso>Qt@5)r%fNamoSEjf#9P zRM(}a+)ee42oOBSg!ErNDZAe@8T79<)2=KN5~@t8yI10oIu*9uFW3F$Faelr${uXB z4|F-eQeOeq4q4B)RZbA#{NGpd+ZMz;o$>cTs<0x^TO!W`@f(9#o%-w6fsVPyukM2q z*al-GUQK2?-q3t-9QMweiZNJT-viA7`z79NLG)Iyhrb_Hu>Ma!UW<(=lSyl$or?2MYI!F zTR4G1y9{ezUJ!y3ASoQN>PZ}S+NEhFt#eRyBjr9himg~&kUZ~P0pob(`(DR%U)Gn> zB~RpnUzn-cR3+40?9D4iihUh&&sRE7xZZ**|XNiO~87 z;EUSA^asw}bB+6SSXU1>wzBKeVXJazjA{pP6YyT-7w?%xpn7x0j}d_BZi|566zoQ0 zw5q!LgHKk^&JlnKdWOX^g{%zThP~ka5;P`GxEkaMh(&nKSaHp zq2y>22_G(X&su~@-h}0chHJTu{iavx>$AV=UBf{;^W-NkZMx7>jiB!H7ySY zmDxk%JU|N8p=H0u9#x$a9eR@;Ook&jdJ8q1OEchmm086>+rD}9_YB*kaW}IhO|zOY z)f9^f#b+~&GD&sWNw?bn2O~;fzUm&qurI!Tnc;RHbk4P(6J{_P*RYMVRaPO83D+`m zfHn0WKlGp9`mNWoZv|(9sOGhC-dp{^! z2PKm(n$23x9QevFa{>o?Mvn>=ye*zOp|PS4(Mw0!j-=HRgv?FQx1xRoqjnxx^1UQf z>)Hu?91Y={g2>N{p%IW2doR_ZI7*b=3? ziN~~{hv788_eNNB`}DzfVzh)hQa59l+>De*Sl^upaOT`2X z)K;s9;tJP=rJey{aTpg%ZjhODeAEILskAQntHky6tw#E$5$SnO8n2#G_TC%T+uT!) z1}<6x&;qqFt=HeZ^DA68I;KBUYRnhad1i)7L;6&QTefu9w1v{qOdM)j-o;aL44eh) zE~I&6EhZ%kj|?skB5{c!tzYe}IbWXD1;{%vhqbg1yZ8XBo4O1K{~`WsVaxOVz#{{K ztKr>?$$kF9Nz)twf1QIp+{n;EQkQNFMr8P5JQ+{Ddrl^zw0GBp)a?*M6+u)JBJ-Gr zI`hhtLS;{;W!Il~teBiQB@b+PEZO?;Jxl_VDV`A-#g5!>%_2>o8cKd_#g)|lSZq@| zc#YlO;2Zd?TleuG{T9rGGg9G9I73iiJiat`|Jm6)vqu)+`v9ib zeEcy3C&x%DQLruF6*0V*$cihuD=**Rt8msq4zK8Dp7zp^k(Unb`*{CfOnBqhlIaws zTa^7$R<;$PoUPLlcbNWXs~s&X?`n9) zriN78RG74*bRZa7@a4|cTq{rTo6F1z`VsTXKY->pN|97HHW+j{P zHBuOECN@{e&w`-{OH3;Lf_wIEh9u|EH_C}WyMZ$KZ%<)QnVf>&4;#smlZx)30$63- zZ+K#ZtW+k7TCnLt>#$V+j9vvf6DSGHHqO*#2IX5zfj{cQmF(ZLS)TF67c;C1f~}Xg zmLeuodQ8FIMb7!Q8lF*_|8DHmWvF`OvszDfgZ)>=UJwE~^Y2sVpvZoQo5Qs>3sas# z?*nda(S+kV$AO!gYN5%CMR%VG$tKuo+>Xo0#g^W%@hg=f|Ib_?`N`~$sa*H( znK|`2rMu`gCyb6oRaR|%&nwi~$xv%2@f%zj5y9f;rLEHJ5HVY2*FE@Mw2>WONVa1o zXstjL0ZOParYkbt*8kk6JG-1u@V6J`Fx2g+6D{j`tr|P#s%1nvzL-CwN_?i9bfunL zi6jqyY-cB$yQn2M^cjp6e*4#t)O$IHdY^}S6v^S6dhGNATJ6N6twJU`sokFy4$9|A z+T#kM7>6FU9&1#V=-!@2I|!r*vHHJ3GtVXnr7GngWGFY97wU_OGIo|>pf=(M9|THN z82*i9spVW4ULJ223`(rvV(?Cm$c@)9Vx3A6@*5p^g{caI3O^HI8Y#5k=Kr(2nz7#O z3Dbh;E#CnA_xbe<+-t%uv@|b5dv^{Z*}XH&{cV@6#xto3$gVR1!!;~hycrf-^?pZ6 z^2De1gcbF;mb}3_oXpsyeJE|u=&jIq)1opKBgsk67CGioqBH*^(Pl%Yg@~+N9$0%R zwsCEC)|mzokOUr1fl3YKRB&xdw(^qL9oXz{4slc}gba@yN6VQ*8EPLz&V2;)XF;O3 zZv-8Z<*%rVNb(2Msf<~>A;N5hq<2=pf-G8nFzs{V&ocCaS$HmoF^4hGV@rI0N_etD z`&avsik@jnQ<;^y#jYt^<)L-&=T8e?v4N%NL{(Hd<@Q4L9ye}!%)-iaVxjE_M8Dd^ zQ(R)wO)#>%?C|DF;vK~=8|};GXr9;wx*^8vM$~b6c{f&(tKXA(z0cFOzKg$_{1yX!vYuydisPL5 zJf}jOD=kd9t>i+b+BKSX@jd{6Nc~cmNS^GsHXfp#HK_K|gOB+ckL6g2iSU*-VpA;a zxxP&7eL(h$fcH|AnPbk^YyP358lg!9D%XD`=gqO(<+TX?-g-%Ygoar+RdoQ%0Vvx7 zTz|Y}O{22H@w9CcB4-c$w+|obqWCEr0=;A%|)wX(H z+sIzq*PIWv8`89IhVXI4HbeV9J+c;l=^-KRt*mJ7U;eV8OfV+w(HUw$Pm@3siFA_w zb^8?u(8%TCTlj7y42Fy%P+?RlrEa#*O!;iy*_Mf|xkE-}X@_EQ(|h;O3xGWhK`^h8 zno?w1w5~p!1JS!mK>wK|Rzl1va>E*d zy#!mmr-O}FgR+G=G`g>0s8L5#C|#FYGJYmF8754H;jr!sM?o zIXULOCe}YyqpqNdU9}z8vQzKU#ZHeaB%4DZ~qM~ zf&vH}$g3D~?%6p#OKo##NM9ITzp4(kWjWYSr?xbQ(()s~xleDXzgX%8Fe?^&*F8rR z+nYlJ`$8G6JGuKl6#F{8u+}#$-BHTa`lmkvFz%-0;Cj|1GavEYP?2&1KaYhdera=13; z=vl51M%VmAecL|g<-B(*KXVyV&g}*t{0#zCs=s{yUIOGiTe<5Tn&X)9AR|nhH%^LkcviShZJ_sk-DXQI^*W1Zy0_8Tbp;rx!ktXk_-Iz;Zg(*j)FKZ94z== zkcMQx%HPhaEAKn5h6U;IOF~uH{e>Xi1R+x8N-uLDhnh=JOp*<{)q&5cWhM%g zI@%wuuujI`XEY36`oWB9J21B-h(-(_10CZ!d9!Dv2FZyA9JFt&$UrF_jI zP7+-(y8Lh*T%Bvqn<`7ay#cT33@AbQdknH&v#i#&X&wT5^?yYMZwsXj;i3k{wja3n-(9(2N<83du8hzA$hT*wU-)n4cxqAt0Of3O?!fk#ON%(T-E&9tmZ-z0o z$C6ilETc;X*E22beLiskrH;7Dg1YeFf%maVse(GtaW_F1*#(T zL7E-H>+bFU1Rq-vSkJMUIU~ZUEuv98Cy^z(3$~dhxB`Q0l+s|JBvV}Lx!E4>J2qea zJ$l`Y%-1jT)Q|x$XP=>fL4zk>Wk&T-k-+BtaDAx&lSRCl;4w5wJ4Y(5X|6!x#PA z^+9?G#0H;>tp^cCew4!RG0|*Tt*dOVF==wrJX20nSu^|DDKk}<7wt)%eV<0b>Lfh@ zP{KLK<}O_;+%*);>5-WZqM(C;ML83RPC>@Xt@fC2)qDg7=Cj_B19Dca-{k0AIKh+g zzxY>;wU(KJQ+Kdi=ZQV6tlMnUc(T3{RMWi@Px}Gd*hb7^EL5#W~_vkO_ ze1h1|Zntx#bmmUSM&P(xm;c(*0Gd2j)nfjq)wM(yFVRTiC4oUivNe&PdT!lYgn7c!ToB~F$h3akI=SKJ8@WS|C9r}p0 zbZ3v{eN4b{286O#mk0J%5kO_kYewasXsxT`wL|}0x?wDS-ctKAi5dw7k(&<^%VY5M z#+@&W&Y!P6HzTlaX7F+J)zC2y4kwKdXJ4S^*33xFT!Y-c@ET0o?!ii6TRa&Pd;=9% z+hov5-O$fln%t`sG|B6g@fe&wwQK|B`xtShnRi^EcRCruMZ+o*V2?D$g}5`zBOJ|@ zD#3NT+v~x}hZ+E>J$|=}akJ-6-;?kohx+FQcFBpKckN2uL9Uefpa~Z=swJ=y|kKLonQ=kvUSgiLHoC4Hl6>Z zy%cbMPx4mXJ{f-zasCfwN?H!~-w^W9S6ffjPmcqF`KRf05V@$!3UiZ@q;`C$aiLLD zXAUOK9Jl4H?E_916CmlZEq-^)CEx9%e!9?WKw(uPBij+Cv0vrmlYfS4eoaj zFfiZ+SJ}UtS4TlQ5#8H>K@RI|zD)Od3`}Q@ezbrJqcOSv_zAu7SMiJ90a+G6hXv*h zc|pR93`t-8)(3BgheI;gA^vJ3qx|llMml+AaFX!?gsM$suMf)t58Mof5vfHAV5m?r zw@kk@((URG`(8#v3+^53B;g;!LqYlGPV}%uu4{pl-o!6G*YGua<5&(>1kc@&+Us9= z_S0B@$y2*c>DA@IBIY9kFundP_!`7N?G11eFt5_$&#&pcqDQPNPu^-j&q(N1GTOZB zvzK8wl_UgAbm7duLo`Y3{tJJ*1Uvc`$yH{e0|xG-lgI11D5liJ_GNc=swc@SX($?v z*Dnbdlv}tp-Sm>ZW+xZ;U4UwV_$HXBaZpXXC?g5~GfmBH)CYSV*SFT^dZS{A8-|(2 z3#XK7sxXjHf}p^`zc>u1%ugnUhQsV~9fcZ|sksF)?x_u?3x+NYZeW0GKv>e_@Z0sc z=Hfbo*oZl!6-)QJlGfRalimHoenWpcFMv{WTQI6jG&M&p5bVX_p*66TIw$t8+79GA;6Unz2{?hA@Jh}W5BQU_$7$gbI zO@SN@E4J$QRi?4XQS}|Ht#VQm@u8J{XY|nz%ivchR4y9yJ4N!~1U67i8Tpi%!$d%1 zc1m{;xHoYF%4$TY6>Fqr+j9e7w^yMhly5lP$*3?IGGp`!I!er80z6H3b?;5%n8C7ppngLiyJtbLw=d-2LjcfKdr)@c zRVtD94sN$?~q$!e`k&yc}#mFIO`G8-u8Oyk9^kT3Pe!6y4aY&s%)U z6?V5P<`>LwzR1W0jf8Z0I=1?+TD(fKIEmS%BC@NKUH?B8d2=+6YPNWZk zxl3fd?pl#d^hd4#+!39*dvuXSeCR$<{xU7;+=2Pr46Xx$cXqaZ*F1ktt!#U;I3=NQ z$-VO;M_39S;zz|z{9mSM4as({bNa4#Gv$Iuuoq*?ZmRH-R(2}P!J|zx$n}03Y{`Y2 z%}<%SuU$M0ipgsk!RHe*~N9mlS>ZWI?5sszhqj(r|tMnqUO2o>^zW)XDx? zOv|vApR}aJtL5$GXws_NNw;5vF|dACJ83}0!H|Pp@eD*6wCdKWc4{kagMQ@S)!Pn; zb5730^>K8L*K1;TO}Z!YWTI@c_Zputhjs2k5&k>&6;tgw>f!(ryrnb;_R8I(sW)KX ztX@T+gD%bcikIAH1r9j zXrmIw&N4H*d$LQ_-Oig|I<71IZJC9}`^r++&osS>UPAJiG2%QuZQKb{S^JDPa$Y$# zRU=3e>_39-_`vKeCbRlzJI1Cd$KflEX4%o|xUKK7{nB8UY|zL_pYl+MW|1ne$&Fgr zpZy_|YWtf05sG6CdIz)`(vBC6$#qF|DcG_TT&+0XtG56yrrXM;nM-$ZGW$*OI$cYjVt9Y z*4}fx`ToHvdnz>NQbA*thi9RtPt(qMC*Vzn?2ysyQA1EBzB3(IrtZxCgK9NI z7wymy7Rg?vkSF`kn5qps0LkH>m_>G@R;nMlk)_FylMmq-i=Gm6(&giDm++c{0}%EB zxbyoISDtcma(=`}aq=gz6m(DCkDGoRzD{@`SfUOqI#y(ne0ZOM8Hetv`HA({#k6Q1 zHR`?Pvtk`dtvs0l+e|EB(misi&75jM@q71WQvFPg`3!c9OiOh0Fu03LOxRD~zfbTV z`ea`PL|{VAc#tUpYb)Yi*VpZNQ>HN4drYJZ%9_w2DXB_XZr)9oL@Dq}{tN9T4P5MY@{yY@|j3t?K zcExi}K`POCJWD_K>g0bvxQ5dif-6)i}2#f3bpYwZ&oJnPk)ksdSSx0W#J zu9ebePFvB`JM{p9)8oEwP$80&*rx&U@)PodZMY6_EFu>B@Fqicm`}na~ZYJ@i<*tG`sY;L1D^I>=b$UZkqH=Hl z^{JZjm=GacFW6_vv7lnSQXFE{{b}#Sh#)z+KM(I)wNe9TPKMYi;vpGdVnj+rv-gM9 zdc7shMnxEiY1kEB=aTTZpQwsi7E9D51~y_kFYGoM@;9o`6~jOsy?(85E9rNX^p3?D zRFnW8u6rRNUf_&|Pb1MU=wUHWpC}9TPfZLsnT6772XQ(6hLQB(>pCTeoOSFzKC}Ps zN?Yy`_~t9e%bF66ze$r)Wq5+D&^(BO;>j>JKv0j&b>}{y(6EAg=d?y@|)66WLPC17IBU4mbGz_pu!?Mzv3vXNEGbRkh+#$ zB^cMI05Z3l3$h~WiK#p_I@363mt_iO@$=~>PRi#S_6|nfHDI7#rLO6aT&)#;?kIZv zjCwn*$Z@;lPd-~;&i6{FUmBV?ya|w9$P^>DuJ3EVFz~@ zUa{(ux&_${_`~7na3`T(sg!SYuGlxc`UIf_pLpdYA0)N|WUTyF``r5*XX znG#;S$?Q-@l}-QrIU9NoVl#tfs{PX^_LJ|#%nKYHHH4_bTBl1PY`W-|VqvKlCz_ca zV|f^7$mK`o6kM?jtj&dcqh>&!X%wTfBo#6(Q7L+^Qr-UiQd@+N=%JhbT+NM9Tlmz- zIOAYrjEezVp7@1AY7z~*+MqtzIf{Ir`}%7!W!={=dSf+vW+QxkP69j3I^_@hJTpqw z#UC;*Y}Wtdk9fRN)XfUWLt3$s9wu6&_q4(R!1F4rm)?ziV9`%~uzM!Q-14zGe8CMl z?{M189mNHtMlvBJ>1XcPurGe2PtM;>Q(t+_PlcR~QAv>}um#{H2)HS5Mm=mUwwMeb zY2Yj0H8(*nv=$ms1Iz7JGFL#-whILn4|O6Ogo#OzR&af%wbh?~Kbwz3-u|!frF*&*wxyp0t+3!JlRidS=ph8;>>59 z8vT28_!-~r&LV>Mt}J746wr$&-Gl_qi%(A-1*j}mJcu9ktGd|!8!)^6W~XIH;7)$;N z*te`66ntW*L!@P|g`*^-OgK6>AvS2iT*0Fs9O=x^~kK%GQ|1Bu)*P)}*r)xK8a#88GM zBAam$ceLE9I2Co6X>^js5Z&AJ@{{a$%0OC!tLNA%9rsG0WnDn}J9+YL8W+I?4{b9a zkqT*E=%5Sf9`jzFpvW!V@G+SPM;qysS?l?s_K?0Y13myqvX@@oHx9X`wX&}b07u05 zG5#&3quD7t(H`=C$`%j6Vf=Lb6AngAgJvFvD7=%PXfIZ^sB89&+d$^oe>%)eSoW_x zSQS(||2P|}xH{x;)_n4y8aSzbz**e(7u4GG>pT^bV#~=YDx#O!Ah|NO7at;9%6rnvJ0HR$yWZ#=!@nKcaB??D{ESE|HL0`G&rp=8=%$&6H_U6?WtFMT z`h#zqHWM+WI|uEpiTIu0mn)`QDE41I0#GMq#cQu%DU&g6q8-oWC_haa#hkmar{yp7 z5sheZ^bx|Y^l*)2S>gR7R_J0Ty9jc@7Fo0=1E4}`DvR$(vZVK0d+jv@EFjTg^A)Fc zNRpt)7PYWvcuNIr7YI)stX|WFC={Z*H5xQqJ_e@Db=pN~%%{`mR5sV#ByvHpF!S*Y z@=i4lFRgy27FmqTVKIyvR@{TPPBQZ(vjF7;vhU~`MX9loOKfC5<&~MnBD@+ut#^NZ zQ)?gSyaB%6aTVnD56``f3)iB<_suSbJ`plOk`k05l7WJS^p+QQHiza*2+)br%Uq&q zTx4ROYVV{z9ZAhyrxxWio~}sFM=T=E;{232upy3Nz0goCe`em{VKQKkrPk*{jWrX? zMklWg-T zSySHMF@|%c^Nb^CG4r%dc?~K2bo=2pNgaL-eCEwaqqQWrp{-HI7(mg>)BaiD^g6>e zCGD13cma-8(Xs5`{Eoo&iiGBugI+QCfd(F(Y zoy7fPSmZx8E5-Uy=4GLL6O56c9_FPoV|BG2S6nH|Kt(h#YRpCtV)q$)Ub7148fdUm zOTCqw9&T+y5qVYcGa~mhqnwum@J68W^PD!9!I4e~fmTuE9&y#xd`smob%4IlR!@;l z!1_G3!Y91qLACS;7YAx@59=e5{jO0CecIzBHDvSKhPQ;k&?uTpgV^>wM|UdvbyO#% zSMVlj;Zw-RH)hDBV{28hJ8=?(-rn*=$2E-Y2o&i1R@(J-Br=elRopU+h)FHl(GXT5 z6PlgKe-Oj~i;NscDV@;AP+iLd)Kj(;K$i&%d1zF?jQHE8P)qv1`X$D23kgo2$!v*d zA4WPuGV9VC^Bl#!VWWUHuVn1Tc|#Old?g-I;ECn`MG#fdzk>{D!;WWTTU$Z4zQT^> z(QpyuBT8Yms7rwyASonIaR$82PautP=u#V*Fw#Q0E(Q)QNjd&2{@Pa%Spu`!GhW!G zy5h42N{PMi;63^RrCWB~#d64HpOAD1ZRazeGlr+JIE;piV7k2Zjxn2O-1_6*K#Z{J z537@0b^5eWTt@>4Z7ZdXt_*Dkdgb8Q89yHpmi}7)3*XePMhGm=*LXYbyhOWozi8*I z-C`p1tM19#1_f!+lp)~0g>(qwA&h4Q)rt?~Pfd$Cj^H1|K=<21x+&y**>ee#0fMFO zHIA$;B@Mx&H~dD)1@s$wpKK+w5+yzrnZpFWKMS`N=C}gVC)yp(e=OAZzFN5%5ml9p zv3?o-S}dII1!$G$z5` zTF^@&E%I7I$>hCmhAKEGhOHIxD@qaw4@RJ?u&mthq?+8JbeakXke)|PbU?DCJ4_u$ zGf#yutO_p|UCZ_707<}iGxBl&&gch1wL9~cR;5=&R}h_46w$@vtaR9mfd$MD3gign zrE#*eumR!(yY%*6BWkG36e*|Tr(eh_pe#f?t&J+V(Rl#x#B-}vC@C9z-AM<5!{-sE z{bE_#J$R=RxwF*G*XhDv8fJmAzln>%t6%e%oQ##=5&BSye66T2By5hpFr=z??Xkr} zGDVmezcu$uULK~!=FM>K#a4YvQH=zX2h$s6H~3w|I4Q2}xH@ocXM2WqVPY!BL-fkH z%xOX7xbs&j=T(<_;|BmJ66p`pmovr21iwrx? zoj+yUBhacQw7Rhzn-lP{%I`?#roW12!PJ7dAg+y)8@5uFpu!%^Z?3dTQ3^cBFGKbu zB!95M?)dgmfBl$gp3+>vwhzBI_)9+nP9Ouj_^geVjJbr7*7}HBW-!P5jt%|r2x_f( zwWuqe3`O#v1WqrtRw(yu>d6ts{!)H*Jj5=7wYGn5T}Ym{HN-pRIFdDmiI&tg(uCl> z*}PnOB#$)|)&0n!aJIk59ls)-$pYp18(TSH7y5QPbGq^)sWrO8;<3M*o=SIWwX^5) z2lks&cj9_i9vO=$U8U;Z5wiimEl%wDBDn%wD8Gcis4nXMR$sYt-70iD7<`Klr*M-ajP!=IXb9Tc4W3StE3lQ%~#+H=GtjsK!!vM4C2Wq+Sb@ z8n@#p&{#Dm3I!Zde{R!{s!H$e*jaIi-=5~(zRLbg^(X4TP6Lt269LPswDrTdR+R9i zF5`;G51@6W|0oM`8>{iRgnUqwaNA4G(#;#fe%Zz03=CTTcb2h)z#ViE-jQiopS9kf zD-ahxvb})#3n4yZTFz7s#l=Qu*4hVhLRp28u5i6i-}iK}D0ct;d7uhQ*n7^iAn6`C zeOoqlaiIWonO#?j&N}-lizoawE9&;kA#L1n$XA7ZG=|Z8OZF@EZPD;%O3;gKs>3W_ zBloF<3`wR<=|jLUCC$^Q_gaMIwBQ3KZ53r9<&oIKrlX`0{QVK`*o$lh>xDUJL_Q^3 zU{u$5U#mcSwZrAMv{sXvr23%I+kg9mIalKEns|N0+E;^ttrAG+Lrz6X^FQc9XbifK z0P|moqZZ!W2i-F|EaHcfLHPX9Th2~U>*8UpZsruR%px!Lp(qW`OgrQ1|P#%qrS_|idXnR5?EB4 zhJjBfcW&@$7=Yg`jaZSH$6sH77|0;dH}AdABHulb$^!&kb$b8ZC6Fe`zL<88J?S_G zMF9W{D4!SExH}aifs8r_WK=*XBgFSoC7$+EWqPIn?L1@w4Cc(nCDT850-d5{vF`iNaop1! ztmY=p%*f=en}`dHW;6h@K>&lM2wUgAs&E=o(R};}TP@t(8V2!JI9PqyA7GSdQ z691=hqxt=!AJrr-O7v!{O68eScUm{+z+Y)jxJ@s17~GM`axKS^+VPmiTF#<>1ekNt z@GtZUi{Omg#h+PGT0HbWO4HPC1JbNiPkB>WfF|Pr4^ooEomK>J!thoNXcfgENZEkK ztM2DxQCP;}1cb913ZD`<7WZl^D*m=0h@_Ar2c@`T>_7B?CpH--Kela7M0Y?3kZA_3 z^=bfTeyG?tcT`7rwxYe*`3u2C1i zN|V$Z(Nu?a@LCuR~*2s|MKiQ z6jVH0JoQK81N~{t4DXclu%j&>82hLRb($j`{ASA9|>B*voZ1 zN)kjADGRjc;(2P+=hGiXtp@!C8ZOX$;B5nJ@$CD$!j0nqZUNbhR4;nPIjpAtoMQ)L ze5+p42z~N8zC2>1RH0$781No2%QQ33sl2ihuA%vCjPBev{25<8{xwAy?kam-iSC=l z^OSZ8`ko6w))tn2=Pa_(a`}kV>#OQt(ZA8^G=TFhoktc6+emum{NIdmdblosM03~v zaqnV|uqjRJ^ZBMW;!)Gy&s7O>b}&9@>9c;QYAS<1@o3L9_NXmzL=J=xT9nrYD87v( zH)HOy9_c&z0o!(u`B_{n{zm+d@sX-56K=@GU|}<~jhbK2nm8ec6~efwRqiZJYjJow zyxe|PPBDIc|ITU-1+!QcdbdE=2LeF-^(TOJRs1e*dwK8qj{T@DJB9ciZy_cFq!#%j zH$!3Y+vNX6Q-v);pNHkfx(#hhAUZ_eFmJ+qr*^Yr@4RTyEW8Ed0gj@Bd(7@`D3~ze zd4!4VYH$hF0g($_E@5ad-yy)RJL@HkuDOo|SoaQ2COnXfAHiz>-D*gOxcQ{YO>z^^ z+tED&fUu-AjWe)PyedI)3!oUFA*1N3)iYeVughC<0W8}A_{=Bl62_~}M1GVT1w1rJ zGi3Cef8+wU0j{Ig2>O|aH&mcMZTbCA3$rwrt};0wds}?9mQIJ!dl8l;V6}~=pwxg3 z{$2qz7TMFH@7%i@MCWFRt>LWw`<^#$* zGFb`bDJlJiRY#*Yz?1{DLj&JPEtTxKNjUG95I!gcPztEbfh=JL`y0MxS#tXCt6`xB zq_?h8z?G(M^o1N?fP8W4waMzgU<6=TRoFje_l-YGKH9g<<1cgW#69ojrl@uH9KpQ% zNzLa@@pa=fiQwtuC7u^(rq59wpMGVPO3&ASVY##WXnCk4&@|Xi3{|=_Kuues!E3hN z>O)Hn=CPn zITWAFWndK)(==Xo3L~hM~AIbd>vmepluovhjVtxO1V$932Nd@rk zCYxw@B$aIRJc$9n%K>ZTpwwfqqy=I|{E!StSGH1pQ*o@K8p?89$iVW!B7ZJasz^S(wTDA>SlbZz~ zB(~|D+pm~qz{#+0kMO{+*{VHc*{X0TWhFx zVlp_4PxJPn<`IHS#!(x8fNp=6DLN3yC?Y^VbSE0`CcZ0FKec8p*G~tS7(kUFs3BL; z1ym=N8>9?X9B$~86#~~WJex<~ROp0&Kt=b-A5fsL{%lh*tYN)v+B?@qom-^Jt<8j` z&-EN)ftb`Kg<@@hf{4$A@AXV{!C^*(FN${k*;W6_OTdpm^YkI{`=Qz;-q`}m!jqB% zETXTz_fVW~2m?1tLE*l>ix?pE+T`v5<_kbg^b-NR{ND?nzko{-y&d^GzkDFyFO4QGC#5V|A^zdh{{#I None: def _setup_knapsack(model: InstaNovo) -> Knapsack: - MASS_SCALE = 10000 residue_masses = model.peptide_mass_calculator.masses residue_masses["$"] = 0 residue_indices = model.decoder._aa2idx diff --git a/instanovo/transformer/train.py b/instanovo/transformer/train.py index 8be1992..6d4c0f1 100644 --- a/instanovo/transformer/train.py +++ b/instanovo/transformer/train.py @@ -149,7 +149,7 @@ def validation_step( ) # targets = self.model.batch_idx_to_aa(peptides) - y = ["".join(x.sequence) if type(x) != list else "" for x in p] + y = ["".join(x.sequence) if not isinstance(x, list) else "" for x in p] targets = peptides aa_prec, aa_recall, pep_recall, _ = self.metrics.compute_precision_recall(targets, y) diff --git a/mkdocs.yml b/mkdocs.yml new file mode 100644 index 0000000..283b6ed --- /dev/null +++ b/mkdocs.yml @@ -0,0 +1,87 @@ +site_name: InstaNovo +site_description: Documentation for InstaNovo +site_author: InstaDeep Developers +site_url: https://instadeepai.github.io/InstaNovo/ +site_dir: public +repo_name: instadeep/dtu-denovo-sequencing +repo_url: https://github.com/instadeepai/InstaNovo +strict: false + +theme: + name: material + language: en + palette: + - scheme: default + primary: white + accent: purple + toggle: + icon: material/weather-sunny + name: Switch to dark mode + - scheme: slate + primary: black + accent: lime + toggle: + icon: material/weather-night + name: Switch to light mode + logo: assets/instadeep-logo.png + favicon: assets/instadeep-logo.png + icon: + repo: fontawesome/brands/github + font: + text: Avenir Next + features: + - navigation.tracking # the URL is automatically updated with the active anchor + - navigation.sections # top-level sections are rendered as groups in the sidebar + - navigation.tabs # horizontal tabs at the top of the page + - navigation.tabs.sticky # navigation tabs will lock below the header and always remain visible when scrolling + - navigation.indexes # documents can be directly attached to sections + - search.highlight # highlight search result + - search.share # share button + - search.suggest # smart suggestion + - toc.integrate + - toc.follow + - content.code.annotate + - navigation.tabs + - navigation.top + +markdown_extensions: + - pymdownx.highlight: + anchor_linenums: true + - pymdownx.inlinehilite + - pymdownx.snippets + - pymdownx.superfences + +plugins: + - search + - autorefs + # - git-revision-date + - include-markdown + - gen-files: + scripts: + - docs/gen_ref_nav.py + - mkdocstrings: + default_handler: python + handlers: + python: + paths: [instanovo] + options: + docstring_style: google + merge_init_into_class: yes + show_submodules: no + selection: + inherited_members: false + rendering: + show_source: false + members_order: source + show_if_no_docstring: true + show_signature: true + show_signature_annotations: true + show_root_full_path: false + show_root_heading: true + merge_init_into_class: true + docstring_section_style: spacy + +nav: + - Home: index.md + - License: LICENSE.md + - Code reference: reference/SUMMARY.md diff --git a/notebooks/getting_started_with_instanovo.ipynb b/notebooks/getting_started_with_instanovo.ipynb index 276e9a3..c6f560c 100644 --- a/notebooks/getting_started_with_instanovo.ipynb +++ b/notebooks/getting_started_with_instanovo.ipynb @@ -92,8 +92,6 @@ }, "outputs": [], "source": [ - "from instanovo.inference.knapsack import Knapsack\n", - "from instanovo.inference.knapsack_beam_search import KnapsackBeamSearchDecoder\n", "from instanovo.transformer.model import InstaNovo\n", "\n", "from tqdm import tqdm\n", @@ -417,8 +415,11 @@ }, "outputs": [], "source": [ + "from instanovo.constants import MASS_SCALE\n", + "from instanovo.inference.knapsack import Knapsack\n", + "from instanovo.inference.knapsack_beam_search import KnapsackBeamSearchDecoder\n", + "\n", "def _setup_knapsack(model: InstaNovo) -> Knapsack:\n", - " MASS_SCALE = 10000\n", " residue_masses = model.peptide_mass_calculator.masses\n", " residue_masses[\"$\"] = 0\n", " residue_indices = model.decoder._aa2idx\n", @@ -481,8 +482,8 @@ " beam_size=config[\"n_beams\"],\n", " max_length=config[\"max_length\"],\n", " )\n", - "preds = [\"\".join(x.sequence) if type(x) != list else \"\" for x in p]\n", - "probs = [x.log_probability if type(x) != list else -1 for x in p]" + "preds = [\"\".join(x.sequence) if not isinstance(x, list) else \"\" for x in p]\n", + "probs = [x.log_probability if not isinstance(x, list) else -1 for x in p]" ] }, { @@ -560,8 +561,8 @@ " max_length=config[\"max_length\"],\n", " )\n", "\n", - " preds += [\"\".join(x.sequence) if type(x) != list else \"\" for x in p]\n", - " probs += [x.log_probability if type(x) != list else -1 for x in p]\n", + " preds += [\"\".join(x.sequence) if not isinstance(x, list) else \"\" for x in p]\n", + " probs += [x.log_probability if not isinstance(x, list) else -1 for x in p]\n", " targs += list(peptides)" ] }, @@ -837,8 +838,9 @@ "provenance": [] }, "kernelspec": { - "display_name": "Python 3", - "name": "python3" + "display_name": "instanovo-py3.8", + "language": "python", + "name": "instanovo-py3.8" }, "language_info": { "codemirror_mode": { @@ -850,7 +852,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.3" + "version": "3.8.9" } }, "nbformat": 4, diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..cd19794 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,13 @@ +# All of these dependencies are pinned to specific versions to achieve reproducible builds (CI/DEV) +allure-pytest==2.11.1 +coverage==6.4.4 +livereload==2.6.3 +pre-commit==2.20.0 +pytest==7.1.3 +pytest-cov==3.0.0 +pytest-mock==3.8.2 +pytest-parallel==0.1.1 +pytest-random-order==1.0.4 +pytest-xdist==2.5.0 +pytype==2023.10.5 +testfixtures==7.0.0 diff --git a/requirements-docs.txt b/requirements-docs.txt new file mode 100644 index 0000000..4ee1125 --- /dev/null +++ b/requirements-docs.txt @@ -0,0 +1,10 @@ +# All of these dependencies are pinned to specific versions to achieve reproducible builds (CI/DEV) +mkdocs==1.4.0 +mkdocs-gen-files==0.4.0 +mkdocs-git-revision-date-plugin==0.3.2 +mkdocs-include-markdown-plugin==3.9.0 +mkdocs-material==8.5.6 +mkdocs-pymdownx-material-extras==2.2.1 +mkdocstrings==0.23.0 +mkdocstrings-python==1.7.1 +pymdown-extensions==9.5 diff --git a/requirements.txt b/requirements.txt index 819098a..9ada7de 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,10 +5,9 @@ botocore==1.27.59 click==8.1.7 cloudpathlib==0.10.0 datasets==2.14.5 -deepspeed==0.7.2 depthcharge-ms==0.1.0 fastprogress==1.0.3 -jiwer==2.5.1 +jiwer==3.0.3 matchms==0.22.0 matplotlib==3.7.2 numba==0.57.1 @@ -17,17 +16,17 @@ omegaconf==2.2.3 pandas==2.0.3 polars==0.19.7 protobuf==3.19.6 -pyarrow==11.0.0 -pyopenms==2.7.0 +pyarrow==15.0.0 +pyopenms==2.7.0; python_version < '3.10' +pyopenms==3.1.0; python_version >= '3.10' python-dotenv==0.21.0 pytorch_lightning==1.8.6 -scikit-learn==1.1.2 +scikit-learn==1.1.2; python_version < '3.11' +scikit-learn==1.4.1.post1; python_version >= '3.11' seaborn==0.12.0 spectrum_utils==0.4.1 tensorboard==2.10.1 tensorboardX==2.5.1 -torch==1.13.1 -torchaudio==0.13.1 -torchvision==0.14.1 +torch==2.2.1 tqdm==4.65.0 transfusion-asr==0.1.0 diff --git a/setup.cfg b/setup.cfg index 38dc5f7..5a56cc2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [semantic_release] version_variable = setup.py:__version__ -hvcs=gitlab +hvcs=github upload_to_pypi=false major_on_zero=false branch=main diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..7461d40 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +import os +import sys + +import pytest +import requests +from datasets import load_dataset +from datasets.arrow_dataset import Dataset +from datasets.dataset_dict import DatasetDict +from datasets.dataset_dict import IterableDatasetDict +from datasets.iterable_dataset import IterableDataset + +# Add the root directory to the PYTHONPATH +# This allows pytest to find the modules for testing + +root_dir = os.path.dirname(os.path.dirname(__file__)) +sys.path.append(root_dir) + + +@pytest.fixture(scope="session") +def checkpoints_dir() -> str: + """A pytest fixture to create and provide the absolute path of a 'checkpoints' directory. + + Ensures the directory exists for storing checkpoint files during the test session. + """ + checkpoints_dir = "checkpoints" + os.makedirs(checkpoints_dir, exist_ok=True) + return os.path.abspath(checkpoints_dir) + + +@pytest.fixture(scope="session") +def instanovo_checkpoint(checkpoints_dir: str) -> str: + """A pytest fixture to download and provide the path of the InstaNovo model checkpoint. + + Downloads from a predefined URL if the checkpoint file doesn't exist locally. + """ + url = "https://github.com/instadeepai/InstaNovo/releases/download/0.1.4/instanovo_yeast.pt" + checkpoint_path = os.path.join(checkpoints_dir, "instanovo_yeast.pt") + + if not os.path.isfile(checkpoint_path): + response = requests.get(url) + with open(checkpoint_path, "wb") as file: + file.write(response.content) + + return os.path.abspath(checkpoint_path) + + +@pytest.fixture(scope="session") +def dataset() -> DatasetDict | Dataset | IterableDatasetDict | IterableDataset: + """A pytest fixture to load and provide a dataset for testing. + + Loads a specific subset (1% of test data) from the 'instanovo_ninespecies_exclude_yeast' dataset. + """ + return load_dataset("InstaDeepAI/instanovo_ninespecies_exclude_yeast", split="test[:1%]") + + +@pytest.fixture(scope="session") +def knapsack_dir(checkpoints_dir: str) -> str: + """A pytest fixture to create and provide the absolute path of a 'knapsack' directory within the checkpoints directory for storing test artifacts.""" + knapsack_dir = os.path.join(checkpoints_dir, "knapsack") + return os.path.abspath(knapsack_dir) diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/model_test.py b/tests/integration/model_test.py new file mode 100644 index 0000000..0cfb0c5 --- /dev/null +++ b/tests/integration/model_test.py @@ -0,0 +1,488 @@ +from __future__ import annotations + +import os + +import numpy as np +import torch +from datasets.arrow_dataset import Dataset +from datasets.dataset_dict import DatasetDict +from datasets.dataset_dict import IterableDatasetDict +from datasets.iterable_dataset import IterableDataset +from torch.utils.data import DataLoader + +from instanovo.constants import MASS_SCALE +from instanovo.inference.knapsack import Knapsack +from instanovo.inference.knapsack_beam_search import KnapsackBeamSearchDecoder +from instanovo.transformer.dataset import collate_batch +from instanovo.transformer.dataset import SpectrumDataset +from instanovo.transformer.model import InstaNovo + + +def _setup_knapsack(model: InstaNovo) -> Knapsack: + residue_masses = model.peptide_mass_calculator.masses + residue_masses["$"] = 0 + residue_indices = model.decoder._aa2idx + return Knapsack.construct_knapsack( + residue_masses=residue_masses, + residue_indices=residue_indices, + max_mass=4000.00, + mass_scale=MASS_SCALE, + ) + + +def test_model( + instanovo_checkpoint: str, + dataset: DatasetDict | Dataset | IterableDatasetDict | IterableDataset, + knapsack_dir: str, +) -> None: + """Test loading an InstaNovo model and doing inference end-to-end.""" + device = "cuda" if torch.cuda.is_available() else "cpu" + model, config = InstaNovo.load(instanovo_checkpoint) + model = model.to(device).eval() + + s2i = {v: k for k, v in model.i2s.items()} + assert s2i == { + "A": 1, + "C(+57.02)": 6, + "D": 10, + "E": 13, + "F": 16, + "G": 0, + "H": 15, + "I": 8, + "K": 12, + "L": 7, + "M": 14, + "M(+15.99)": 20, + "N": 9, + "N(+.98)": 21, + "P": 3, + "Q": 11, + "Q(+.98)": 22, + "R": 17, + "S": 2, + "T": 5, + "V": 4, + "W": 19, + "Y": 18, + } + + n_peaks = config["n_peaks"] + assert n_peaks == 200 + + ds = SpectrumDataset(dataset, s2i, config["n_peaks"], return_str=True) + assert len(ds) == 271 + spectrum, precursor_mz, precursor_charge, peptide = ds[0] + assert torch.allclose( + spectrum, + torch.Tensor( + [ + [1.0096e02, 6.8907e-02], + [1.1006e02, 6.6649e-02], + [1.1646e02, 6.5169e-02], + [1.2910e02, 1.3785e-01], + [1.3009e02, 1.3666e-01], + [1.4711e02, 1.4966e-01], + [1.7309e02, 7.0756e-02], + [1.8612e02, 1.0042e-01], + [2.0413e02, 1.4815e-01], + [2.7303e02, 7.4630e-02], + [2.8318e02, 1.1245e-01], + [3.0119e02, 5.0341e-01], + [3.2845e02, 7.8869e-02], + [3.7222e02, 1.9128e-01], + [4.7877e02, 7.5372e-02], + [5.2873e02, 8.4931e-02], + [5.7176e02, 8.8744e-02], + [5.7975e02, 9.3491e-02], + [6.1527e02, 8.3923e-02], + [6.5630e02, 9.5524e-02], + [7.7837e02, 1.2861e-01], + [7.7887e02, 2.2789e-01], + [7.7938e02, 6.1620e-02], + [7.9137e02, 1.0959e-01], + [7.9189e02, 1.0925e-01], + [1.0365e03, 9.1989e-02], + [1.1015e03, 9.4942e-02], + [1.1395e03, 2.0198e-01], + [1.1895e03, 9.5279e-02], + [1.2285e03, 1.2807e-01], + [1.2556e03, 1.1814e-01], + [1.2565e03, 1.2252e-01], + [1.2716e03, 2.7978e-01], + [1.2996e03, 4.9205e-01], + ] + ), + rtol=1e-04, + ) + assert precursor_mz == 800.38427734375 + assert precursor_charge == 2.0 + assert peptide == "TPGREDAAEETAAPGK" + + dl = DataLoader(ds, batch_size=2, shuffle=False, collate_fn=collate_batch) + batch = next(iter(dl)) + + spectra, precursors, spectra_mask, peptides, _ = batch + assert torch.allclose( + spectra, + torch.Tensor( + [ + [ + [1.0096e02, 6.8907e-02], + [1.1006e02, 6.6649e-02], + [1.1646e02, 6.5169e-02], + [1.2910e02, 1.3785e-01], + [1.3009e02, 1.3666e-01], + [1.4711e02, 1.4966e-01], + [1.7309e02, 7.0756e-02], + [1.8612e02, 1.0042e-01], + [2.0413e02, 1.4815e-01], + [2.7303e02, 7.4630e-02], + [2.8318e02, 1.1245e-01], + [3.0119e02, 5.0341e-01], + [3.2845e02, 7.8869e-02], + [3.7222e02, 1.9128e-01], + [4.7877e02, 7.5372e-02], + [5.2873e02, 8.4931e-02], + [5.7176e02, 8.8744e-02], + [5.7975e02, 9.3491e-02], + [6.1527e02, 8.3923e-02], + [6.5630e02, 9.5524e-02], + [7.7837e02, 1.2861e-01], + [7.7887e02, 2.2789e-01], + [7.7938e02, 6.1620e-02], + [7.9137e02, 1.0959e-01], + [7.9189e02, 1.0925e-01], + [1.0365e03, 9.1989e-02], + [1.1015e03, 9.4942e-02], + [1.1395e03, 2.0198e-01], + [1.1895e03, 9.5279e-02], + [1.2285e03, 1.2807e-01], + [1.2556e03, 1.1814e-01], + [1.2565e03, 1.2252e-01], + [1.2716e03, 2.7978e-01], + [1.2996e03, 4.9205e-01], + [0.0000e00, 0.0000e00], + [0.0000e00, 0.0000e00], + [0.0000e00, 0.0000e00], + [0.0000e00, 0.0000e00], + [0.0000e00, 0.0000e00], + [0.0000e00, 0.0000e00], + [0.0000e00, 0.0000e00], + [0.0000e00, 0.0000e00], + [0.0000e00, 0.0000e00], + [0.0000e00, 0.0000e00], + [0.0000e00, 0.0000e00], + [0.0000e00, 0.0000e00], + [0.0000e00, 0.0000e00], + [0.0000e00, 0.0000e00], + [0.0000e00, 0.0000e00], + [0.0000e00, 0.0000e00], + [0.0000e00, 0.0000e00], + [0.0000e00, 0.0000e00], + [0.0000e00, 0.0000e00], + [0.0000e00, 0.0000e00], + [0.0000e00, 0.0000e00], + [0.0000e00, 0.0000e00], + [0.0000e00, 0.0000e00], + [0.0000e00, 0.0000e00], + [0.0000e00, 0.0000e00], + [0.0000e00, 0.0000e00], + [0.0000e00, 0.0000e00], + [0.0000e00, 0.0000e00], + [0.0000e00, 0.0000e00], + [0.0000e00, 0.0000e00], + [0.0000e00, 0.0000e00], + [0.0000e00, 0.0000e00], + [0.0000e00, 0.0000e00], + [0.0000e00, 0.0000e00], + [0.0000e00, 0.0000e00], + [0.0000e00, 0.0000e00], + [0.0000e00, 0.0000e00], + [0.0000e00, 0.0000e00], + [0.0000e00, 0.0000e00], + [0.0000e00, 0.0000e00], + [0.0000e00, 0.0000e00], + [0.0000e00, 0.0000e00], + [0.0000e00, 0.0000e00], + ], + [ + [1.0206e02, 1.5659e-01], + [1.2910e02, 8.5658e-02], + [1.3009e02, 6.9173e-02], + [1.4711e02, 9.2828e-02], + [1.5508e02, 7.6935e-02], + [1.6910e02, 4.4613e-02], + [1.7111e02, 1.0253e-01], + [1.7309e02, 1.0162e-01], + [1.9508e02, 4.1588e-02], + [1.9911e02, 7.1197e-02], + [2.0109e02, 7.8843e-02], + [2.0413e02, 1.0829e-01], + [2.1309e02, 6.7321e-02], + [2.3110e02, 5.2743e-02], + [2.3909e02, 4.1613e-02], + [2.4108e02, 5.1726e-02], + [2.5909e02, 7.4543e-02], + [2.8318e02, 9.6649e-02], + [2.8412e02, 4.9769e-02], + [3.0119e02, 4.0183e-01], + [3.1365e02, 6.9522e-02], + [3.2216e02, 5.1740e-02], + [3.2816e02, 4.3714e-02], + [3.2865e02, 8.4462e-02], + [3.4213e02, 4.7763e-02], + [3.5017e02, 8.2215e-02], + [3.5421e02, 5.3792e-02], + [3.5517e02, 1.0384e-01], + [3.7222e02, 1.6867e-01], + [3.7669e02, 5.1280e-02], + [3.8569e02, 7.6477e-02], + [4.1317e02, 7.4231e-02], + [4.2092e02, 4.0213e-02], + [4.4084e02, 5.2411e-02], + [4.4326e02, 1.5205e-01], + [4.5521e02, 7.7468e-02], + [4.5819e02, 4.3700e-02], + [4.7821e02, 1.0422e-01], + [4.8857e02, 7.0330e-02], + [5.4431e02, 2.1082e-01], + [5.5525e02, 1.6596e-01], + [5.5725e02, 4.2561e-02], + [5.6125e02, 5.4916e-02], + [6.0027e02, 7.5823e-02], + [6.2629e02, 9.8458e-02], + [6.5534e02, 9.4754e-02], + [6.5629e02, 1.0460e-01], + [6.6933e02, 1.2577e-01], + [6.7335e02, 1.8197e-01], + [6.9733e02, 1.7234e-01], + [7.0133e02, 1.1745e-01], + [7.0182e02, 5.1286e-02], + [7.0933e02, 5.0761e-02], + [7.2733e02, 2.1210e-01], + [7.2932e02, 5.3202e-02], + [7.4085e02, 4.9344e-02], + [7.4136e02, 5.5698e-02], + [7.4987e02, 1.0906e-01], + [7.5434e02, 4.2837e-02], + [7.7037e02, 7.6015e-02], + [7.8037e02, 1.1569e-01], + [7.8438e02, 7.9805e-02], + [7.9837e02, 2.4075e-01], + [8.0239e02, 1.4020e-01], + [8.2636e02, 1.7631e-01], + [8.7343e02, 4.6804e-02], + [9.2741e02, 2.3908e-01], + [9.3891e02, 4.2673e-02], + [9.5541e02, 5.4546e-02], + [1.0385e03, 1.2572e-01], + [1.0564e03, 1.6373e-01], + [1.1395e03, 1.2800e-01], + [1.1985e03, 5.0744e-02], + [1.1996e03, 4.8810e-02], + [1.3836e03, 7.1560e-02], + [1.4807e03, 1.4328e-01], + [1.4997e03, 1.1361e-01], + ], + ] + ), + rtol=1e-04, + ) + assert torch.allclose( + precursors, torch.Tensor([[1598.7540, 2.0000, 800.3843], [1598.7551, 3.0000, 533.9257]]) + ) + assert torch.equal( + spectra_mask, + torch.Tensor( + [ + [ + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + True, + ], + [ + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + ] + ), + ) + assert peptides == ("TPGREDAAEETAAPGK", "TPGREDAAEETAAPGK") + + spectra = spectra.to(device) + precursors = precursors.to(device) + + if not os.path.exists(knapsack_dir): + knapsack = _setup_knapsack(model) + decoder = KnapsackBeamSearchDecoder(model, knapsack) + knapsack.save(knapsack_dir) + else: + decoder = KnapsackBeamSearchDecoder.from_file(model=model, path=knapsack_dir) + + assert os.path.isfile(os.path.join(knapsack_dir, "parameters.pkl")) + assert os.path.isfile(os.path.join(knapsack_dir, "chart.npy")) + assert os.path.isfile(os.path.join(knapsack_dir, "masses.npy")) + assert isinstance(decoder, KnapsackBeamSearchDecoder) + + with torch.no_grad(): + p = decoder.decode( + spectra=spectra, + precursors=precursors, + beam_size=config["n_beams"], + max_length=config["max_length"], + ) + preds = ["".join(x.sequence) if not isinstance(x, list) else "" for x in p] + probs = [x.log_probability if not isinstance(x, list) else -1 for x in p] + + assert preds == ["NRNVGDQNGC(+57.02)LAPGK", "TDRPGEAAEETAAPGK"] + assert np.allclose(probs, [-8.156049728393555, -3.1159517765045166])