diff --git a/1pga_faesmfold.pdb b/1pga_faesmfold.pdb new file mode 100644 index 0000000..a6aaa92 --- /dev/null +++ b/1pga_faesmfold.pdb @@ -0,0 +1,438 @@ +PARENT N/A +ATOM 1 N MET A 1 0.183 -2.760 -13.004 1.00 0.90 N +ATOM 2 CA MET A 1 -1.088 -2.401 -12.380 1.00 0.91 C +ATOM 3 C MET A 1 -0.883 -1.346 -11.298 1.00 0.93 C +ATOM 4 CB MET A 1 -1.763 -3.637 -11.783 1.00 0.87 C +ATOM 5 O MET A 1 0.214 -1.212 -10.755 1.00 0.92 O +ATOM 6 CG MET A 1 -2.245 -4.635 -12.824 1.00 0.75 C +ATOM 7 SD MET A 1 -2.844 -6.199 -12.074 1.00 0.70 S +ATOM 8 CE MET A 1 -1.279 -7.107 -11.936 1.00 0.65 C +ATOM 9 N THR A 2 -1.815 -0.529 -11.080 1.00 0.95 N +ATOM 10 CA THR A 2 -1.768 0.488 -10.035 1.00 0.95 C +ATOM 11 C THR A 2 -2.298 -0.067 -8.716 1.00 0.95 C +ATOM 12 CB THR A 2 -2.578 1.735 -10.433 1.00 0.94 C +ATOM 13 O THR A 2 -3.384 -0.649 -8.673 1.00 0.94 O +ATOM 14 CG2 THR A 2 -2.467 2.825 -9.372 1.00 0.86 C +ATOM 15 OG1 THR A 2 -2.081 2.244 -11.677 1.00 0.88 O +ATOM 16 N TYR A 3 -1.578 0.019 -7.715 1.00 0.95 N +ATOM 17 CA TYR A 3 -1.958 -0.377 -6.364 1.00 0.95 C +ATOM 18 C TYR A 3 -2.073 0.838 -5.451 1.00 0.94 C +ATOM 19 CB TYR A 3 -0.942 -1.368 -5.788 1.00 0.95 C +ATOM 20 O TYR A 3 -1.387 1.842 -5.655 1.00 0.93 O +ATOM 21 CG TYR A 3 -0.905 -2.691 -6.515 1.00 0.93 C +ATOM 22 CD1 TYR A 3 -1.670 -3.769 -6.078 1.00 0.91 C +ATOM 23 CD2 TYR A 3 -0.104 -2.864 -7.639 1.00 0.90 C +ATOM 24 CE1 TYR A 3 -1.638 -4.991 -6.743 1.00 0.92 C +ATOM 25 CE2 TYR A 3 -0.064 -4.081 -8.312 1.00 0.92 C +ATOM 26 OH TYR A 3 -0.797 -6.343 -8.520 1.00 0.85 O +ATOM 27 CZ TYR A 3 -0.833 -5.136 -7.858 1.00 0.91 C +ATOM 28 N LYS A 4 -2.877 0.766 -4.502 1.00 0.95 N +ATOM 29 CA LYS A 4 -3.122 1.864 -3.571 1.00 0.95 C +ATOM 30 C LYS A 4 -2.766 1.463 -2.143 1.00 0.94 C +ATOM 31 CB LYS A 4 -4.583 2.311 -3.642 1.00 0.93 C +ATOM 32 O LYS A 4 -2.962 0.312 -1.748 1.00 0.94 O +ATOM 33 CG LYS A 4 -4.911 3.499 -2.749 1.00 0.86 C +ATOM 34 CD LYS A 4 -6.380 3.886 -2.850 1.00 0.82 C +ATOM 35 CE LYS A 4 -6.704 4.504 -4.204 1.00 0.76 C +ATOM 36 NZ LYS A 4 -8.129 4.945 -4.285 1.00 0.67 N +ATOM 37 N LEU A 5 -2.221 2.400 -1.407 1.00 0.93 N +ATOM 38 CA LEU A 5 -2.010 2.264 0.030 1.00 0.93 C +ATOM 39 C LEU A 5 -2.804 3.316 0.798 1.00 0.92 C +ATOM 40 CB LEU A 5 -0.522 2.385 0.367 1.00 0.91 C +ATOM 41 O LEU A 5 -2.715 4.509 0.498 1.00 0.91 O +ATOM 42 CG LEU A 5 -0.161 2.383 1.854 1.00 0.87 C +ATOM 43 CD1 LEU A 5 -0.399 1.002 2.455 1.00 0.80 C +ATOM 44 CD2 LEU A 5 1.287 2.816 2.053 1.00 0.81 C +ATOM 45 N ILE A 6 -3.582 2.887 1.705 1.00 0.92 N +ATOM 46 CA ILE A 6 -4.266 3.754 2.659 1.00 0.91 C +ATOM 47 C ILE A 6 -3.625 3.610 4.037 1.00 0.91 C +ATOM 48 CB ILE A 6 -5.775 3.433 2.734 1.00 0.90 C +ATOM 49 O ILE A 6 -3.626 2.523 4.619 1.00 0.89 O +ATOM 50 CG1 ILE A 6 -6.420 3.574 1.350 1.00 0.86 C +ATOM 51 CG2 ILE A 6 -6.470 4.337 3.756 1.00 0.86 C +ATOM 52 CD1 ILE A 6 -7.869 3.110 1.292 1.00 0.83 C +ATOM 53 N LEU A 7 -3.059 4.669 4.542 1.00 0.87 N +ATOM 54 CA LEU A 7 -2.478 4.725 5.879 1.00 0.87 C +ATOM 55 C LEU A 7 -3.454 5.349 6.871 1.00 0.86 C +ATOM 56 CB LEU A 7 -1.170 5.522 5.862 1.00 0.83 C +ATOM 57 O LEU A 7 -3.857 6.504 6.709 1.00 0.83 O +ATOM 58 CG LEU A 7 -0.019 4.918 5.057 1.00 0.75 C +ATOM 59 CD1 LEU A 7 1.119 5.924 4.923 1.00 0.68 C +ATOM 60 CD2 LEU A 7 0.472 3.630 5.708 1.00 0.68 C +ATOM 61 N ASN A 8 -3.918 4.505 7.775 1.00 0.83 N +ATOM 62 CA ASN A 8 -4.789 4.956 8.855 1.00 0.82 C +ATOM 63 C ASN A 8 -4.047 5.015 10.187 1.00 0.81 C +ATOM 64 CB ASN A 8 -6.016 4.048 8.970 1.00 0.80 C +ATOM 65 O ASN A 8 -4.316 4.219 11.088 1.00 0.78 O +ATOM 66 CG ASN A 8 -6.916 4.125 7.753 1.00 0.74 C +ATOM 67 ND2 ASN A 8 -7.387 2.974 7.290 1.00 0.71 N +ATOM 68 OD1 ASN A 8 -7.185 5.211 7.232 1.00 0.71 O +ATOM 69 N GLY A 9 -3.046 5.825 10.212 1.00 0.77 N +ATOM 70 CA GLY A 9 -2.276 6.030 11.428 1.00 0.77 C +ATOM 71 C GLY A 9 -2.822 7.145 12.300 1.00 0.78 C +ATOM 72 O GLY A 9 -3.716 7.884 11.883 1.00 0.75 O +ATOM 73 N LYS A 10 -2.439 7.238 13.690 1.00 0.76 N +ATOM 74 CA LYS A 10 -2.861 8.266 14.637 1.00 0.75 C +ATOM 75 C LYS A 10 -2.356 9.643 14.217 1.00 0.75 C +ATOM 76 CB LYS A 10 -2.365 7.933 16.046 1.00 0.72 C +ATOM 77 O LYS A 10 -3.087 10.632 14.308 1.00 0.72 O +ATOM 78 CG LYS A 10 -3.054 6.733 16.678 1.00 0.66 C +ATOM 79 CD LYS A 10 -2.589 6.512 18.111 1.00 0.64 C +ATOM 80 CE LYS A 10 -3.270 5.304 18.741 1.00 0.56 C +ATOM 81 NZ LYS A 10 -2.769 5.045 20.124 1.00 0.49 N +ATOM 82 N THR A 11 -1.035 9.648 13.714 1.00 0.70 N +ATOM 83 CA THR A 11 -0.403 10.921 13.384 1.00 0.70 C +ATOM 84 C THR A 11 -0.329 11.111 11.872 1.00 0.70 C +ATOM 85 CB THR A 11 1.011 11.016 13.986 1.00 0.65 C +ATOM 86 O THR A 11 -0.329 12.243 11.383 1.00 0.66 O +ATOM 87 CG2 THR A 11 0.953 11.222 15.496 1.00 0.56 C +ATOM 88 OG1 THR A 11 1.725 9.805 13.707 1.00 0.60 O +ATOM 89 N LEU A 12 -0.234 10.048 11.113 1.00 0.74 N +ATOM 90 CA LEU A 12 -0.139 10.106 9.658 1.00 0.74 C +ATOM 91 C LEU A 12 -1.285 9.338 9.008 1.00 0.74 C +ATOM 92 CB LEU A 12 1.202 9.541 9.185 1.00 0.70 C +ATOM 93 O LEU A 12 -1.485 8.155 9.290 1.00 0.71 O +ATOM 94 CG LEU A 12 2.326 10.555 8.965 1.00 0.65 C +ATOM 95 CD1 LEU A 12 3.522 10.221 9.849 1.00 0.61 C +ATOM 96 CD2 LEU A 12 2.734 10.592 7.496 1.00 0.62 C +ATOM 97 N LYS A 13 -2.193 10.159 8.261 1.00 0.79 N +ATOM 98 CA LYS A 13 -3.240 9.641 7.385 1.00 0.80 C +ATOM 99 C LYS A 13 -3.007 10.067 5.938 1.00 0.79 C +ATOM 100 CB LYS A 13 -4.616 10.113 7.856 1.00 0.75 C +ATOM 101 O LYS A 13 -2.635 11.212 5.675 1.00 0.75 O +ATOM 102 CG LYS A 13 -5.021 9.578 9.221 1.00 0.68 C +ATOM 103 CD LYS A 13 -6.463 9.935 9.557 1.00 0.66 C +ATOM 104 CE LYS A 13 -6.868 9.402 10.924 1.00 0.58 C +ATOM 105 NZ LYS A 13 -8.293 9.720 11.243 1.00 0.50 N +ATOM 106 N GLY A 14 -3.046 9.068 5.162 1.00 0.84 N +ATOM 107 CA GLY A 14 -2.922 9.439 3.761 1.00 0.85 C +ATOM 108 C GLY A 14 -3.017 8.253 2.819 1.00 0.86 C +ATOM 109 O GLY A 14 -3.286 7.130 3.251 1.00 0.84 O +ATOM 110 N GLU A 15 -3.025 8.612 1.568 1.00 0.88 N +ATOM 111 CA GLU A 15 -3.069 7.610 0.507 1.00 0.89 C +ATOM 112 C GLU A 15 -1.931 7.811 -0.490 1.00 0.89 C +ATOM 113 CB GLU A 15 -4.417 7.652 -0.217 1.00 0.86 C +ATOM 114 O GLU A 15 -1.494 8.940 -0.721 1.00 0.88 O +ATOM 115 CG GLU A 15 -5.615 7.465 0.703 1.00 0.77 C +ATOM 116 CD GLU A 15 -6.942 7.444 -0.038 1.00 0.74 C +ATOM 117 OE1 GLU A 15 -6.943 7.553 -1.286 1.00 0.72 O +ATOM 118 OE2 GLU A 15 -7.990 7.317 0.634 1.00 0.68 O +ATOM 119 N THR A 16 -1.391 6.834 -0.944 1.00 0.91 N +ATOM 120 CA THR A 16 -0.424 6.881 -2.035 1.00 0.91 C +ATOM 121 C THR A 16 -0.618 5.700 -2.980 1.00 0.92 C +ATOM 122 CB THR A 16 1.021 6.883 -1.501 1.00 0.90 C +ATOM 123 O THR A 16 -1.346 4.757 -2.663 1.00 0.92 O +ATOM 124 CG2 THR A 16 1.377 5.541 -0.871 1.00 0.81 C +ATOM 125 OG1 THR A 16 1.923 7.145 -2.583 1.00 0.83 O +ATOM 126 N THR A 17 -0.040 5.771 -4.150 1.00 0.93 N +ATOM 127 CA THR A 17 -0.143 4.687 -5.120 1.00 0.93 C +ATOM 128 C THR A 17 1.233 4.322 -5.670 1.00 0.93 C +ATOM 129 CB THR A 17 -1.079 5.065 -6.283 1.00 0.92 C +ATOM 130 O THR A 17 2.181 5.101 -5.553 1.00 0.92 O +ATOM 131 CG2 THR A 17 -2.480 5.394 -5.777 1.00 0.83 C +ATOM 132 OG1 THR A 17 -0.548 6.209 -6.964 1.00 0.86 O +ATOM 133 N THR A 18 1.293 3.177 -6.259 1.00 0.93 N +ATOM 134 CA THR A 18 2.477 2.721 -6.980 1.00 0.94 C +ATOM 135 C THR A 18 2.086 1.809 -8.139 1.00 0.94 C +ATOM 136 CB THR A 18 3.449 1.979 -6.045 1.00 0.92 C +ATOM 137 O THR A 18 1.030 1.174 -8.108 1.00 0.94 O +ATOM 138 CG2 THR A 18 2.874 0.637 -5.606 1.00 0.85 C +ATOM 139 OG1 THR A 18 4.687 1.755 -6.732 1.00 0.85 O +ATOM 140 N GLU A 19 2.905 1.820 -9.176 1.00 0.94 N +ATOM 141 CA GLU A 19 2.784 0.820 -10.232 1.00 0.94 C +ATOM 142 C GLU A 19 3.670 -0.390 -9.951 1.00 0.93 C +ATOM 143 CB GLU A 19 3.140 1.427 -11.592 1.00 0.91 C +ATOM 144 O GLU A 19 4.852 -0.241 -9.635 1.00 0.92 O +ATOM 145 CG GLU A 19 2.220 2.563 -12.018 1.00 0.78 C +ATOM 146 CD GLU A 19 0.818 2.097 -12.375 1.00 0.75 C +ATOM 147 OE1 GLU A 19 0.676 1.019 -12.997 1.00 0.76 O +ATOM 148 OE2 GLU A 19 -0.148 2.815 -12.032 1.00 0.75 O +ATOM 149 N ALA A 20 3.097 -1.506 -10.040 1.00 0.93 N +ATOM 150 CA ALA A 20 3.836 -2.744 -9.805 1.00 0.94 C +ATOM 151 C ALA A 20 3.292 -3.880 -10.665 1.00 0.93 C +ATOM 152 CB ALA A 20 3.779 -3.125 -8.327 1.00 0.92 C +ATOM 153 O ALA A 20 2.147 -3.830 -11.119 1.00 0.92 O +ATOM 154 N VAL A 21 4.222 -4.970 -10.858 1.00 0.93 N +ATOM 155 CA VAL A 21 3.830 -6.107 -11.685 1.00 0.93 C +ATOM 156 C VAL A 21 2.848 -6.990 -10.918 1.00 0.93 C +ATOM 157 CB VAL A 21 5.056 -6.934 -12.130 1.00 0.92 C +ATOM 158 O VAL A 21 2.031 -7.689 -11.521 1.00 0.91 O +ATOM 159 CG1 VAL A 21 5.943 -6.122 -13.072 1.00 0.84 C +ATOM 160 CG2 VAL A 21 5.852 -7.404 -10.914 1.00 0.82 C +ATOM 161 N ASP A 22 3.007 -7.005 -9.595 1.00 0.93 N +ATOM 162 CA ASP A 22 2.130 -7.840 -8.780 1.00 0.94 C +ATOM 163 C ASP A 22 1.990 -7.275 -7.369 1.00 0.94 C +ATOM 164 CB ASP A 22 2.656 -9.275 -8.723 1.00 0.93 C +ATOM 165 O ASP A 22 2.657 -6.301 -7.015 1.00 0.93 O +ATOM 166 CG ASP A 22 4.075 -9.366 -8.188 1.00 0.90 C +ATOM 167 OD1 ASP A 22 4.497 -8.472 -7.423 1.00 0.88 O +ATOM 168 OD2 ASP A 22 4.778 -10.339 -8.538 1.00 0.89 O +ATOM 169 N ALA A 23 1.087 -7.818 -6.612 1.00 0.93 N +ATOM 170 CA ALA A 23 0.753 -7.328 -5.277 1.00 0.94 C +ATOM 171 C ALA A 23 1.945 -7.457 -4.332 1.00 0.94 C +ATOM 172 CB ALA A 23 -0.450 -8.084 -4.718 1.00 0.93 C +ATOM 173 O ALA A 23 2.160 -6.599 -3.473 1.00 0.93 O +ATOM 174 N ALA A 24 2.772 -8.517 -4.490 1.00 0.93 N +ATOM 175 CA ALA A 24 3.914 -8.742 -3.608 1.00 0.93 C +ATOM 176 C ALA A 24 4.943 -7.624 -3.746 1.00 0.93 C +ATOM 177 CB ALA A 24 4.558 -10.094 -3.907 1.00 0.92 C +ATOM 178 O ALA A 24 5.483 -7.143 -2.747 1.00 0.92 O +ATOM 179 N THR A 25 5.137 -7.316 -5.003 1.00 0.93 N +ATOM 180 CA THR A 25 6.077 -6.233 -5.268 1.00 0.94 C +ATOM 181 C THR A 25 5.544 -4.910 -4.725 1.00 0.93 C +ATOM 182 CB THR A 25 6.360 -6.095 -6.775 1.00 0.92 C +ATOM 183 O THR A 25 6.289 -4.137 -4.119 1.00 0.92 O +ATOM 184 CG2 THR A 25 7.375 -4.990 -7.046 1.00 0.78 C +ATOM 185 OG1 THR A 25 6.876 -7.335 -7.274 1.00 0.82 O +ATOM 186 N ALA A 26 4.312 -4.645 -4.929 1.00 0.94 N +ATOM 187 CA ALA A 26 3.693 -3.433 -4.401 1.00 0.94 C +ATOM 188 C ALA A 26 3.794 -3.384 -2.879 1.00 0.94 C +ATOM 189 CB ALA A 26 2.232 -3.347 -4.837 1.00 0.94 C +ATOM 190 O ALA A 26 4.098 -2.336 -2.304 1.00 0.93 O +ATOM 191 N GLU A 27 3.520 -4.498 -2.247 1.00 0.93 N +ATOM 192 CA GLU A 27 3.572 -4.602 -0.792 1.00 0.93 C +ATOM 193 C GLU A 27 4.951 -4.224 -0.260 1.00 0.93 C +ATOM 194 CB GLU A 27 3.205 -6.019 -0.341 1.00 0.92 C +ATOM 195 O GLU A 27 5.063 -3.537 0.757 1.00 0.92 O +ATOM 196 CG GLU A 27 3.096 -6.175 1.169 1.00 0.86 C +ATOM 197 CD GLU A 27 2.688 -7.574 1.600 1.00 0.83 C +ATOM 198 OE1 GLU A 27 2.557 -8.464 0.728 1.00 0.82 O +ATOM 199 OE2 GLU A 27 2.497 -7.783 2.819 1.00 0.81 O +ATOM 200 N LYS A 28 5.995 -4.662 -0.916 1.00 0.93 N +ATOM 201 CA LYS A 28 7.356 -4.350 -0.490 1.00 0.93 C +ATOM 202 C LYS A 28 7.610 -2.846 -0.517 1.00 0.93 C +ATOM 203 CB LYS A 28 8.375 -5.069 -1.376 1.00 0.92 C +ATOM 204 O LYS A 28 8.191 -2.292 0.418 1.00 0.92 O +ATOM 205 CG LYS A 28 8.490 -6.561 -1.100 1.00 0.82 C +ATOM 206 CD LYS A 28 9.565 -7.208 -1.965 1.00 0.78 C +ATOM 207 CE LYS A 28 9.620 -8.715 -1.756 1.00 0.72 C +ATOM 208 NZ LYS A 28 10.610 -9.365 -2.666 1.00 0.65 N +ATOM 209 N VAL A 29 7.210 -2.206 -1.579 1.00 0.93 N +ATOM 210 CA VAL A 29 7.357 -0.763 -1.733 1.00 0.93 C +ATOM 211 C VAL A 29 6.568 -0.043 -0.642 1.00 0.93 C +ATOM 212 CB VAL A 29 6.890 -0.290 -3.128 1.00 0.92 C +ATOM 213 O VAL A 29 7.082 0.874 0.002 1.00 0.92 O +ATOM 214 CG1 VAL A 29 6.849 1.235 -3.196 1.00 0.82 C +ATOM 215 CG2 VAL A 29 7.805 -0.852 -4.215 1.00 0.81 C +ATOM 216 N PHE A 30 5.406 -0.484 -0.415 1.00 0.93 N +ATOM 217 CA PHE A 30 4.532 0.160 0.558 1.00 0.93 C +ATOM 218 C PHE A 30 5.046 -0.060 1.976 1.00 0.92 C +ATOM 219 CB PHE A 30 3.100 -0.370 0.431 1.00 0.93 C +ATOM 220 O PHE A 30 4.999 0.849 2.807 1.00 0.90 O +ATOM 221 CG PHE A 30 2.382 0.111 -0.801 1.00 0.92 C +ATOM 222 CD1 PHE A 30 2.598 1.393 -1.292 1.00 0.90 C +ATOM 223 CD2 PHE A 30 1.491 -0.719 -1.469 1.00 0.90 C +ATOM 224 CE1 PHE A 30 1.935 1.841 -2.432 1.00 0.90 C +ATOM 225 CE2 PHE A 30 0.825 -0.278 -2.608 1.00 0.90 C +ATOM 226 CZ PHE A 30 1.047 1.003 -3.088 1.00 0.90 C +ATOM 227 N LYS A 31 5.516 -1.276 2.296 1.00 0.92 N +ATOM 228 CA LYS A 31 6.086 -1.556 3.610 1.00 0.91 C +ATOM 229 C LYS A 31 7.314 -0.689 3.874 1.00 0.91 C +ATOM 230 CB LYS A 31 6.454 -3.036 3.732 1.00 0.90 C +ATOM 231 O LYS A 31 7.503 -0.194 4.987 1.00 0.90 O +ATOM 232 CG LYS A 31 5.265 -3.950 3.987 1.00 0.84 C +ATOM 233 CD LYS A 31 5.708 -5.383 4.251 1.00 0.82 C +ATOM 234 CE LYS A 31 4.527 -6.281 4.593 1.00 0.76 C +ATOM 235 NZ LYS A 31 4.948 -7.700 4.790 1.00 0.71 N +ATOM 236 N GLN A 32 8.070 -0.504 2.843 1.00 0.91 N +ATOM 237 CA GLN A 32 9.212 0.392 2.996 1.00 0.90 C +ATOM 238 C GLN A 32 8.756 1.823 3.263 1.00 0.90 C +ATOM 239 CB GLN A 32 10.101 0.347 1.751 1.00 0.89 C +ATOM 240 O GLN A 32 9.303 2.502 4.134 1.00 0.89 O +ATOM 241 CG GLN A 32 11.361 1.194 1.865 1.00 0.81 C +ATOM 242 CD GLN A 32 12.309 0.695 2.939 1.00 0.77 C +ATOM 243 NE2 GLN A 32 12.868 1.619 3.713 1.00 0.69 N +ATOM 244 OE1 GLN A 32 12.536 -0.512 3.073 1.00 0.75 O +ATOM 245 N TYR A 33 7.771 2.254 2.531 1.00 0.88 N +ATOM 246 CA TYR A 33 7.212 3.588 2.714 1.00 0.88 C +ATOM 247 C TYR A 33 6.677 3.767 4.129 1.00 0.87 C +ATOM 248 CB TYR A 33 6.096 3.848 1.697 1.00 0.87 C +ATOM 249 O TYR A 33 6.951 4.778 4.781 1.00 0.85 O +ATOM 250 CG TYR A 33 5.467 5.214 1.823 1.00 0.82 C +ATOM 251 CD1 TYR A 33 6.053 6.330 1.230 1.00 0.78 C +ATOM 252 CD2 TYR A 33 4.286 5.392 2.535 1.00 0.79 C +ATOM 253 CE1 TYR A 33 5.478 7.591 1.344 1.00 0.79 C +ATOM 254 CE2 TYR A 33 3.701 6.649 2.656 1.00 0.80 C +ATOM 255 OH TYR A 33 3.729 8.986 2.174 1.00 0.70 O +ATOM 256 CZ TYR A 33 4.303 7.740 2.058 1.00 0.77 C +ATOM 257 N ALA A 34 5.922 2.819 4.647 1.00 0.87 N +ATOM 258 CA ALA A 34 5.360 2.871 5.994 1.00 0.86 C +ATOM 259 C ALA A 34 6.463 2.887 7.049 1.00 0.86 C +ATOM 260 CB ALA A 34 4.423 1.688 6.226 1.00 0.85 C +ATOM 261 O ALA A 34 6.407 3.666 8.003 1.00 0.84 O +ATOM 262 N ASN A 35 7.535 2.038 6.821 1.00 0.85 N +ATOM 263 CA ASN A 35 8.671 1.986 7.735 1.00 0.85 C +ATOM 264 C ASN A 35 9.414 3.318 7.782 1.00 0.85 C +ATOM 265 CB ASN A 35 9.627 0.859 7.341 1.00 0.83 C +ATOM 266 O ASN A 35 9.780 3.793 8.859 1.00 0.84 O +ATOM 267 CG ASN A 35 9.151 -0.502 7.811 1.00 0.77 C +ATOM 268 ND2 ASN A 35 9.716 -1.558 7.238 1.00 0.72 N +ATOM 269 OD1 ASN A 35 8.283 -0.602 8.682 1.00 0.72 O +ATOM 270 N ASP A 36 9.525 3.931 6.681 1.00 0.85 N +ATOM 271 CA ASP A 36 10.236 5.201 6.571 1.00 0.85 C +ATOM 272 C ASP A 36 9.455 6.328 7.244 1.00 0.84 C +ATOM 273 CB ASP A 36 10.496 5.545 5.103 1.00 0.84 C +ATOM 274 O ASP A 36 10.033 7.347 7.628 1.00 0.82 O +ATOM 275 CG ASP A 36 11.548 4.656 4.462 1.00 0.81 C +ATOM 276 OD1 ASP A 36 12.252 3.923 5.190 1.00 0.80 O +ATOM 277 OD2 ASP A 36 11.676 4.691 3.220 1.00 0.81 O +ATOM 278 N ASN A 37 8.203 6.079 7.495 1.00 0.84 N +ATOM 279 CA ASN A 37 7.356 7.122 8.063 1.00 0.83 C +ATOM 280 C ASN A 37 6.868 6.749 9.460 1.00 0.82 C +ATOM 281 CB ASN A 37 6.166 7.407 7.144 1.00 0.82 C +ATOM 282 O ASN A 37 5.944 7.372 9.986 1.00 0.79 O +ATOM 283 CG ASN A 37 6.556 8.201 5.913 1.00 0.78 C +ATOM 284 ND2 ASN A 37 6.669 7.521 4.779 1.00 0.75 N +ATOM 285 OD1 ASN A 37 6.755 9.417 5.981 1.00 0.75 O +ATOM 286 N GLY A 38 7.453 5.769 9.962 1.00 0.81 N +ATOM 287 CA GLY A 38 7.211 5.428 11.355 1.00 0.81 C +ATOM 288 C GLY A 38 5.825 4.862 11.599 1.00 0.80 C +ATOM 289 O GLY A 38 5.293 4.965 12.706 1.00 0.78 O +ATOM 290 N VAL A 39 5.211 4.356 10.565 1.00 0.81 N +ATOM 291 CA VAL A 39 3.875 3.783 10.690 1.00 0.81 C +ATOM 292 C VAL A 39 3.975 2.266 10.833 1.00 0.80 C +ATOM 293 CB VAL A 39 2.987 4.146 9.479 1.00 0.79 C +ATOM 294 O VAL A 39 4.523 1.588 9.961 1.00 0.78 O +ATOM 295 CG1 VAL A 39 1.615 3.483 9.597 1.00 0.74 C +ATOM 296 CG2 VAL A 39 2.842 5.662 9.359 1.00 0.74 C +ATOM 297 N ASP A 40 3.551 1.718 11.960 1.00 0.80 N +ATOM 298 CA ASP A 40 3.455 0.293 12.266 1.00 0.80 C +ATOM 299 C ASP A 40 2.020 -0.099 12.610 1.00 0.80 C +ATOM 300 CB ASP A 40 4.392 -0.072 13.419 1.00 0.77 C +ATOM 301 O ASP A 40 1.327 0.627 13.326 1.00 0.77 O +ATOM 302 CG ASP A 40 4.666 -1.563 13.510 1.00 0.70 C +ATOM 303 OD1 ASP A 40 4.161 -2.329 12.661 1.00 0.66 O +ATOM 304 OD2 ASP A 40 5.391 -1.976 14.441 1.00 0.69 O +ATOM 305 N GLY A 41 1.590 -1.138 12.053 1.00 0.84 N +ATOM 306 CA GLY A 41 0.215 -1.550 12.287 1.00 0.84 C +ATOM 307 C GLY A 41 -0.170 -2.803 11.523 1.00 0.84 C +ATOM 308 O GLY A 41 0.690 -3.480 10.957 1.00 0.82 O +ATOM 309 N GLU A 42 -1.503 -3.143 11.558 1.00 0.88 N +ATOM 310 CA GLU A 42 -2.049 -4.296 10.848 1.00 0.89 C +ATOM 311 C GLU A 42 -2.256 -3.985 9.368 1.00 0.89 C +ATOM 312 CB GLU A 42 -3.369 -4.743 11.481 1.00 0.86 C +ATOM 313 O GLU A 42 -2.842 -2.958 9.020 1.00 0.89 O +ATOM 314 CG GLU A 42 -3.861 -6.096 10.988 1.00 0.76 C +ATOM 315 CD GLU A 42 -5.196 -6.505 11.590 1.00 0.72 C +ATOM 316 OE1 GLU A 42 -5.685 -5.810 12.510 1.00 0.69 O +ATOM 317 OE2 GLU A 42 -5.757 -7.528 11.138 1.00 0.66 O +ATOM 318 N TRP A 43 -1.835 -4.906 8.549 1.00 0.89 N +ATOM 319 CA TRP A 43 -1.899 -4.797 7.095 1.00 0.90 C +ATOM 320 C TRP A 43 -3.033 -5.649 6.534 1.00 0.89 C +ATOM 321 CB TRP A 43 -0.568 -5.216 6.465 1.00 0.86 C +ATOM 322 O TRP A 43 -3.208 -6.802 6.935 1.00 0.86 O +ATOM 323 CG TRP A 43 0.575 -4.304 6.793 1.00 0.72 C +ATOM 324 CD1 TRP A 43 1.266 -4.238 7.971 1.00 0.63 C +ATOM 325 CD2 TRP A 43 1.159 -3.322 5.931 1.00 0.65 C +ATOM 326 CE2 TRP A 43 2.200 -2.698 6.653 1.00 0.58 C +ATOM 327 CE3 TRP A 43 0.902 -2.910 4.616 1.00 0.64 C +ATOM 328 NE1 TRP A 43 2.244 -3.274 7.893 1.00 0.72 N +ATOM 329 CH2 TRP A 43 2.713 -1.297 4.815 1.00 0.64 C +ATOM 330 CZ2 TRP A 43 2.985 -1.682 6.103 1.00 0.71 C +ATOM 331 CZ3 TRP A 43 1.685 -1.898 4.071 1.00 0.62 C +ATOM 332 N THR A 44 -3.816 -5.080 5.734 1.00 0.92 N +ATOM 333 CA THR A 44 -4.786 -5.846 4.958 1.00 0.92 C +ATOM 334 C THR A 44 -4.663 -5.525 3.471 1.00 0.93 C +ATOM 335 CB THR A 44 -6.225 -5.565 5.429 1.00 0.91 C +ATOM 336 O THR A 44 -4.140 -4.472 3.099 1.00 0.92 O +ATOM 337 CG2 THR A 44 -6.383 -5.856 6.918 1.00 0.83 C +ATOM 338 OG1 THR A 44 -6.541 -4.189 5.187 1.00 0.85 O +ATOM 339 N TYR A 45 -5.077 -6.452 2.697 1.00 0.92 N +ATOM 340 CA TYR A 45 -5.058 -6.286 1.248 1.00 0.93 C +ATOM 341 C TYR A 45 -6.389 -6.700 0.634 1.00 0.93 C +ATOM 342 CB TYR A 45 -3.920 -7.103 0.628 1.00 0.92 C +ATOM 343 O TYR A 45 -6.905 -7.782 0.924 1.00 0.92 O +ATOM 344 CG TYR A 45 -3.921 -7.097 -0.881 1.00 0.89 C +ATOM 345 CD1 TYR A 45 -3.816 -5.904 -1.593 1.00 0.84 C +ATOM 346 CD2 TYR A 45 -4.029 -8.283 -1.599 1.00 0.85 C +ATOM 347 CE1 TYR A 45 -3.819 -5.893 -2.984 1.00 0.87 C +ATOM 348 CE2 TYR A 45 -4.032 -8.285 -2.990 1.00 0.88 C +ATOM 349 OH TYR A 45 -3.930 -7.082 -5.049 1.00 0.77 O +ATOM 350 CZ TYR A 45 -3.927 -7.087 -3.672 1.00 0.87 C +ATOM 351 N ASP A 46 -6.999 -5.845 -0.037 1.00 0.93 N +ATOM 352 CA ASP A 46 -8.204 -6.119 -0.814 1.00 0.94 C +ATOM 353 C ASP A 46 -7.873 -6.315 -2.292 1.00 0.94 C +ATOM 354 CB ASP A 46 -9.219 -4.986 -0.648 1.00 0.93 C +ATOM 355 O ASP A 46 -7.511 -5.361 -2.984 1.00 0.94 O +ATOM 356 CG ASP A 46 -10.537 -5.263 -1.349 1.00 0.88 C +ATOM 357 OD1 ASP A 46 -10.578 -6.124 -2.254 1.00 0.85 O +ATOM 358 OD2 ASP A 46 -11.544 -4.612 -0.994 1.00 0.86 O +ATOM 359 N ASP A 47 -7.939 -7.452 -2.769 1.00 0.94 N +ATOM 360 CA ASP A 47 -7.551 -7.805 -4.132 1.00 0.94 C +ATOM 361 C ASP A 47 -8.500 -7.181 -5.152 1.00 0.94 C +ATOM 362 CB ASP A 47 -7.519 -9.325 -4.304 1.00 0.92 C +ATOM 363 O ASP A 47 -8.091 -6.848 -6.266 1.00 0.92 O +ATOM 364 CG ASP A 47 -6.905 -9.760 -5.623 1.00 0.84 C +ATOM 365 OD1 ASP A 47 -5.728 -9.431 -5.886 1.00 0.79 O +ATOM 366 OD2 ASP A 47 -7.605 -10.436 -6.408 1.00 0.82 O +ATOM 367 N ALA A 48 -9.764 -6.965 -4.788 1.00 0.94 N +ATOM 368 CA ALA A 48 -10.747 -6.413 -5.716 1.00 0.94 C +ATOM 369 C ALA A 48 -10.404 -4.972 -6.085 1.00 0.93 C +ATOM 370 CB ALA A 48 -12.148 -6.486 -5.113 1.00 0.92 C +ATOM 371 O ALA A 48 -10.521 -4.578 -7.248 1.00 0.92 O +ATOM 372 N THR A 49 -9.903 -4.238 -5.141 1.00 0.94 N +ATOM 373 CA THR A 49 -9.632 -2.822 -5.366 1.00 0.94 C +ATOM 374 C THR A 49 -8.129 -2.565 -5.440 1.00 0.93 C +ATOM 375 CB THR A 49 -10.250 -1.951 -4.257 1.00 0.93 C +ATOM 376 O THR A 49 -7.697 -1.418 -5.574 1.00 0.92 O +ATOM 377 CG2 THR A 49 -11.767 -2.109 -4.216 1.00 0.89 C +ATOM 378 OG1 THR A 49 -9.706 -2.342 -2.991 1.00 0.90 O +ATOM 379 N LYS A 50 -7.338 -3.598 -5.299 1.00 0.95 N +ATOM 380 CA LYS A 50 -5.881 -3.502 -5.296 1.00 0.95 C +ATOM 381 C LYS A 50 -5.397 -2.523 -4.230 1.00 0.95 C +ATOM 382 CB LYS A 50 -5.369 -3.074 -6.672 1.00 0.94 C +ATOM 383 O LYS A 50 -4.558 -1.663 -4.506 1.00 0.93 O +ATOM 384 CG LYS A 50 -5.881 -3.934 -7.818 1.00 0.89 C +ATOM 385 CD LYS A 50 -5.425 -5.381 -7.678 1.00 0.83 C +ATOM 386 CE LYS A 50 -5.814 -6.211 -8.894 1.00 0.81 C +ATOM 387 NZ LYS A 50 -5.729 -7.676 -8.616 1.00 0.72 N +ATOM 388 N THR A 51 -5.966 -2.637 -3.033 1.00 0.94 N +ATOM 389 CA THR A 51 -5.741 -1.638 -1.995 1.00 0.94 C +ATOM 390 C THR A 51 -5.179 -2.286 -0.733 1.00 0.94 C +ATOM 391 CB THR A 51 -7.040 -0.885 -1.653 1.00 0.94 C +ATOM 392 O THR A 51 -5.746 -3.254 -0.220 1.00 0.94 O +ATOM 393 CG2 THR A 51 -6.790 0.199 -0.610 1.00 0.89 C +ATOM 394 OG1 THR A 51 -7.560 -0.277 -2.842 1.00 0.88 O +ATOM 395 N PHE A 52 -4.090 -1.754 -0.338 1.00 0.93 N +ATOM 396 CA PHE A 52 -3.523 -2.066 0.968 1.00 0.94 C +ATOM 397 C PHE A 52 -3.952 -1.034 2.005 1.00 0.92 C +ATOM 398 CB PHE A 52 -1.995 -2.129 0.891 1.00 0.93 C +ATOM 399 O PHE A 52 -4.021 0.160 1.707 1.00 0.91 O +ATOM 400 CG PHE A 52 -1.476 -3.252 0.035 1.00 0.90 C +ATOM 401 CD1 PHE A 52 -1.231 -4.506 0.582 1.00 0.87 C +ATOM 402 CD2 PHE A 52 -1.233 -3.054 -1.318 1.00 0.87 C +ATOM 403 CE1 PHE A 52 -0.750 -5.547 -0.209 1.00 0.89 C +ATOM 404 CE2 PHE A 52 -0.753 -4.090 -2.114 1.00 0.89 C +ATOM 405 CZ PHE A 52 -0.511 -5.335 -1.557 1.00 0.88 C +ATOM 406 N THR A 53 -4.236 -1.518 3.175 1.00 0.92 N +ATOM 407 CA THR A 53 -4.570 -0.626 4.279 1.00 0.92 C +ATOM 408 C THR A 53 -3.738 -0.960 5.514 1.00 0.91 C +ATOM 409 CB THR A 53 -6.067 -0.706 4.629 1.00 0.91 C +ATOM 410 O THR A 53 -3.596 -2.130 5.876 1.00 0.90 O +ATOM 411 CG2 THR A 53 -6.427 0.279 5.736 1.00 0.83 C +ATOM 412 OG1 THR A 53 -6.842 -0.400 3.463 1.00 0.85 O +ATOM 413 N VAL A 54 -3.170 0.019 6.075 1.00 0.89 N +ATOM 414 CA VAL A 54 -2.471 -0.138 7.346 1.00 0.89 C +ATOM 415 C VAL A 54 -3.177 0.674 8.429 1.00 0.88 C +ATOM 416 CB VAL A 54 -0.991 0.294 7.237 1.00 0.86 C +ATOM 417 O VAL A 54 -3.410 1.874 8.263 1.00 0.86 O +ATOM 418 CG1 VAL A 54 -0.276 0.115 8.575 1.00 0.71 C +ATOM 419 CG2 VAL A 54 -0.284 -0.499 6.139 1.00 0.71 C +ATOM 420 N THR A 55 -3.541 -0.021 9.385 1.00 0.86 N +ATOM 421 CA THR A 55 -4.177 0.597 10.543 1.00 0.85 C +ATOM 422 C THR A 55 -3.320 0.414 11.793 1.00 0.84 C +ATOM 423 CB THR A 55 -5.580 0.012 10.789 1.00 0.83 C +ATOM 424 O THR A 55 -2.892 -0.701 12.100 1.00 0.82 O +ATOM 425 CG2 THR A 55 -6.298 0.759 11.909 1.00 0.71 C +ATOM 426 OG1 THR A 55 -6.353 0.118 9.587 1.00 0.74 O +ATOM 427 N GLU A 56 -3.008 1.512 12.452 1.00 0.81 N +ATOM 428 CA GLU A 56 -2.285 1.438 13.718 1.00 0.80 C +ATOM 429 C GLU A 56 -3.188 0.933 14.840 1.00 0.79 C +ATOM 430 CB GLU A 56 -1.703 2.805 14.086 1.00 0.78 C +ATOM 431 O GLU A 56 -4.390 1.210 14.849 1.00 0.76 O +ATOM 432 CG GLU A 56 -0.439 3.161 13.318 1.00 0.72 C +ATOM 433 CD GLU A 56 0.196 4.464 13.778 1.00 0.70 C +ATOM 434 OE1 GLU A 56 -0.519 5.318 14.350 1.00 0.68 O +ATOM 435 OE2 GLU A 56 1.417 4.633 13.564 1.00 0.65 O +TER 436 GLU A 56 +END diff --git a/FAProGen2_benchmark.png b/FAProGen2_benchmark.png new file mode 100644 index 0000000..20b46b0 Binary files /dev/null and b/FAProGen2_benchmark.png differ diff --git a/README.md b/README.md index 68fb2e4..1a8f079 100644 --- a/README.md +++ b/README.md @@ -73,6 +73,38 @@ print("Repr shape:", outputs['last_hidden_state'].shape) # (batch_size, sequenc # Step 5: start the repo if the code works for u! ``` +### ProGen2 + +```python +import torch +from faesm.progen2 import ProGenForCausalLM +from transformers import AutoTokenizer +device = 'cuda' if torch.cuda.is_available() else 'cpu' +model = ProGenForCausalLM.from_pretrained("jinyuan22/ProGen2-small").to(torch.float16).to(device).eval() +tokenizer = AutoTokenizer.from_pretrained("jinyuan22/ProGen2-small") + +# sequence = "1" + "ACDEFGHIKLMNPQRSTVWY" * 50 + "2" # 1002 token + +sequence = "2GFLPFRGADEGLAAREAATLAARGTAARAYREDSWAVPVPRGLLGDLTARVAALGAASPPPADPLAVTLDLHHVTAEVALTTVLDAATLVHGQTRVLSAEDAAEAATAAAAATEAYLERLQDFVLFMSASVRVWRRGNAAGATGPEWDQWYTVADRDALGSAPTHLAVLGRQADALCHFVLDRVAWGTCGTPLWSGDEDLGNVVATFAGYADRLATAPRDLIM1" + +inputs = tokenizer(sequence, return_tensors="pt").to(device) + +with torch.no_grad(): + logits = model(inputs.input_ids, labels=inputs.input_ids).logits + +logits = logits[0][:-1, ...] +target = inputs.input_ids[0, 1:] + +# remove unused logits +first_token, last_token = 5, 29 +logits = logits[:, first_token:(last_token+1)] +target = target - first_token + +ce_eval = torch.nn.functional.cross_entropy(input=logits.view(-1, logits.size(-1)), target=target.view(-1), reduction="mean").item() +print(ce_eval) +assert abs(ce_eval - 2.4) < 0.1 +``` + ### Training \[WIP\] Working on an example training script for MLM training on Uniref50. For now, you can use the same training logic as how you would train the official ESM since the FAESM has no difference in the model architecture. diff --git a/data.ipynb b/data.ipynb new file mode 100644 index 0000000..317c078 --- /dev/null +++ b/data.ipynb @@ -0,0 +1,31 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from faesm.esmfold import FAEsmForProteinFolding\n", + "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", + "esmfold = FAEsmForProteinFolding.from_pretrained(\"facebook/esmfold_v1\")\n", + "esmfold.esm.half()\n", + "esmfold = esmfold.to(device).eval()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "torch", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.11.10" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/esmfold_benchmark.png b/esmfold_benchmark.png new file mode 100644 index 0000000..e690827 Binary files /dev/null and b/esmfold_benchmark.png differ diff --git a/faesm/esm.py b/faesm/esm.py index 2db4b50..32c2ad5 100644 --- a/faesm/esm.py +++ b/faesm/esm.py @@ -589,14 +589,19 @@ def forward( attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, + output_hidden_states=output_hidden_states, # For the hidden states ) sequence_output = outputs[0] logits = self.lm_head(sequence_output) - result = { - "logits": logits, - "last_hidden_state": sequence_output, - } + if outputs.hidden_states is not None: + result = { + "logits": logits, + "last_hidden_state": sequence_output, + "hidden_states": [x.unsqueeze(0) for x in outputs.hidden_states], + } + else: + result = {"logits": logits, "last_hidden_state": sequence_output} return result @classmethod diff --git a/faesm/esmfold.py b/faesm/esmfold.py new file mode 100644 index 0000000..57fee35 --- /dev/null +++ b/faesm/esmfold.py @@ -0,0 +1,2421 @@ +import math +import sys +from dataclasses import dataclass +from functools import partial +from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +from torch.nn import LayerNorm + +from transformers.models.esm.modeling_esm import ESM_START_DOCSTRING, EsmModel, EsmPreTrainedModel +from faesm.esm import FAEsmModel +from transformers.modeling_outputs import ModelOutput +from transformers.integrations.deepspeed import is_deepspeed_available +from transformers.utils import ( + ContextManagers, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_scipy_available, + logging, + replace_return_docstrings, +) +from transformers.models.esm.configuration_esm import EsmConfig + +from transformers.models.esm.openfold_utils import ( + OFProtein, + Rigid, + Rotation, + atom14_to_atom37, + chunk_layer, + compute_predicted_aligned_error, + compute_tm, + frames_and_literature_positions_to_atom14_pos, + make_atom14_masks, + residue_constants, + to_pdb, + torsion_angles_to_frames, +) + +logger = logging.get_logger(__name__) +_CHECKPOINT_FOR_DOC = "facebook/esmfold_v1" +_CONFIG_FOR_DOC = "EsmConfig" + +@add_start_docstrings( + """ + ESMForProteinFolding is the HuggingFace port of the original ESMFold model. It consists of an ESM-2 "stem" followed + by a protein folding "head", although unlike most other output heads, this "head" is similar in size and runtime to + the rest of the model combined! It outputs a dictionary containing predicted structural information about the input + protein(s). + """, + ESM_START_DOCSTRING, +) + +@dataclass +class EsmForProteinFoldingOutput(ModelOutput): + """ + Output type of [`EsmForProteinFoldingOutput`]. + + Args: + frames (`torch.FloatTensor`): + Output frames. + sidechain_frames (`torch.FloatTensor`): + Output sidechain frames. + unnormalized_angles (`torch.FloatTensor`): + Predicted unnormalized backbone and side chain torsion angles. + angles (`torch.FloatTensor`): + Predicted backbone and side chain torsion angles. + positions (`torch.FloatTensor`): + Predicted positions of the backbone and side chain atoms. + states (`torch.FloatTensor`): + Hidden states from the protein folding trunk. + s_s (`torch.FloatTensor`): + Per-residue embeddings derived by concatenating the hidden states of each layer of the ESM-2 LM stem. + s_z (`torch.FloatTensor`): + Pairwise residue embeddings. + distogram_logits (`torch.FloatTensor`): + Input logits to the distogram used to compute residue distances. + lm_logits (`torch.FloatTensor`): + Logits output by the ESM-2 protein language model stem. + aatype (`torch.FloatTensor`): + Input amino acids (AlphaFold2 indices). + atom14_atom_exists (`torch.FloatTensor`): + Whether each atom exists in the atom14 representation. + residx_atom14_to_atom37 (`torch.FloatTensor`): + Mapping between atoms in the atom14 and atom37 representations. + residx_atom37_to_atom14 (`torch.FloatTensor`): + Mapping between atoms in the atom37 and atom14 representations. + atom37_atom_exists (`torch.FloatTensor`): + Whether each atom exists in the atom37 representation. + residue_index (`torch.FloatTensor`): + The index of each residue in the protein chain. Unless internal padding tokens are used, this will just be + a sequence of integers from 0 to `sequence_length`. + lddt_head (`torch.FloatTensor`): + Raw outputs from the lddt head used to compute plddt. + plddt (`torch.FloatTensor`): + Per-residue confidence scores. Regions of low confidence may indicate areas where the model's prediction is + uncertain, or where the protein structure is disordered. + ptm_logits (`torch.FloatTensor`): + Raw logits used for computing ptm. + ptm (`torch.FloatTensor`): + TM-score output representing the model's high-level confidence in the overall structure. + aligned_confidence_probs (`torch.FloatTensor`): + Per-residue confidence scores for the aligned structure. + predicted_aligned_error (`torch.FloatTensor`): + Predicted error between the model's prediction and the ground truth. + max_predicted_aligned_error (`torch.FloatTensor`): + Per-sample maximum predicted error. + """ + + frames: torch.FloatTensor = None + sidechain_frames: torch.FloatTensor = None + unnormalized_angles: torch.FloatTensor = None + angles: torch.FloatTensor = None + positions: torch.FloatTensor = None + states: torch.FloatTensor = None + s_s: torch.FloatTensor = None + s_z: torch.FloatTensor = None + distogram_logits: torch.FloatTensor = None + lm_logits: torch.FloatTensor = None + aatype: torch.FloatTensor = None + atom14_atom_exists: torch.FloatTensor = None + residx_atom14_to_atom37: torch.FloatTensor = None + residx_atom37_to_atom14: torch.FloatTensor = None + atom37_atom_exists: torch.FloatTensor = None + residue_index: torch.FloatTensor = None + lddt_head: torch.FloatTensor = None + plddt: torch.FloatTensor = None + ptm_logits: torch.FloatTensor = None + ptm: torch.FloatTensor = None + aligned_confidence_probs: torch.FloatTensor = None + predicted_aligned_error: torch.FloatTensor = None + max_predicted_aligned_error: torch.FloatTensor = None + + +ESMFOLD_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + masking_pattern (`torch.LongTensor` of shape `({0})`, *optional*): + Locations of tokens to mask during training as a form of regularization. Mask values selected in `[0, 1]`. + num_recycles (`int`, *optional*, defaults to `None`): + Number of times to recycle the input sequence. If `None`, defaults to `config.num_recycles`. "Recycling" + consists of passing the output of the folding trunk back in as input to the trunk. During training, the + number of recycles should vary with each batch, to ensure that the model learns to output valid predictions + after each recycle. During inference, num_recycles should be set to the highest value that the model was + trained with for maximum accuracy. Accordingly, when this value is set to `None`, config.max_recycles is + used. +""" + + +def is_fp16_enabled(): + # Autocast world + fp16_enabled = torch.get_autocast_gpu_dtype() == torch.float16 + fp16_enabled = fp16_enabled and torch.is_autocast_enabled() + + return fp16_enabled + + +def is_deepspeed_initialized(): + if is_deepspeed_available(): + return False + else: + try: + import deepspeed + + # This is not available in all DeepSpeed versions. + return deepspeed.utils.is_initialized() + except Exception: + return False + + +def collate_dense_tensors(samples: List[torch.Tensor], pad_v: float = 0) -> torch.Tensor: + """ + Takes a list of tensors with the following dimensions: + [(d_11, ..., d_1K), + (d_21, ..., d_2K), ..., (d_N1, ..., d_NK)] + and stack + pads them into a single tensor of: + (N, max_i=1,N { d_i1 }, ..., max_i=1,N {diK}) + """ + if len(samples) == 0: + return torch.Tensor() + if len({x.dim() for x in samples}) != 1: + raise RuntimeError(f"Samples has varying dimensions: {[x.dim() for x in samples]}") + (device,) = tuple({x.device for x in samples}) # assumes all on same device + max_shape = [max(lst) for lst in zip(*[x.shape for x in samples])] + result = torch.empty(len(samples), *max_shape, dtype=samples[0].dtype, device=device) + result.fill_(pad_v) + for i in range(len(samples)): + result_i = result[i] + t = samples[i] + result_i[tuple(slice(0, k) for k in t.shape)] = t + return result + + +def flatten_final_dims(t: torch.Tensor, no_dims: int): + return t.reshape(t.shape[:-no_dims] + (-1,)) + + +def permute_final_dims(tensor: torch.Tensor, inds: List[int]): + zero_index = -1 * len(inds) + first_inds = list(range(len(tensor.shape[:zero_index]))) + return tensor.permute(first_inds + [zero_index + i for i in inds]) + + +def dict_multimap(fn, dicts): + first = dicts[0] + new_dict = {} + for k, v in first.items(): + all_v = [d[k] for d in dicts] + if isinstance(v, dict): + new_dict[k] = dict_multimap(fn, all_v) + else: + new_dict[k] = fn(all_v) + + return new_dict + + +def trunc_normal_init_(weights, scale=1.0, fan="fan_in"): + shape = weights.shape + scale = scale / max(1, shape[1]) + + if not is_scipy_available(): + logger.warning( + "This init requires scipy, but scipy was not found, default to an approximation that might not be" + " equivalent." + ) + std = math.sqrt(scale) + torch.nn.init.normal_(weights, std=std).clamp(min=0.0, max=2.0 * std) + + else: + from scipy.stats import truncnorm + + std = math.sqrt(scale) / truncnorm.std(a=-2, b=2, loc=0, scale=1) + samples = truncnorm.rvs(a=-2, b=2, loc=0, scale=std, size=weights.numel()) + samples = np.reshape(samples, shape) + weights.copy_(torch.tensor(samples, device=weights.device)) + + +def ipa_point_weights_init_(weights): + with torch.no_grad(): + softplus_inverse_1 = 0.541324854612918 + weights.fill_(softplus_inverse_1) + + +class EsmFoldLinear(nn.Linear): + """ + A Linear layer with built-in nonstandard initializations. Called just like torch.nn.Linear. + + Implements the initializers in 1.11.4, plus some additional ones found in the code. + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + bias: bool = True, + init: str = "default", + init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None, + ): + """ + Args: + in_dim: + The final dimension of inputs to the layer + out_dim: + The final dimension of layer outputs + bias: + Whether to learn an additive bias. True by default + init: + The initializer to use. Choose from: + + "default": LeCun fan-in truncated normal initialization "relu": He initialization w/ truncated normal + distribution "glorot": Fan-average Glorot uniform initialization "gating": Weights=0, Bias=1 "normal": + Normal initialization with std=1/sqrt(fan_in) "final": Weights=0, Bias=0 + + Overridden by init_fn if the latter is not None. + init_fn: + A custom initializer taking weight and bias as inputs. Overrides init if not None. + """ + super().__init__(in_dim, out_dim, bias=bias) + + if bias: + with torch.no_grad(): + self.bias.fill_(0) + self.init = init + self.init_fn = init_fn + + if init not in ["default", "relu", "glorot", "gating", "normal", "final"]: + raise ValueError("Invalid init string.") + + +class EsmFoldLayerNorm(nn.Module): + def __init__(self, c_in, eps=1e-5): + super().__init__() + + self.c_in = (c_in,) + self.eps = eps + + self.weight = nn.Parameter(torch.ones(c_in)) + self.bias = nn.Parameter(torch.zeros(c_in)) + + def forward(self, x): + d = x.dtype + if d is torch.bfloat16 and not is_deepspeed_initialized(): + with torch.cuda.amp.autocast(enabled=False): + out = nn.functional.layer_norm(x, self.c_in, self.weight.to(dtype=d), self.bias.to(dtype=d), self.eps) + else: + out = nn.functional.layer_norm(x, self.c_in, self.weight, self.bias, self.eps) + + return out + + +@torch.jit.ignore +def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor: + """ + Softmax, but without automatic casting to fp32 when the input is of type bfloat16 + """ + d = t.dtype + if d is torch.bfloat16 and not is_deepspeed_initialized(): + with torch.cuda.amp.autocast(enabled=False): + s = torch.nn.functional.softmax(t, dim=dim) + else: + s = torch.nn.functional.softmax(t, dim=dim) + + return s + + +class EsmFoldAttention(nn.Module): + """ + Standard multi-head attention using AlphaFold's default layer initialization. Allows multiple bias vectors. + """ + + def __init__( + self, + c_q: int, + c_k: int, + c_v: int, + c_hidden: int, + no_heads: int, + gating: bool = True, + ): + """ + Args: + c_q: + Input dimension of query data + c_k: + Input dimension of key data + c_v: + Input dimension of value data + c_hidden: + Per-head hidden dimension + no_heads: + Number of attention heads + gating: + Whether the output should be gated using query data + """ + super().__init__() + + self.c_q = c_q + self.c_k = c_k + self.c_v = c_v + self.c_hidden = c_hidden + self.no_heads = no_heads + self.gating = gating + + # DISCREPANCY: c_hidden is not the per-head channel dimension, as + # stated in the supplement, but the overall channel dimension. + + self.linear_q = EsmFoldLinear(self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot") + self.linear_k = EsmFoldLinear(self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot") + self.linear_v = EsmFoldLinear(self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot") + self.linear_o = EsmFoldLinear(self.c_hidden * self.no_heads, self.c_q, init="final") + + self.linear_g = None + if self.gating: + self.linear_g = EsmFoldLinear(self.c_q, self.c_hidden * self.no_heads, init="gating") + + self.sigmoid = nn.Sigmoid() + + def _prep_qkv(self, q_x: torch.Tensor, kv_x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # [*, Q/K/V, H * C_hidden] + q = self.linear_q(q_x) + k = self.linear_k(kv_x) + v = self.linear_v(kv_x) + + # [*, Q/K, H, C_hidden] + q = q.view(q.shape[:-1] + (self.no_heads, -1)) + k = k.view(k.shape[:-1] + (self.no_heads, -1)) + v = v.view(v.shape[:-1] + (self.no_heads, -1)) + + # [*, H, Q/K, C_hidden] + q = q.transpose(-2, -3) + k = k.transpose(-2, -3) + v = v.transpose(-2, -3) + + q /= math.sqrt(self.c_hidden) + + return q, k, v + + def _wrap_up(self, o: torch.Tensor, q_x: torch.Tensor) -> torch.Tensor: + if self.linear_g is not None: + g = self.sigmoid(self.linear_g(q_x)) + + # [*, Q, H, C_hidden] + g = g.view(g.shape[:-1] + (self.no_heads, -1)) + o = o * g + + # [*, Q, H * C_hidden] + o = flatten_final_dims(o, 2) + + # [*, Q, C_q] + o = self.linear_o(o) + + return o + + def forward( + self, + q_x: torch.Tensor, + kv_x: torch.Tensor, + biases: Optional[List[torch.Tensor]] = None, + use_memory_efficient_kernel: bool = False, + use_lma: bool = False, + lma_q_chunk_size: int = 1024, + lma_kv_chunk_size: int = 4096, + use_flash: bool = False, + flash_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Args: + q_x: + [*, Q, C_q] query data + kv_x: + [*, K, C_k] key data + biases: + List of biases that broadcast to [*, H, Q, K] + use_memory_efficient_kernel: + Whether to use a custom memory-efficient attention kernel. This should be the default choice for most. + If none of the "use_<...>" flags are True, a stock PyTorch implementation is used instead + use_lma: + Whether to use low-memory attention (Staats & Rabe 2021). If none of the "use_<...>" flags are True, a + stock PyTorch implementation is used instead + lma_q_chunk_size: + Query chunk size (for LMA) + lma_kv_chunk_size: + Key/Value chunk size (for LMA) + Returns + [*, Q, C_q] attention update + """ + if use_lma and (lma_q_chunk_size is None or lma_kv_chunk_size is None): + raise ValueError("If use_lma is specified, lma_q_chunk_size and lma_kv_chunk_size must be provided") + + if use_flash and biases is not None: + raise ValueError("use_flash is incompatible with the bias option. For masking, use flash_mask instead") + + attn_options = [use_memory_efficient_kernel, use_lma, use_flash] + if sum(attn_options) > 1: + raise ValueError("Choose at most one alternative attention algorithm") + + if biases is None: + biases = [] + + # [*, H, Q/K, C_hidden] + query, key, value = self._prep_qkv(q_x, kv_x) + key = permute_final_dims(key, (1, 0)) + + # [*, H, Q, K] + output = torch.matmul(query, key) # q / sqrt(d) * k^T + for b in biases: + output += b + output = softmax_no_cast(output, -1) # softmax(q / sqrt(d) * k^T) + + # [*, H, Q, C_hidden] + output = torch.matmul(output, value) # softmax(q / sqrt(d) * k^T) * v + output = output.transpose(-2, -3) + output = self._wrap_up(output, q_x) + # breakpoint() + return output + +# import math +# from typing import List, Optional, Tuple +# import torch +# import torch.nn as nn +from flash_attn import flash_attn_qkvpacked_func +# from torch.utils.checkpoint import checkpoint +class EsmFoldFlashSelfAttention(nn.Module): + def __init__(self, embed_dim, num_heads, head_width, gated=False): + super().__init__() + assert embed_dim == num_heads * head_width + + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_width = head_width + + self.proj = nn.Linear(embed_dim, embed_dim * 3, bias=False) + self.o_proj = nn.Linear(embed_dim, embed_dim, bias=True) + self.gated = gated + if gated: + self.g_proj = nn.Linear(embed_dim, embed_dim) + torch.nn.init.zeros_(self.g_proj.weight) + torch.nn.init.ones_(self.g_proj.bias) + + self.rescale_factor = self.head_width**-0.5 + + torch.nn.init.zeros_(self.o_proj.bias) + + def forward(self, x, mask=None, bias=None, indices=None): + """ + Basic self attention with optional mask and external pairwise bias. To handle sequences of different lengths, + use mask. + + Inputs: + x: batch of input sequneces (.. x L x C) mask: batch of boolean masks where 1=valid, 0=padding position (.. + x L_k) bias: batch of scalar pairwise attention biases (.. x Lq x Lk x num_heads) + + Outputs: + sequence projection (B x L x embed_dim), attention maps (B x L x L x num_heads) + """ + + t = self.proj(x).view(*x.shape[:2], self.num_heads, -1) + t = t.permute(0, 2, 1, 3) + # q, k, v = t.chunk(3, dim=-1) + + # q = self.rescale_factor * q + # a = torch.einsum("...qc,...kc->...qk", q, k) + + # Add external attention bias. + # if bias is not None: + # a = a + bias.permute(0, 3, 1, 2) + if mask is not None: + mask = mask[:, None, :, :] + + # Do not attend to padding tokens. + # if mask is not None: + # mask = mask[:, None, None] + # a = a.masked_fill(mask == False, -np.inf) # noqa: E712 + y = flash_attn_qkvpacked_func( + t, + dropout_p=0.0, # Set to 0.0 during evaluation + softmax_scale=self.rescale_factor, + causal=False, # Assuming no causal masking, modify if necessary + window_size=(-1, -1), # Full attention + alibi_slopes=None, # No external bias + deterministic=False + ) + + # a = nn.functional.softmax(a, dim=-1) + + # y = torch.einsum("...hqk,...hkc->...qhc", a, v) + # y = y.reshape(*y.shape[:2], -1) + + if self.gated: + y = self.g_proj(x).sigmoid() * y + y = self.o_proj(y) + breakpoint() + return y, None#, a.permute(0, 3, 1, 2) + +class EsmFoldTriangleAttention(nn.Module): + def __init__(self, c_in, c_hidden, no_heads, starting=True, inf=1e9): + """ + Args: + c_in: + Input channel dimension + c_hidden: + Overall hidden channel dimension (not per-head) + no_heads: + Number of attention heads + """ + super().__init__() + + self.c_in = c_in + self.c_hidden = c_hidden + self.no_heads = no_heads + self.starting = starting + self.inf = inf + + self.layer_norm = LayerNorm(self.c_in) + + self.linear = EsmFoldLinear(c_in, self.no_heads, bias=False, init="normal") + + self.mha = EsmFoldAttention(self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads) + # self.mha = EsmFoldFlashAttention(self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads) + + @torch.jit.ignore + def _chunk( + self, + x: torch.Tensor, + biases: List[torch.Tensor], + chunk_size: int, + use_memory_efficient_kernel: bool = False, + use_lma: bool = False, + inplace_safe: bool = False, + ) -> torch.Tensor: + "triangle! triangle!" + mha_inputs = { + "q_x": x, + "kv_x": x, + "biases": biases, + } + + return chunk_layer( + partial(self.mha, use_memory_efficient_kernel=use_memory_efficient_kernel, use_lma=use_lma), + mha_inputs, + chunk_size=chunk_size, + no_batch_dims=len(x.shape[:-2]), + _out=x if inplace_safe else None, + ) + + def forward( + self, + x: torch.Tensor, + mask: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None, + use_memory_efficient_kernel: bool = False, + use_lma: bool = False, + inplace_safe: bool = False, + ) -> torch.Tensor: + """ + Args: + x: + [*, I, J, C_in] input tensor (e.g. the pair representation) + Returns: + [*, I, J, C_in] output tensor + """ + if mask is None: + # [*, I, J] + mask = x.new_ones( + x.shape[:-1], + ) + + if not self.starting: + x = x.transpose(-2, -3) + mask = mask.transpose(-1, -2) + + # [*, I, J, C_in] + x = self.layer_norm(x) + + # [*, I, 1, 1, J] + mask_bias = (self.inf * (mask - 1))[..., :, None, None, :] + + # [*, H, I, J] + triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1)) + + # [*, 1, H, I, J] + triangle_bias = triangle_bias.unsqueeze(-4) + + biases = [mask_bias, triangle_bias] + + if chunk_size is not None: + x = self._chunk( + x, + biases, + chunk_size, + use_memory_efficient_kernel=use_memory_efficient_kernel, + use_lma=use_lma, + inplace_safe=inplace_safe, + ) + breakpoint() + else: + # x = self.mha( + # q_x=x, kv_x=x, biases=biases, use_memory_efficient_kernel=use_memory_efficient_kernel, use_lma=use_lma + # ) + x = self.mha( + q_x=x, kv_x=x, biases=biases + ) + + if not self.starting: + x = x.transpose(-2, -3) + + return x + + +class EsmFoldTriangleMultiplicativeUpdate(nn.Module): + """ + Implements Algorithms 11 and 12. + """ + + def __init__(self, config, _outgoing=True): + super().__init__() + c_hidden = config.pairwise_state_dim + self._outgoing = _outgoing + + self.linear_a_p = EsmFoldLinear(c_hidden, c_hidden) + self.linear_a_g = EsmFoldLinear(c_hidden, c_hidden, init="gating") + self.linear_b_p = EsmFoldLinear(c_hidden, c_hidden) + self.linear_b_g = EsmFoldLinear(c_hidden, c_hidden, init="gating") + self.linear_g = EsmFoldLinear(c_hidden, c_hidden, init="gating") + self.linear_z = EsmFoldLinear(c_hidden, c_hidden, init="final") + + self.layer_norm_in = LayerNorm(c_hidden) + self.layer_norm_out = LayerNorm(c_hidden) + + self.sigmoid = nn.Sigmoid() + + def _combine_projections( + self, a: torch.Tensor, b: torch.Tensor, _inplace_chunk_size: Optional[int] = None + ) -> torch.Tensor: + if self._outgoing: + a = permute_final_dims(a, (2, 0, 1)) + b = permute_final_dims(b, (2, 1, 0)) + else: + a = permute_final_dims(a, (2, 1, 0)) + b = permute_final_dims(b, (2, 0, 1)) + + if _inplace_chunk_size is not None: + # To be replaced by torch vmap + for i in range(0, a.shape[-3], _inplace_chunk_size): + a_chunk = a[..., i : i + _inplace_chunk_size, :, :] + b_chunk = b[..., i : i + _inplace_chunk_size, :, :] + a[..., i : i + _inplace_chunk_size, :, :] = torch.matmul( + a_chunk, + b_chunk, + ) + + p = a + else: + p = torch.matmul(a, b) + + return permute_final_dims(p, (1, 2, 0)) + + def _inference_forward( + self, + z: torch.Tensor, + mask: Optional[torch.Tensor] = None, + inplace_chunk_size: Optional[int] = None, + with_add: bool = True, + ): + """ + Args: + z: + A [*, N, N, C_z] pair representation + mask: + A [*, N, N] pair mask + inplace_chunk_size: + Size of chunks used in the main computation. Increase to trade memory for speed. + with_add: + If True, z is overwritten with (z + update). Otherwise, it is overwritten with (update). + Returns: + A reference to the overwritten z + + More memory-efficient, inference-only version of the forward function. Uses in-place operations, fusion of the + addition that happens after this module in the Evoformer, a smidge of recomputation, and a cache of overwritten + values to lower peak memory consumption of this module from 5x the size of the input tensor z to 2.5x its size. + Useful for inference on extremely long sequences. + + It works as follows. We will make reference to variables used in the default forward implementation below. + Naively, triangle multiplication attention requires the manifestation of 5 tensors the size of z: 1) z, the + "square" input tensor, 2) a, the first projection of z, 3) b, the second projection of b, 4) g, a z-sized mask, + and 5) a z-sized tensor for intermediate computations. For large N, this is prohibitively expensive; for + N=4000, for example, z is more than 8GB alone. To avoid this problem, we compute b, g, and all intermediate + tensors in small chunks, noting that the chunks required to compute a chunk of the output depend only on the + tensor a and corresponding vertical and horizontal chunks of z. This suggests an algorithm that loops over + pairs of chunks of z: hereafter "columns" and "rows" of z, even though each "column" and "row" in fact contains + inplace_chunk_size contiguous true columns and rows of z. Writing output chunks to a new tensor would bring + total memory consumption down to 3x the size of z. However, more memory can be saved by writing output chunks + directly to z in-place. WLOG, we choose to write output chunks vertically, overwriting the ith "column" of z at + the end of the ith iteration of the main loop. Despite this overwriting, the ith column is always one column + ahead of previously overwritten columns and can be recovered directly from z. After the first iteration, + however, the ith row of z is always at least partially overwritten. For this reason, we introduce the z-cache, + a tensor one-half the size of z. The z-cache initially contains the left half (2nd and 3rd quadrants) of z. For + 0 < i < N/2, the missing left part of the ith row of z is recovered from this cache at the beginning of the ith + iteration. Once i exceeds n/2, the cache is "reoriented" to encompass the 3rd and 4th quadrants of z instead. + Though the 3rd quadrant of the original z is entirely overwritten at this point, it can be recovered from the + z-cache itself. Thereafter, the ith row of z can be recovered in its entirety from the reoriented z-cache. + After the final iteration, z has been completely overwritten and contains the triangular multiplicative update. + If with_add is True, it instead contains the sum of z and the triangular multiplicative update. In either case, + peak memory consumption is just 2.5x the size of z, disregarding memory used for chunks and other small + variables. + """ + if mask is None: + mask = z.new_ones(z.shape[:-1]) + + mask = mask.unsqueeze(-1) + + def compute_projection_helper(pair, mask, a=True): + if a: + linear_g = self.linear_a_g + linear_p = self.linear_a_p + else: + linear_g = self.linear_b_g + linear_p = self.linear_b_p + + pair = self.layer_norm_in(pair) + p = linear_g(pair) + p.sigmoid_() + p *= linear_p(pair) + p *= mask + p = permute_final_dims(p, (2, 0, 1)) + return p + + def compute_projection(pair, mask, a=True, chunked=True): + need_transpose = self._outgoing ^ a + if not chunked: + p = compute_projection_helper(pair, mask, a) + if need_transpose: + p = p.transpose(-1, -2) + else: + # This computation is chunked so as not to exceed our 2.5x + # budget with a large intermediate tensor + linear_g = self.linear_a_g if a else self.linear_b_g + c = linear_g.bias.shape[-1] + out_shape = pair.shape[:-3] + (c,) + pair.shape[-3:-1] + p = pair.new_zeros(out_shape) + for i in range(0, pair.shape[-3], inplace_chunk_size): + pair_chunk = pair[..., i : i + inplace_chunk_size, :, :] + pair_chunk = compute_projection_helper( + pair[..., i : i + inplace_chunk_size, :, :], + mask[..., i : i + inplace_chunk_size, :, :], + a, + ) + if need_transpose: + pair_chunk = pair_chunk.transpose(-1, -2) + p[..., i : i + inplace_chunk_size] = pair_chunk + else: + p[..., i : i + inplace_chunk_size, :] = pair_chunk + + del pair_chunk + + return p + + # We start by fully manifesting a. In addition to the input, this + # brings total memory consumption to 2x z (disregarding size of chunks) + # [*, N, N, c] + a = compute_projection(z, mask, True, chunked=True) + + if inplace_chunk_size is not None: + n = a.shape[-1] + half_n = n // 2 + n % 2 + row_dim = -3 + col_dim = -2 + b_chunk_dim = row_dim if self._outgoing else col_dim + + def empty_slicer(t): + return [slice(None) for _ in t.shape] + + def slice_tensor(t, start, end, dim): + # Slices start:end from the dim dimension of t + s = empty_slicer(t) + s[dim] = slice(start, end) + return t[s] + + def flip_z_cache_(z_cache, z): + # "Reorient" the z_cache (see below), filling it with quadrants + # 3---recovered from the z_cache---and 4---recovered from z--- + # of the input tensor z. + quadrant_3 = slice_tensor(z_cache, half_n, None, row_dim) + z_cache = z_cache.transpose(row_dim, col_dim) + + # If n is odd, we need to shrink the z_cache by one row + z_cache = z_cache[..., : (n // 2), :, :] + + # Move the 3rd quadrant of z into the + first_half_slicer = empty_slicer(z_cache) + first_half_slicer[col_dim] = slice(0, half_n) + z_cache[first_half_slicer] = quadrant_3 + + # Get the fourth quadrant of z + quadrant_4 = slice_tensor(z, half_n, None, row_dim) + quadrant_4 = slice_tensor(quadrant_4, half_n, None, col_dim) + + # Insert said quadrant into the rotated z-cache + quadrant_3_slicer = empty_slicer(z_cache) + quadrant_3_slicer[col_dim] = slice(half_n, None) + + z_cache[quadrant_3_slicer] = quadrant_4 + + return z_cache + + # Initialize the z cache to the left half of z. + z_cache_shape = list(z.shape) + z_cache_shape[col_dim] = half_n + z_cache = z.new_zeros(z_cache_shape) + z_cache_slicer = empty_slicer(z_cache) + z_cache_slicer[col_dim] = slice(0, half_n) + z_cache.copy_(z[z_cache_slicer]) + z_cache_rotated = False + + # We need to reorient the z-cache at the halfway point, and we + # don't want a single chunk to straddle that point. We contract one + # of the chunks in the middle to address that problem. + i_range = list(range(0, half_n, inplace_chunk_size)) + initial_offsets = [i_2 - i_1 for i_1, i_2 in zip(i_range, i_range[1:] + [half_n])] + after_half = list(range(half_n, n, inplace_chunk_size)) + after_half_offsets = [inplace_chunk_size for _ in after_half] + combined_range_with_offsets = zip(i_range + after_half, initial_offsets + after_half_offsets) + for i, offset in combined_range_with_offsets: + if not z_cache_rotated and i >= half_n: + z_cache = flip_z_cache_(z_cache, z) + z_cache_rotated = True + + z_chunk_b = slice_tensor(z, i, i + offset, b_chunk_dim) + mask_chunk = slice_tensor(mask, i, i + offset, b_chunk_dim) + + z_chunk_b = z_chunk_b.clone() + if b_chunk_dim == col_dim: + z_chunk_b = slice_tensor(z, i, i + offset, col_dim) + else: # b_chunk_dim == row_dim + # In this case, the b-dimension (b_chunk_dim) is partially + # overwritten at the end of each iteration. We need to + # restore the missing component from the z-cache. + if not z_cache_rotated: + z_chunk_slicer = empty_slicer(z_chunk_b) + z_chunk_slicer[col_dim] = slice(0, half_n) + z_chunk_b[z_chunk_slicer] = slice_tensor(z_cache, i, i + offset, row_dim) + else: + z_cache_offset = i - half_n + z_chunk_b = slice_tensor(z_cache, z_cache_offset, z_cache_offset + offset, row_dim) + + b_chunk = compute_projection(z_chunk_b, mask_chunk, a=False, chunked=False) + del z_chunk_b + + x_chunk = torch.matmul(a, b_chunk) + x_chunk = permute_final_dims(x_chunk, (1, 2, 0)) + x_chunk = self.layer_norm_out(x_chunk) + x_chunk = self.linear_z(x_chunk) + + # The g dimension (col_dim) is parallel to and ahead of the + # overwrites in z. We can extract the g chunk normally. + z_chunk_g = slice_tensor(z, i, i + offset, col_dim) + g_chunk = self.linear_g(self.layer_norm_in(z_chunk_g)) + g_chunk.sigmoid_() + del z_chunk_g + + x_chunk *= g_chunk + + # Write the columns into z in-place + z_slicer = empty_slicer(z) + z_slicer[col_dim] = slice(i, i + offset) + if with_add: + z[z_slicer] += x_chunk + else: + z[z_slicer] = x_chunk + else: + b = compute_projection(z, mask, False, False) + x = torch.matmul(a, b) + x = self.layer_norm_out(x) + x = self.linear_z(x) + g = self.linear_g(z) + g.sigmoid_() + x *= g + if with_add: + z += x + else: + z = x + + return z + + def forward( + self, + z: torch.Tensor, + mask: Optional[torch.Tensor] = None, + inplace_safe: bool = False, + _add_with_inplace: bool = False, + _inplace_chunk_size: Optional[int] = 256, + ) -> torch.Tensor: + """ + Args: + x: + [*, N_res, N_res, C_z] input tensor + mask: + [*, N_res, N_res] input mask + Returns: + [*, N_res, N_res, C_z] output tensor + """ + if inplace_safe: + x = self._inference_forward( + z, + mask, + inplace_chunk_size=_inplace_chunk_size, + with_add=_add_with_inplace, + ) + return x + + if mask is None: + mask = z.new_ones(z.shape[:-1]) + + mask = mask.unsqueeze(-1) + + z = self.layer_norm_in(z) + a = mask + a = a * self.sigmoid(self.linear_a_g(z)) + a = a * self.linear_a_p(z) + b = mask + b = b * self.sigmoid(self.linear_b_g(z)) + b = b * self.linear_b_p(z) + + if is_fp16_enabled(): + with torch.cuda.amp.autocast(enabled=False): + x = self._combine_projections(a.float(), b.float()) + else: + x = self._combine_projections(a, b) + + del a, b + x = self.layer_norm_out(x) + x = self.linear_z(x) + g = self.sigmoid(self.linear_g(z)) + x = x * g + + return x + + +class EsmFoldPreTrainedModel(EsmPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + # Subclass `EsMPreTrainedModel` to deal with special init + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, EsmFoldLinear): + with torch.no_grad(): + if module.init_fn is not None: + module.init_fn(module.weight, module.bias) + elif module.init == "default": + trunc_normal_init_(module.weight, scale=1.0) + elif module.init == "relu": + trunc_normal_init_(module.weight, scale=2.0) + elif module.init == "glorot": + nn.init.xavier_uniform_(module.weight, gain=1) + elif module.init == "gating": + module.weight.fill_(0.0) + if module.bias: + module.bias.fill_(1.0) + elif module.init == "normal": + torch.nn.init.kaiming_normal_(module.weight, nonlinearity="linear") + elif module.init == "final": + module.weight.fill_(0.0) + elif isinstance(module, EsmFoldInvariantPointAttention): + ipa_point_weights_init_(module.head_weights) + elif isinstance(module, EsmFoldTriangularSelfAttentionBlock): + torch.nn.init.zeros_(module.tri_mul_in.linear_z.weight) + torch.nn.init.zeros_(module.tri_mul_in.linear_z.bias) + torch.nn.init.zeros_(module.tri_mul_out.linear_z.weight) + torch.nn.init.zeros_(module.tri_mul_out.linear_z.bias) + torch.nn.init.zeros_(module.tri_att_start.mha.linear_o.weight) + torch.nn.init.zeros_(module.tri_att_start.mha.linear_o.bias) + torch.nn.init.zeros_(module.tri_att_end.mha.linear_o.weight) + torch.nn.init.zeros_(module.tri_att_end.mha.linear_o.bias) + + torch.nn.init.zeros_(module.sequence_to_pair.o_proj.weight) + torch.nn.init.zeros_(module.sequence_to_pair.o_proj.bias) + torch.nn.init.zeros_(module.pair_to_sequence.linear.weight) + torch.nn.init.zeros_(module.seq_attention.o_proj.weight) + torch.nn.init.zeros_(module.seq_attention.o_proj.bias) + torch.nn.init.zeros_(module.mlp_seq.mlp[-2].weight) + torch.nn.init.zeros_(module.mlp_seq.mlp[-2].bias) + torch.nn.init.zeros_(module.mlp_pair.mlp[-2].weight) + torch.nn.init.zeros_(module.mlp_pair.mlp[-2].bias) + else: + super()._init_weights(module) + + +class EsmFoldSelfAttention(nn.Module): + def __init__(self, embed_dim, num_heads, head_width, gated=False): + super().__init__() + assert embed_dim == num_heads * head_width + + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_width = head_width + + self.proj = nn.Linear(embed_dim, embed_dim * 3, bias=False) + self.o_proj = nn.Linear(embed_dim, embed_dim, bias=True) + self.gated = gated + if gated: + self.g_proj = nn.Linear(embed_dim, embed_dim) + torch.nn.init.zeros_(self.g_proj.weight) + torch.nn.init.ones_(self.g_proj.bias) + + self.rescale_factor = self.head_width**-0.5 + + torch.nn.init.zeros_(self.o_proj.bias) + + def forward(self, x, mask=None, bias=None, indices=None): + """ + Basic self attention with optional mask and external pairwise bias. To handle sequences of different lengths, + use mask. + + Inputs: + x: batch of input sequneces (.. x L x C) mask: batch of boolean masks where 1=valid, 0=padding position (.. + x L_k) bias: batch of scalar pairwise attention biases (.. x Lq x Lk x num_heads) + + Outputs: + sequence projection (B x L x embed_dim), attention maps (B x L x L x num_heads) + """ + + t = self.proj(x).view(*x.shape[:2], self.num_heads, -1) + t = t.permute(0, 2, 1, 3) + q, k, v = t.chunk(3, dim=-1) + + q = self.rescale_factor * q + a = torch.einsum("...qc,...kc->...qk", q, k) + + # Add external attention bias. + if bias is not None: + a = a + bias.permute(0, 3, 1, 2) + + # Do not attend to padding tokens. + if mask is not None: + mask = mask[:, None, None] + a = a.masked_fill(mask == False, -np.inf) # noqa: E712 + + a = nn.functional.softmax(a, dim=-1) + + y = torch.einsum("...hqk,...hkc->...qhc", a, v) + y = y.reshape(*y.shape[:2], -1) + + if self.gated: + y = self.g_proj(x).sigmoid() * y + y = self.o_proj(y) + breakpoint() + return y, a.permute(0, 3, 1, 2) + + +class EsmFoldDropout(nn.Module): + """ + Implementation of dropout with the ability to share the dropout mask along a particular dimension. + """ + + def __init__(self, r: float, batch_dim: Union[int, List[int]]): + super().__init__() + + self.r = r + if isinstance(batch_dim, int): + batch_dim = [batch_dim] + self.batch_dim = batch_dim + self.dropout = nn.Dropout(self.r) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shape = list(x.shape) + if self.batch_dim is not None: + for bd in self.batch_dim: + shape[bd] = 1 + return x * self.dropout(x.new_ones(shape)) + + +class EsmFoldSequenceToPair(nn.Module): + def __init__(self, sequence_state_dim, inner_dim, pairwise_state_dim): + super().__init__() + + self.layernorm = nn.LayerNorm(sequence_state_dim) + self.proj = nn.Linear(sequence_state_dim, inner_dim * 2, bias=True) + self.o_proj = nn.Linear(2 * inner_dim, pairwise_state_dim, bias=True) + + torch.nn.init.zeros_(self.proj.bias) + torch.nn.init.zeros_(self.o_proj.bias) + + def forward(self, sequence_state): + """ + Inputs: + sequence_state: B x L x sequence_state_dim + + Output: + pairwise_state: B x L x L x pairwise_state_dim + + Intermediate state: + B x L x L x 2*inner_dim + """ + + assert len(sequence_state.shape) == 3 + + s = self.layernorm(sequence_state) + s = self.proj(s) + q, k = s.chunk(2, dim=-1) + + prod = q[:, None, :, :] * k[:, :, None, :] + diff = q[:, None, :, :] - k[:, :, None, :] + + x = torch.cat([prod, diff], dim=-1) + x = self.o_proj(x) + + return x + + +class EsmFoldPairToSequence(nn.Module): + def __init__(self, pairwise_state_dim, num_heads): + super().__init__() + + self.layernorm = nn.LayerNorm(pairwise_state_dim) + self.linear = nn.Linear(pairwise_state_dim, num_heads, bias=False) + + def forward(self, pairwise_state): + """ + Inputs: + pairwise_state: B x L x L x pairwise_state_dim + + Output: + pairwise_bias: B x L x L x num_heads + """ + assert len(pairwise_state.shape) == 4 + z = self.layernorm(pairwise_state) + pairwise_bias = self.linear(z) + return pairwise_bias + + +class EsmFoldResidueMLP(nn.Module): + def __init__(self, embed_dim, inner_dim, dropout=0): + super().__init__() + + self.mlp = nn.Sequential( + nn.LayerNorm(embed_dim), + nn.Linear(embed_dim, inner_dim), + nn.ReLU(), + nn.Linear(inner_dim, embed_dim), + nn.Dropout(dropout), + ) + + def forward(self, x): + return x + self.mlp(x) + + +class EsmFoldTriangularSelfAttentionBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + + sequence_state_dim = config.sequence_state_dim + pairwise_state_dim = config.pairwise_state_dim + sequence_num_heads = sequence_state_dim // config.sequence_head_width + pairwise_num_heads = pairwise_state_dim // config.pairwise_head_width + + self.layernorm_1 = nn.LayerNorm(sequence_state_dim) + + self.sequence_to_pair = EsmFoldSequenceToPair(sequence_state_dim, pairwise_state_dim // 2, pairwise_state_dim) + self.pair_to_sequence = EsmFoldPairToSequence(pairwise_state_dim, sequence_num_heads) + + self.seq_attention = EsmFoldSelfAttention( + sequence_state_dim, sequence_num_heads, config.sequence_head_width, gated=True + ) + self.tri_mul_out = EsmFoldTriangleMultiplicativeUpdate(config, _outgoing=True) + self.tri_mul_in = EsmFoldTriangleMultiplicativeUpdate(config, _outgoing=False) + + self.tri_att_start = EsmFoldTriangleAttention( + pairwise_state_dim, config.pairwise_head_width, pairwise_num_heads, inf=1e9, starting=True + ) + self.tri_att_end = EsmFoldTriangleAttention( + pairwise_state_dim, config.pairwise_head_width, pairwise_num_heads, inf=1e9, starting=False + ) + + self.mlp_seq = EsmFoldResidueMLP(sequence_state_dim, 4 * sequence_state_dim, dropout=config.dropout) + self.mlp_pair = EsmFoldResidueMLP(pairwise_state_dim, 4 * pairwise_state_dim, dropout=config.dropout) + + self.drop = nn.Dropout(config.dropout) + self.row_drop = EsmFoldDropout(config.dropout * 2, 2) + self.col_drop = EsmFoldDropout(config.dropout * 2, 1) + + def forward(self, sequence_state, pairwise_state, mask=None, chunk_size=None, **__kwargs): + """ + Inputs: + sequence_state: B x L x sequence_state_dim pairwise_state: B x L x L x pairwise_state_dim mask: B x L boolean + tensor of valid positions + + Output: + sequence_state: B x L x sequence_state_dim pairwise_state: B x L x L x pairwise_state_dim + """ + if len(sequence_state.shape) != 3: + raise ValueError(f"`sequence_state` should be a 3d-tensor, got {len(sequence_state.shape)} dims.") + if len(pairwise_state.shape) != 4: + raise ValueError(f"`pairwise_state` should be a 4d-tensor, got {len(pairwise_state.shape)} dims.") + if mask is not None and len(mask.shape) != 2: + raise ValueError(f"`mask` should be a 2d-tensor, got {len(mask.shape)} dims.") + + batch_dim, seq_dim, sequence_state_dim = sequence_state.shape + pairwise_state_dim = pairwise_state.shape[3] + + if sequence_state_dim != self.config.sequence_state_dim: + raise ValueError( + "`sequence_state` last dimension should be equal to `self.sequence_state_dim`. Got " + f"{sequence_state_dim} != {self.config.sequence_state_dim}." + ) + if pairwise_state_dim != self.config.pairwise_state_dim: + raise ValueError( + "`pairwise_state` last dimension should be equal to `self.pairwise_state_dim`. Got " + f"{pairwise_state_dim} != {self.config.pairwise_state_dim}." + ) + if batch_dim != pairwise_state.shape[0]: + raise ValueError( + f"`sequence_state` and `pairwise_state` have inconsistent batch size: {batch_dim} != " + f"{pairwise_state.shape[0]}." + ) + if seq_dim != pairwise_state.shape[1] or seq_dim != pairwise_state.shape[2]: + raise ValueError( + f"`sequence_state` and `pairwise_state` have inconsistent sequence length: {seq_dim} != " + f"{pairwise_state.shape[1]} or {pairwise_state.shape[2]}." + ) + + # Update sequence state + bias = self.pair_to_sequence(pairwise_state) + + # Self attention with bias + mlp. + y = self.layernorm_1(sequence_state) + y, _ = self.seq_attention(y, mask=mask, bias=bias) + sequence_state = sequence_state + self.drop(y) + sequence_state = self.mlp_seq(sequence_state) + + # Update pairwise state + pairwise_state = pairwise_state + self.sequence_to_pair(sequence_state) + + # Axial attention with triangular bias. + tri_mask = mask.unsqueeze(2) * mask.unsqueeze(1) if mask is not None else None + pairwise_state = pairwise_state + self.row_drop(self.tri_mul_out(pairwise_state, mask=tri_mask)) + pairwise_state = pairwise_state + self.col_drop(self.tri_mul_in(pairwise_state, mask=tri_mask)) + pairwise_state = pairwise_state + self.row_drop( + self.tri_att_start(pairwise_state, mask=tri_mask, chunk_size=chunk_size) + ) + pairwise_state = pairwise_state + self.col_drop( + self.tri_att_end(pairwise_state, mask=tri_mask, chunk_size=chunk_size) + ) + + # MLP over pairs. + pairwise_state = self.mlp_pair(pairwise_state) + + return sequence_state, pairwise_state + + +class EsmCategoricalMixture: + def __init__(self, param, bins=50, start=0, end=1): + # All tensors are of shape ..., bins. + self.logits = param + bins = torch.linspace(start, end, bins + 1, device=self.logits.device, dtype=self.logits.dtype) + self.v_bins = (bins[:-1] + bins[1:]) / 2 + + def log_prob(self, true): + # Shapes are: + # self.probs: ... x bins + # true : ... + true_index = (true.unsqueeze(-1) - self.v_bins[[None] * true.ndim]).abs().argmin(-1) + nll = self.logits.log_softmax(-1) + return torch.take_along_dim(nll, true_index.unsqueeze(-1), dim=-1).squeeze(-1) + + def mean(self): + return (self.logits.softmax(-1) @ self.v_bins.unsqueeze(1)).squeeze(-1) + + +def categorical_lddt(logits, bins=50): + # Logits are ..., 37, bins. + return EsmCategoricalMixture(logits, bins=bins).mean() + + +def get_axial_mask(mask): + """ + Helper to convert B x L mask of valid positions to axial mask used in row column attentions. + + Input: + mask: B x L tensor of booleans + + Output: + mask: B x L x L tensor of booleans + """ + + if mask is None: + return None + + if len(mask.shape) != 2: + raise ValueError(f"`mask` should be a 2d-tensor, got {len(mask.shape)} dims.") + batch_dim, seq_dim = mask.shape + m = mask.unsqueeze(1).expand(batch_dim, seq_dim, seq_dim) + m = m.reshape(batch_dim * seq_dim, seq_dim) + return m + + +class EsmFoldRelativePosition(nn.Module): + def __init__(self, config): + super().__init__() + self.bins = config.position_bins + + # Note an additional offset is used so that the 0th position + # is reserved for masked pairs. + self.embedding = torch.nn.Embedding(2 * self.bins + 2, config.pairwise_state_dim) + + def forward(self, residue_index, mask=None): + """ + Input: + residue_index: B x L tensor of indices (dtype=torch.long) mask: B x L tensor of booleans + + Output: + pairwise_state: B x L x L x pairwise_state_dim tensor of embeddings + """ + if residue_index.dtype != torch.long: + raise ValueError(f"`residue_index` has dtype {residue_index.dtype}, it should be `torch.long`.") + if mask is not None and residue_index.shape != mask.shape: + raise ValueError( + f"`residue_index` and `mask` have inconsistent shapes: {residue_index.shape} != {mask.shape}." + ) + + diff = residue_index[:, None, :] - residue_index[:, :, None] + diff = diff.clamp(-self.bins, self.bins) + diff = diff + self.bins + 1 # Add 1 to adjust for padding index. + + if mask is not None: + mask = mask[:, None, :] * mask[:, :, None] + diff[mask == False] = 0 # noqa: E712 + + output = self.embedding(diff) + return output + + +class EsmFoldAngleResnetBlock(nn.Module): + def __init__(self, config): + super().__init__() + + self.linear_1 = EsmFoldLinear(config.resnet_dim, config.resnet_dim, init="relu") + self.linear_2 = EsmFoldLinear(config.resnet_dim, config.resnet_dim, init="final") + + self.relu = nn.ReLU() + + def forward(self, a: torch.Tensor) -> torch.Tensor: + s_initial = a + + a = self.relu(a) + a = self.linear_1(a) + a = self.relu(a) + a = self.linear_2(a) + + return a + s_initial + + +class EsmFoldAngleResnet(nn.Module): + """ + Implements Algorithm 20, lines 11-14 + """ + + def __init__(self, config): + super().__init__() + self.config = config + + self.linear_in = EsmFoldLinear(config.sequence_dim, config.resnet_dim) + self.linear_initial = EsmFoldLinear(config.sequence_dim, config.resnet_dim) + + self.layers = nn.ModuleList() + for _ in range(config.num_resnet_blocks): + layer = EsmFoldAngleResnetBlock(config) + self.layers.append(layer) + + self.linear_out = EsmFoldLinear(config.resnet_dim, config.num_angles * 2) + + self.relu = nn.ReLU() + + def forward(self, s: torch.Tensor, s_initial: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + s: + [*, C_hidden] single embedding + s_initial: + [*, C_hidden] single embedding as of the start of the StructureModule + Returns: + [*, no_angles, 2] predicted angles + """ + # NOTE: The ReLU's applied to the inputs are absent from the supplement + # pseudocode but present in the source. For maximal compatibility with + # the pretrained weights, I'm going with the source. + + # [*, C_hidden] + s_initial = self.relu(s_initial) + s_initial = self.linear_initial(s_initial) + s = self.relu(s) + s = self.linear_in(s) + s = s + s_initial + + for l in self.layers: + s = l(s) + + s = self.relu(s) + + # [*, no_angles * 2] + s = self.linear_out(s) + + # [*, no_angles, 2] + s = s.view(s.shape[:-1] + (-1, 2)) + + unnormalized_s = s + norm_denom = torch.sqrt( + torch.clamp( + torch.sum(s**2, dim=-1, keepdim=True), + min=self.config.epsilon, + ) + ) + s = s / norm_denom + + return unnormalized_s, s + + +class EsmFoldInvariantPointAttention(nn.Module): + """ + Implements Algorithm 22. + """ + + def __init__(self, config): + super().__init__() + self.config = config + + c_s = config.sequence_dim + c_z = config.pairwise_dim + self.hidden_dim = config.ipa_dim + self.num_heads = config.num_heads_ipa + self.num_qk_points = config.num_qk_points + self.num_v_points = config.num_v_points + + # These linear layers differ from their specifications in the + # supplement. There, they lack bias and use Glorot initialization. + # Here as in the official source, they have bias and use the default + # Lecun initialization. + hc = config.ipa_dim * config.num_heads_ipa + self.linear_q = EsmFoldLinear(c_s, hc) + self.linear_kv = EsmFoldLinear(c_s, 2 * hc) + + hpq = config.num_heads_ipa * config.num_qk_points * 3 + self.linear_q_points = EsmFoldLinear(c_s, hpq) + + hpkv = config.num_heads_ipa * (config.num_qk_points + config.num_v_points) * 3 + self.linear_kv_points = EsmFoldLinear(c_s, hpkv) + + self.linear_b = EsmFoldLinear(c_z, config.num_heads_ipa) + + self.head_weights = nn.Parameter(torch.zeros((config.num_heads_ipa))) + + concat_out_dim = config.num_heads_ipa * (c_z + config.ipa_dim + config.num_v_points * 4) + self.linear_out = EsmFoldLinear(concat_out_dim, c_s, init="final") + + self.softmax = nn.Softmax(dim=-1) + self.softplus = nn.Softplus() + + def forward( + self, + s: torch.Tensor, + z: Optional[torch.Tensor], + r: Rigid, + mask: torch.Tensor, + _offload_inference: bool = False, + _z_reference_list: Optional[Sequence[torch.Tensor]] = None, + ) -> torch.Tensor: + """ + Args: + s: + [*, N_res, C_s] single representation + z: + [*, N_res, N_res, C_z] pair representation + r: + [*, N_res] transformation object + mask: + [*, N_res] mask + Returns: + [*, N_res, C_s] single representation update + """ + z = [z] + + ####################################### + # Generate scalar and point activations + ####################################### + # [*, N_res, H * C_hidden] + q = self.linear_q(s) + kv = self.linear_kv(s) + + # [*, N_res, H, C_hidden] + q = q.view(q.shape[:-1] + (self.num_heads, -1)) + + # [*, N_res, H, 2 * C_hidden] + kv = kv.view(kv.shape[:-1] + (self.num_heads, -1)) + + # [*, N_res, H, C_hidden] + k, v = torch.split(kv, self.hidden_dim, dim=-1) + + # [*, N_res, H * P_q * 3] + q_pts = self.linear_q_points(s) + + # This is kind of clunky, but it's how the original does it + # [*, N_res, H * P_q, 3] + q_pts = torch.split(q_pts, q_pts.shape[-1] // 3, dim=-1) + q_pts = torch.stack(q_pts, dim=-1) + q_pts = r[..., None].apply(q_pts) + + # [*, N_res, H, P_q, 3] + q_pts = q_pts.view(q_pts.shape[:-2] + (self.num_heads, self.num_qk_points, 3)) + + # [*, N_res, H * (P_q + P_v) * 3] + kv_pts = self.linear_kv_points(s) + + # [*, N_res, H * (P_q + P_v), 3] + kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1) + kv_pts = torch.stack(kv_pts, dim=-1) + kv_pts = r[..., None].apply(kv_pts) + + # [*, N_res, H, (P_q + P_v), 3] + kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.num_heads, -1, 3)) + + # [*, N_res, H, P_q/P_v, 3] + k_pts, v_pts = torch.split(kv_pts, [self.num_qk_points, self.num_v_points], dim=-2) + + ########################## + # Compute attention scores + ########################## + # [*, N_res, N_res, H] + b = self.linear_b(z[0]) + + if _offload_inference: + assert sys.getrefcount(z[0]) == 2 + z[0] = z[0].cpu() + + # [*, H, N_res, N_res] + if is_fp16_enabled(): + with torch.cuda.amp.autocast(enabled=False): + a = torch.matmul( + permute_final_dims(q.float(), (1, 0, 2)), # [*, H, N_res, C_hidden] + permute_final_dims(k.float(), (1, 2, 0)), # [*, H, C_hidden, N_res] + ) + else: + a = torch.matmul( + permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden] + permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res] + ) + + a *= math.sqrt(1.0 / (3 * self.hidden_dim)) + a += math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1)) + + # [*, N_res, N_res, H, P_q, 3] + pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5) + pt_att = pt_att**2 + + # [*, N_res, N_res, H, P_q] + pt_att = sum(torch.unbind(pt_att, dim=-1)) + head_weights = self.softplus(self.head_weights).view(*((1,) * len(pt_att.shape[:-2]) + (-1, 1))) + head_weights = head_weights * math.sqrt(1.0 / (3 * (self.num_qk_points * 9.0 / 2))) + pt_att = pt_att * head_weights + + # [*, N_res, N_res, H] + pt_att = torch.sum(pt_att, dim=-1) * (-0.5) + # [*, N_res, N_res] + square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2) + square_mask = self.config.inf * (square_mask - 1) + + # [*, H, N_res, N_res] + pt_att = permute_final_dims(pt_att, (2, 0, 1)) + + a = a + pt_att + a = a + square_mask.unsqueeze(-3) + a = self.softmax(a) + + ################ + # Compute output + ################ + # [*, N_res, H, C_hidden] + o = torch.matmul(a, v.transpose(-2, -3).to(dtype=a.dtype)).transpose(-2, -3) + + # [*, N_res, H * C_hidden] + o = flatten_final_dims(o, 2) + + # [*, H, 3, N_res, P_v] + o_pt = torch.sum( + (a[..., None, :, :, None] * permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]), + dim=-2, + ) + + # [*, N_res, H, P_v, 3] + o_pt = permute_final_dims(o_pt, (2, 0, 3, 1)) + o_pt = r[..., None, None].invert_apply(o_pt) + + # [*, N_res, H * P_v] + o_pt_norm = flatten_final_dims(torch.sqrt(torch.sum(o_pt**2, dim=-1) + self.config.epsilon), 2) + + # [*, N_res, H * P_v, 3] + o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3) + + if _offload_inference: + z[0] = z[0].to(o_pt.device) + + # [*, N_res, H, C_z] + o_pair = torch.matmul(a.transpose(-2, -3), z[0].to(dtype=a.dtype)) + + # [*, N_res, H * C_z] + o_pair = flatten_final_dims(o_pair, 2) + + # [*, N_res, C_s] + s = self.linear_out( + torch.cat((o, *torch.unbind(o_pt, dim=-1), o_pt_norm, o_pair), dim=-1).to(dtype=z[0].dtype) + ) + + return s + + +class EsmFoldBackboneUpdate(nn.Module): + """ + Implements part of Algorithm 23. + """ + + def __init__(self, config): + super().__init__() + + self.linear = EsmFoldLinear(config.sequence_dim, 6, init="final") + + def forward(self, s: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + [*, N_res, C_s] single representation + Returns: + [*, N_res, 6] update vector + """ + # [*, 6] + update = self.linear(s) + + return update + + +class EsmFoldStructureModuleTransitionLayer(nn.Module): + def __init__(self, config): + super().__init__() + + self.linear_1 = EsmFoldLinear(config.sequence_dim, config.sequence_dim, init="relu") + self.linear_2 = EsmFoldLinear(config.sequence_dim, config.sequence_dim, init="relu") + self.linear_3 = EsmFoldLinear(config.sequence_dim, config.sequence_dim, init="final") + + self.relu = nn.ReLU() + + def forward(self, s): + s_initial = s + s = self.linear_1(s) + s = self.relu(s) + s = self.linear_2(s) + s = self.relu(s) + s = self.linear_3(s) + + s = s + s_initial + + return s + + +class EsmFoldStructureModuleTransition(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + + self.layers = nn.ModuleList() + for _ in range(config.num_transition_layers): + l = EsmFoldStructureModuleTransitionLayer(config) + self.layers.append(l) + + self.dropout = nn.Dropout(config.dropout_rate) + self.layer_norm = LayerNorm(config.sequence_dim) + + def forward(self, s): + for l in self.layers: + s = l(s) + + s = self.dropout(s) + s = self.layer_norm(s) + + return s + + +class EsmFoldStructureModule(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + + # Buffers to be lazily initialized later + # self.default_frames + # self.group_idx + # self.atom_mask + # self.lit_positions + + self.layer_norm_s = LayerNorm(config.sequence_dim) + self.layer_norm_z = LayerNorm(config.pairwise_dim) + + self.linear_in = EsmFoldLinear(config.sequence_dim, config.sequence_dim) + + self.ipa = EsmFoldInvariantPointAttention(config) + + self.ipa_dropout = nn.Dropout(config.dropout_rate) + self.layer_norm_ipa = LayerNorm(config.sequence_dim) + + self.transition = EsmFoldStructureModuleTransition(config) + self.bb_update = EsmFoldBackboneUpdate(config) + self.angle_resnet = EsmFoldAngleResnet(config) + + def forward( + self, + evoformer_output_dict, + aatype, + mask=None, + _offload_inference=False, + ): + """ + Args: + evoformer_output_dict: + Dictionary containing: + "single": + [*, N_res, C_s] single representation + "pair": + [*, N_res, N_res, C_z] pair representation + aatype: + [*, N_res] amino acid indices + mask: + Optional [*, N_res] sequence mask + Returns: + A dictionary of outputs + """ + s = evoformer_output_dict["single"] + + if mask is None: + # [*, N] + mask = s.new_ones(s.shape[:-1]) + + # [*, N, C_s] + s = self.layer_norm_s(s) + + # [*, N, N, C_z] + z = self.layer_norm_z(evoformer_output_dict["pair"]) + + z_reference_list = None + if _offload_inference: + assert sys.getrefcount(evoformer_output_dict["pair"]) == 2 + evoformer_output_dict["pair"] = evoformer_output_dict["pair"].cpu() + z_reference_list = [z] + z = None + + # [*, N, C_s] + s_initial = s + s = self.linear_in(s) + + # [*, N] + rigids = Rigid.identity( + s.shape[:-1], + s.dtype, + s.device, + self.training, + fmt="quat", + ) + outputs = [] + for i in range(self.config.num_blocks): + # [*, N, C_s] + s = s + self.ipa( + s, + z, + rigids, + mask, + _offload_inference=_offload_inference, + _z_reference_list=z_reference_list, + ) + s = self.ipa_dropout(s) + s = self.layer_norm_ipa(s) + s = self.transition(s) + + # [*, N] + rigids = rigids.compose_q_update_vec(self.bb_update(s)) + + # To hew as closely as possible to AlphaFold, we convert our + # quaternion-based transformations to rotation-matrix ones + # here + backb_to_global = Rigid( + Rotation(rot_mats=rigids.get_rots().get_rot_mats(), quats=None), + rigids.get_trans(), + ) + + backb_to_global = backb_to_global.scale_translation(self.config.trans_scale_factor) + + # [*, N, 7, 2] + unnormalized_angles, angles = self.angle_resnet(s, s_initial) + + all_frames_to_global = self.torsion_angles_to_frames(backb_to_global, angles, aatype) + + pred_xyz = self.frames_and_literature_positions_to_atom14_pos(all_frames_to_global, aatype) + + scaled_rigids = rigids.scale_translation(self.config.trans_scale_factor) + + preds = { + "frames": scaled_rigids.to_tensor_7(), + "sidechain_frames": all_frames_to_global.to_tensor_4x4(), + "unnormalized_angles": unnormalized_angles, + "angles": angles, + "positions": pred_xyz, + "states": s, + } + + outputs.append(preds) + + rigids = rigids.stop_rot_gradient() + + del z, z_reference_list + + if _offload_inference: + evoformer_output_dict["pair"] = evoformer_output_dict["pair"].to(s.device) + + outputs = dict_multimap(torch.stack, outputs) + outputs["single"] = s + + return outputs + + def _init_residue_constants(self, float_dtype, device): + if not hasattr(self, "default_frames"): + self.register_buffer( + "default_frames", + torch.tensor( + residue_constants.restype_rigid_group_default_frame, + dtype=float_dtype, + device=device, + requires_grad=False, + ), + persistent=False, + ) + if not hasattr(self, "group_idx"): + self.register_buffer( + "group_idx", + torch.tensor( + residue_constants.restype_atom14_to_rigid_group, + device=device, + requires_grad=False, + ), + persistent=False, + ) + if not hasattr(self, "atom_mask"): + self.register_buffer( + "atom_mask", + torch.tensor( + residue_constants.restype_atom14_mask, + dtype=float_dtype, + device=device, + requires_grad=False, + ), + persistent=False, + ) + if not hasattr(self, "lit_positions"): + self.register_buffer( + "lit_positions", + torch.tensor( + residue_constants.restype_atom14_rigid_group_positions, + dtype=float_dtype, + device=device, + requires_grad=False, + ), + persistent=False, + ) + + def torsion_angles_to_frames(self, r, alpha, f): + # Lazily initialize the residue constants on the correct device + self._init_residue_constants(alpha.dtype, alpha.device) + # Separated purely to make testing less annoying + return torsion_angles_to_frames(r, alpha, f, self.default_frames) + + def frames_and_literature_positions_to_atom14_pos(self, r, f): # [*, N, 8] # [*, N] + # Lazily initialize the residue constants on the correct device + self._init_residue_constants(r.get_rots().dtype, r.get_rots().device) + return frames_and_literature_positions_to_atom14_pos( + r, + f, + self.default_frames, + self.group_idx, + self.atom_mask, + self.lit_positions, + ) + + +class EsmFoldingTrunk(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + + c_s = config.sequence_state_dim + c_z = config.pairwise_state_dim + + self.pairwise_positional_embedding = EsmFoldRelativePosition(config) + + self.blocks = nn.ModuleList([EsmFoldTriangularSelfAttentionBlock(config) for _ in range(config.num_blocks)]) + + self.recycle_bins = 15 + self.recycle_s_norm = nn.LayerNorm(c_s) + self.recycle_z_norm = nn.LayerNorm(c_z) + self.recycle_disto = nn.Embedding(self.recycle_bins, c_z) + self.recycle_disto.weight[0].detach().zero_() + + self.structure_module = EsmFoldStructureModule(config.structure_module) + self.trunk2sm_s = nn.Linear(c_s, config.structure_module.sequence_dim) + self.trunk2sm_z = nn.Linear(c_z, config.structure_module.pairwise_dim) + + self.chunk_size = config.chunk_size + + def set_chunk_size(self, chunk_size): + # This parameter means the axial attention will be computed + # in a chunked manner. This should make the memory used more or less O(L) instead of O(L^2). + # It's equivalent to running a for loop over chunks of the dimension we're iterative over, + # where the chunk_size is the size of the chunks, so 128 would mean to parse 128-length chunks. + self.chunk_size = chunk_size + + def forward(self, seq_feats, pair_feats, true_aa, residx, mask, no_recycles): + """ + Inputs: + seq_feats: B x L x C tensor of sequence features pair_feats: B x L x L x C tensor of pair features residx: B + x L long tensor giving the position in the sequence mask: B x L boolean tensor indicating valid residues + + Output: + predicted_structure: B x L x (num_atoms_per_residue * 3) tensor wrapped in a Coordinates object + """ + + device = seq_feats.device + s_s_0 = seq_feats + s_z_0 = pair_feats + + if no_recycles is None: + no_recycles = self.config.max_recycles + else: + if no_recycles < 0: + raise ValueError("Number of recycles must not be negative.") + no_recycles += 1 # First 'recycle' is just the standard forward pass through the model. + + def trunk_iter(s, z, residx, mask): + z = z + self.pairwise_positional_embedding(residx, mask=mask) + + for block in self.blocks: + s, z = block(s, z, mask=mask, residue_index=residx, chunk_size=self.chunk_size) + return s, z + + s_s = s_s_0 + s_z = s_z_0 + recycle_s = torch.zeros_like(s_s) + recycle_z = torch.zeros_like(s_z) + recycle_bins = torch.zeros(*s_z.shape[:-1], device=device, dtype=torch.int64) + + for recycle_idx in range(no_recycles): + with ContextManagers([] if recycle_idx == no_recycles - 1 else [torch.no_grad()]): + # === Recycling === + recycle_s = self.recycle_s_norm(recycle_s.detach()).to(device) + recycle_z = self.recycle_z_norm(recycle_z.detach()).to(device) + recycle_z += self.recycle_disto(recycle_bins.detach()).to(device) + + s_s, s_z = trunk_iter(s_s_0 + recycle_s, s_z_0 + recycle_z, residx, mask) + + # === Structure module === + structure = self.structure_module( + {"single": self.trunk2sm_s(s_s), "pair": self.trunk2sm_z(s_z)}, + true_aa, + mask.float(), + ) + + recycle_s = s_s + recycle_z = s_z + # Distogram needs the N, CA, C coordinates, and bin constants same as alphafold. + recycle_bins = EsmFoldingTrunk.distogram( + structure["positions"][-1][:, :, :3], + 3.375, + 21.375, + self.recycle_bins, + ) + + structure["s_s"] = s_s + structure["s_z"] = s_z + + return structure + + @staticmethod + def distogram(coords, min_bin, max_bin, num_bins): + # Coords are [... L x 3 x 3], where it's [N, CA, C] x 3 coordinates. + boundaries = torch.linspace( + min_bin, + max_bin, + num_bins - 1, + device=coords.device, + ) + boundaries = boundaries**2 + N, CA, C = [x.squeeze(-2) for x in coords.chunk(3, dim=-2)] + # Infer CB coordinates. + b = CA - N + c = C - CA + a = b.cross(c, dim=-1) + CB = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + CA + dists = (CB[..., None, :, :] - CB[..., :, None, :]).pow(2).sum(dim=-1, keepdims=True) + bins = torch.sum(dists > boundaries, dim=-1) # [..., L, L] + return bins + + +# TODO Add information to the docstring about any methods that convert to PDB format, or otherwise prepare +# the outputs for downstream use. + +class FAESMConfig(EsmConfig): + def __init__(self, use_fa=False, **kwargs): + super().__init__(**kwargs) + self.use_fa = use_fa + + def to_dict(self): + """ + Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. + + Returns: + `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, + """ + output = super().to_dict() + output["use_fa"] = self.use_fa + return output + + +@add_start_docstrings( + """ + ESMForProteinFolding is the HuggingFace port of the original ESMFold model. It consists of an ESM-2 "stem" followed + by a protein folding "head", although unlike most other output heads, this "head" is similar in size and runtime to + the rest of the model combined! It outputs a dictionary containing predicted structural information about the input + protein(s). + """, + ESM_START_DOCSTRING, +) +class FAEsmForProteinFolding(EsmPreTrainedModel): + _no_split_modules = ["EsmFoldStructureModule", "EsmFoldTriangularSelfAttentionBlock"] + + def __init__(self, config): + super().__init__(config) + + self.config = FAESMConfig(use_fa=True, **config.to_dict()) + + self.distogram_bins = 64 + + # self.esm = EsmModel(config, add_pooling_layer=False) + # config_dict = config.to_dict() + # config.use_fa = True + + self.esm = FAEsmModel(self.config, add_pooling_layer=False) + self.esm.requires_grad_(False) + # if self.config.esmfold_config.fp16_esm: + # self.esm.half() + self.esm.to(torch.float16) + + self.esm_feats = self.config.hidden_size + self.esm_attns = self.config.num_hidden_layers * self.config.num_attention_heads + self.esm_layers = self.config.num_hidden_layers + self.register_buffer("af2_to_esm", self._af2_to_esm_from_vocab_list(config.vocab_list)) + self.esm_s_combine = nn.Parameter(torch.zeros(self.esm_layers + 1)) + + trunk_config = self.config.esmfold_config.trunk + c_s = trunk_config.sequence_state_dim + c_z = trunk_config.pairwise_state_dim + self.esm_s_mlp = nn.Sequential( + LayerNorm(self.esm_feats), + nn.Linear(self.esm_feats, c_s), + nn.ReLU(), + nn.Linear(c_s, c_s), + ) + + # 0 is padding, N is unknown residues, N + 1 is mask. + self.n_tokens_embed = residue_constants.restype_num + 3 + self.pad_idx = 0 + self.unk_idx = self.n_tokens_embed - 2 + self.mask_idx = self.n_tokens_embed - 1 + self.esm_dict_cls_idx = self.config.vocab_list.index("") + self.esm_dict_mask_idx = self.config.vocab_list.index("") + self.esm_dict_eos_idx = self.config.vocab_list.index("") + self.esm_dict_padding_idx = self.config.vocab_list.index("") + if self.config.esmfold_config.embed_aa: + self.embedding = nn.Embedding(self.n_tokens_embed, c_s, padding_idx=0) + + self.trunk = EsmFoldingTrunk(trunk_config) + + self.distogram_head = nn.Linear(c_z, self.distogram_bins) + self.ptm_head = nn.Linear(c_z, self.distogram_bins) + self.lm_head = nn.Linear(c_s, self.n_tokens_embed) + self.lddt_bins = 50 + structure_module_config = trunk_config.structure_module + self.lddt_head = nn.Sequential( + nn.LayerNorm(structure_module_config.sequence_dim), + nn.Linear(structure_module_config.sequence_dim, self.config.esmfold_config.lddt_head_hid_dim), + nn.Linear(self.config.esmfold_config.lddt_head_hid_dim, self.config.esmfold_config.lddt_head_hid_dim), + nn.Linear(self.config.esmfold_config.lddt_head_hid_dim, 37 * self.lddt_bins), + ) + + @staticmethod + def _af2_to_esm_from_vocab_list(vocab_list: List[str]) -> torch.Tensor: + # Remember that t is shifted from residue_constants by 1 (0 is padding). + esm_reorder = [vocab_list.index("")] + [vocab_list.index(v) for v in residue_constants.restypes_with_x] + return torch.tensor(esm_reorder) + + @add_start_docstrings_to_model_forward(ESMFOLD_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=EsmForProteinFoldingOutput, config_class=EsmConfig) + def forward( + self, + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + masking_pattern: Optional[torch.Tensor] = None, + num_recycles: Optional[int] = None, + ) -> EsmForProteinFoldingOutput: + r""" + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, EsmForProteinFolding + + >>> model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1") + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1") + >>> inputs = tokenizer(["MLKNVQVQLV"], return_tensors="pt", add_special_tokens=False) # A tiny random peptide + >>> outputs = model(**inputs) + >>> folded_positions = outputs.positions + ``` + + """ + cfg = self.config.esmfold_config + + aa = input_ids # B x L + B = aa.shape[0] + L = aa.shape[1] + device = input_ids.device + if attention_mask is None: + attention_mask = torch.ones_like(aa, device=device) + if position_ids is None: + position_ids = torch.arange(L, device=device).expand_as(input_ids) + + # === ESM === + esmaa = self.af2_idx_to_esm_idx(aa, attention_mask) + + if masking_pattern is not None: + masked_aa, esmaa, mlm_targets = self.bert_mask(aa, esmaa, attention_mask, masking_pattern) + else: + masked_aa = aa + mlm_targets = None + + # We get sequence and pair representations from whatever version of ESM / + # configuration we are using. The sequence representation esm_s is always + # present. The pair embedding esm_z may be present depending on the + # configuration of the model. If esm_z is not used by the model then it + # is returned as None here. + esm_s = self.compute_language_model_representations(esmaa) + + # Convert esm_s and esm_z, if present, to the precision used by the trunk and + # the structure module. These tensors may be a lower precision if, for example, + # we're running the language model in fp16 precision. + esm_s = esm_s.to(self.esm_s_combine.dtype) + + if cfg.esm_ablate_sequence: + esm_s = esm_s * 0 + + esm_s = esm_s.detach() + + # === preprocessing === + esm_s = (self.esm_s_combine.softmax(0).unsqueeze(0) @ esm_s).squeeze(2) + s_s_0 = self.esm_s_mlp(esm_s) + + s_z_0 = s_s_0.new_zeros(B, L, L, cfg.trunk.pairwise_state_dim) + + if self.config.esmfold_config.embed_aa: + s_s_0 += self.embedding(masked_aa) + + structure: dict = self.trunk(s_s_0, s_z_0, aa, position_ids, attention_mask, no_recycles=num_recycles) + # Documenting what we expect: + structure = { + k: v + for k, v in structure.items() + if k + in [ + "s_z", + "s_s", + "frames", + "sidechain_frames", + "unnormalized_angles", + "angles", + "positions", + "states", + ] + } + + # Add BERT mask for the loss to use, if available. + if mlm_targets: + structure["mlm_targets"] = mlm_targets + + disto_logits = self.distogram_head(structure["s_z"]) + disto_logits = (disto_logits + disto_logits.transpose(1, 2)) / 2 + structure["distogram_logits"] = disto_logits + + lm_logits = self.lm_head(structure["s_s"]) + structure["lm_logits"] = lm_logits + + structure["aatype"] = aa + make_atom14_masks(structure) + # Of course, this doesn't respect the true mask because it doesn't know about it... + # We're not going to properly mask change of index tensors: + # "residx_atom14_to_atom37", + # "residx_atom37_to_atom14", + for k in [ + "atom14_atom_exists", + "atom37_atom_exists", + ]: + structure[k] *= attention_mask.unsqueeze(-1) + structure["residue_index"] = position_ids + + lddt_head = self.lddt_head(structure["states"]).reshape(structure["states"].shape[0], B, L, -1, self.lddt_bins) + structure["lddt_head"] = lddt_head + plddt = categorical_lddt(lddt_head[-1], bins=self.lddt_bins) + structure["plddt"] = plddt + + ptm_logits = self.ptm_head(structure["s_z"]) + structure["ptm_logits"] = ptm_logits + structure["ptm"] = compute_tm(ptm_logits, max_bin=31, no_bins=self.distogram_bins) + structure.update(compute_predicted_aligned_error(ptm_logits, max_bin=31, no_bins=self.distogram_bins)) + + return EsmForProteinFoldingOutput(**structure) + + def af2_idx_to_esm_idx(self, aa, mask): + # avoid indexing on different devices + if self.af2_to_esm.device != aa.device: + self.af2_to_esm = self.af2_to_esm.to(aa.device) + aa = (aa + 1).masked_fill(mask != 1, 0) + return self.af2_to_esm[aa] + + def compute_language_model_representations(self, esmaa: torch.Tensor) -> torch.Tensor: + device = next(self.parameters()).device + B, L = esmaa.shape # B = batch size, L = sequence length. + + if self.config.esmfold_config.bypass_lm: + esm_s = torch.zeros(B, L, self.esm_s_combine.size[0], -1, self.esm_feats, device=device) + return esm_s + + bosi, eosi = self.esm_dict_cls_idx, self.esm_dict_eos_idx + bos = esmaa.new_full((B, 1), bosi) + eos = esmaa.new_full((B, 1), self.esm_dict_padding_idx) + esmaa = torch.cat([bos, esmaa, eos], dim=1) + # Use the first padding index as eos during inference. + esmaa[range(B), (esmaa != 1).sum(1)] = eosi + + # _, esm_z, esm_s = self.esm(esmaa, return_pairs=self.config.esmfold_config.use_esm_attn_map) + # Because we do not support use_esm_attn_map in the HF port as it is not used in any public models, + # esm_z is always None + esm_hidden_states = [x.unsqueeze(0) for x in self.esm(esmaa, attention_mask=esmaa != 1, output_hidden_states=True)["hidden_states"]] + # breakpoint() + esm_s = torch.stack(esm_hidden_states, dim=2) + + esm_s = esm_s[:, 1:-1] # B, L, nLayers, C + + return esm_s + + def bert_mask(self, aa, esmaa, mask, pattern): + new_aa = aa.clone() + target = aa.clone() + new_esmaa = esmaa.clone() + new_aa[pattern == 1] = self.mask_idx + target[pattern != 1] = 0 + new_esmaa[pattern == 1] = self.esm_dict_mask_idx + return new_aa, new_esmaa, target + + @torch.no_grad() + def infer( + self, + seqs: Union[str, List[str]], + position_ids=None, + ): + if isinstance(seqs, str): + lst = [seqs] + else: + lst = seqs + # Returns the raw outputs of the model given an input sequence. + device = next(self.parameters()).device + aatype = collate_dense_tensors( + [ + torch.from_numpy( + residue_constants.sequence_to_onehot( + sequence=seq, + mapping=residue_constants.restype_order_with_x, + map_unknown_to_x=True, + ) + ) + .to(device) + .argmax(dim=1) + for seq in lst + ] + ) # B=1 x L + mask = collate_dense_tensors([aatype.new_ones(len(seq)) for seq in lst]) + position_ids = ( + torch.arange(aatype.shape[1], device=device).expand(len(lst), -1) + if position_ids is None + else position_ids.to(device) + ) + if position_ids.ndim == 1: + position_ids = position_ids.unsqueeze(0) + return self.forward( + aatype, + mask, + position_ids=position_ids, + ) + + @staticmethod + def output_to_pdb(output: Dict) -> List[str]: + """Returns the pbd (file) string from the model given the model output.""" + output = {k: v.to("cpu").numpy() for k, v in output.items()} + pdbs = [] + final_atom_positions = atom14_to_atom37(output["positions"][-1], output) + final_atom_mask = output["atom37_atom_exists"] + for i in range(output["aatype"].shape[0]): + aa = output["aatype"][i] + pred_pos = final_atom_positions[i] + mask = final_atom_mask[i] + resid = output["residue_index"][i] + 1 + pred = OFProtein( + aatype=aa, + atom_positions=pred_pos, + atom_mask=mask, + residue_index=resid, + b_factors=output["plddt"][i], + ) + pdbs.append(to_pdb(pred)) + return pdbs + + def infer_pdb(self, seqs, *args, **kwargs) -> str: + """Returns the pdb (file) string from the model given an input sequence.""" + assert isinstance(seqs, str) + output = self.infer(seqs, *args, **kwargs) + return self.output_to_pdb(output)[0] + + def infer_pdbs(self, seqs: List[str], *args, **kwargs) -> List[str]: + """Returns the pdb (file) string from the model given an input sequence.""" + output = self.infer(seqs, *args, **kwargs) + return self.output_to_pdb(output) diff --git a/faesm/progen2.py b/faesm/progen2.py new file mode 100644 index 0000000..34bbf40 --- /dev/null +++ b/faesm/progen2.py @@ -0,0 +1,865 @@ +# coding=utf-8 +# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Modified forward-pass implementation based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/gptj/modeling_gptj.py + +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss +from transformers.configuration_utils import PretrainedConfig +from transformers.activations import ACT2FN +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.utils.model_parallel_utils import assert_device_map, get_device_map + +try: + from flash_attn import flash_attn_func#, flash_attn_qkvpacked_func + FLASH_ATTN_AVAILABLE = True +except ImportError: + Warning("Flash Attention is not available. Falling back to standard attention.") + FLASH_ATTN_AVAILABLE = False + +logger = logging.get_logger(__name__) + +class ProGenConfig(PretrainedConfig): + model_type = "progen" + + def __init__( + self, + vocab_size=50400, + n_positions=2048, + n_ctx=2048, + n_embd=4096, + n_layer=28, + n_head=16, + rotary_dim=64, + n_inner=None, + activation_function="gelu_new", + resid_pdrop=0.0, + embd_pdrop=0.0, + attn_pdrop=0.0, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + scale_attn_weights=True, + gradient_checkpointing=False, + use_cache=True, + bos_token_id=50256, + eos_token_id=50256, + **kwargs + ): + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + self.vocab_size = vocab_size + self.n_ctx = n_ctx + self.n_positions = n_positions + self.n_embd = n_embd + self.n_layer = n_layer + self.n_head = n_head + self.n_inner = n_inner + self.rotary_dim = rotary_dim + self.activation_function = activation_function + self.resid_pdrop = resid_pdrop + self.embd_pdrop = embd_pdrop + self.attn_pdrop = attn_pdrop + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.gradient_checkpointing = gradient_checkpointing + self.scale_attn_weights = scale_attn_weights + self.use_cache = use_cache + + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + + @property + def max_position_embeddings(self): + return self.n_positions + + @property + def hidden_size(self): + return self.n_embd + + @property + def num_attention_heads(self): + return self.n_head + + @property + def num_hidden_layers(self): + return self.n_layer + + +def fixed_pos_embedding(x, seq_dim=1, seq_len=None): + dim = x.shape[-1] + if seq_len is None: + seq_len = x.shape[seq_dim] + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim)) + sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(seq_len), inv_freq).to(x.device).float() + return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp) + +def rotate_every_two(x): + x1 = x[:, :, :, ::2] + x2 = x[:, :, :, 1::2] + x = torch.stack((-x2, x1), axis=-1) + return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)') + + +def apply_rotary_pos_emb(x, sincos, offset=0): + sin, cos = map(lambda t: t[None, offset : x.shape[1] + offset, None, :].repeat_interleave(2, 3), sincos) + # einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2) + return (x * cos) + (rotate_every_two(x) * sin) + + +class ProGenAttention(nn.Module): + def __init__(self, config): + super().__init__() + + max_positions = config.max_position_embeddings + self.register_buffer( + "bias", + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( + 1, 1, max_positions, max_positions + ), + ) + self.register_buffer("masked_bias", torch.tensor(-1e9)) + + self.attn_dropout = nn.Dropout(config.attn_pdrop) + self.resid_dropout = nn.Dropout(config.resid_pdrop) + + self.embed_dim = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_attention_heads + if self.head_dim * self.num_attention_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and `num_attention_heads`: {self.num_attention_heads})." + ) + self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float16)).to(torch.get_default_dtype()) + self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim * 3, bias=False) + + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.rotary_dim = None + if config.rotary_dim is not None: + self.rotary_dim = config.rotary_dim + + def _split_heads(self, x, n_head, dim_head, mp_num=4): + reshaped = x.reshape(x.shape[:-1] + (n_head//mp_num, dim_head)) + reshaped = reshaped.reshape(x.shape[:-2] + (-1, ) + reshaped.shape[-1:]) + return reshaped + + def _merge_heads(self, tensor, num_attention_heads, attn_head_size): + """ + Merges attn_head_size dim and num_attn_heads dim into n_ctx + """ + if len(tensor.shape) == 5: + tensor = tensor.permute(0, 1, 3, 2, 4).contiguous() + elif len(tensor.shape) == 4: + tensor = tensor.permute(0, 2, 1, 3).contiguous() + else: + raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}") + new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,) + return tensor.view(new_shape) + + def _attn( + self, + query, + key, + value, + attention_mask=None, + head_mask=None, + ): + + # compute causal mask from causal mask buffer + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] + + # Keep the attention weights computation in fp32 to avoid overflow issues + query = query.to(torch.float16) + key = key.to(torch.float16) + + attn_weights = torch.matmul(query, key.transpose(-1, -2)) + + attn_weights = attn_weights / self.scale_attn + attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.Softmax(dim=-1)(attn_weights) + attn_weights = attn_weights.to(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + + def forward( + self, + hidden_states, + attention_mask=None, + layer_past=None, + head_mask=None, + use_cache=False, + output_attentions=False, + ): + + qkv = self.qkv_proj(hidden_states) + # TODO(enijkamp): factor out number of logical TPU-v3/v4 cores or make forward pass agnostic + # mp_num = 4 + mp_num = 8 + qkv_split = qkv.reshape(qkv.shape[:-1] + (mp_num, -1)) + + local_dim = self.head_dim * self.num_attention_heads // mp_num + query, value, key = torch.split(qkv_split, local_dim, dim=-1) + query = self._split_heads(query, self.num_attention_heads, self.head_dim, mp_num=mp_num) + key = self._split_heads(key, self.num_attention_heads, self.head_dim, mp_num=mp_num) + + value = self._split_heads(value, self.num_attention_heads, self.head_dim, mp_num=mp_num) + value = value.permute(0, 2, 1, 3) + + seq_len = key.shape[1] + offset = 0 + + if layer_past is not None: + offset = layer_past[0].shape[-2] + seq_len += offset + + if self.rotary_dim is not None: + k_rot = key[:, :, :, : self.rotary_dim] + k_pass = key[:, :, :, self.rotary_dim :] + + q_rot = query[:, :, :, : self.rotary_dim] + q_pass = query[:, :, :, self.rotary_dim :] + + sincos = fixed_pos_embedding(k_rot, 1, seq_len=seq_len) + k_rot = apply_rotary_pos_emb(k_rot, sincos, offset=offset) + q_rot = apply_rotary_pos_emb(q_rot, sincos, offset=offset) + + key = torch.cat([k_rot, k_pass], dim=-1) + query = torch.cat([q_rot, q_pass], dim=-1) + else: + sincos = fixed_pos_embedding(key, 1, seq_len=seq_len) + key = apply_rotary_pos_emb(key, sincos, offset=offset) + query = apply_rotary_pos_emb(query, sincos, offset=offset) + + key = key.permute(0, 2, 1, 3) + query = query.permute(0, 2, 1, 3) + + if layer_past is not None: + past_key = layer_past[0] + past_value = layer_past[1] + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + + if use_cache is True: + present = (key, value) + else: + present = None + + # compute self-attention: V x Softmax(QK^T) + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + + attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim) + + attn_output = self.out_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights,) + + return outputs # a, present, (attentions) + + + +class ProGenFlashAttention2(ProGenAttention): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + hidden_states = hidden_states.to(torch.float16) + B, T, C = hidden_states.size() # batch size, sequence length, embedding dimensionality (n_embd) + + qkv = self.qkv_proj(hidden_states) + mp_num = 8 + qkv_split = qkv.reshape(qkv.shape[:-1] + (mp_num, -1)) + + local_dim = self.head_dim * self.num_attention_heads // mp_num + query, value, key = torch.split(qkv_split, local_dim, dim=-1) + + query = self._split_heads(query, self.num_attention_heads, self.head_dim, mp_num=mp_num) + key = self._split_heads(key, self.num_attention_heads, self.head_dim, mp_num=mp_num) + + value = self._split_heads(value, self.num_attention_heads, self.head_dim, mp_num=mp_num) + value = value.permute(0, 2, 1, 3) + + seq_len = key.shape[1] + offset = 0 + + if layer_past is not None: + offset = layer_past[0].shape[-2] + seq_len += offset + + if self.rotary_dim is not None: + k_rot = key[:, :, :, : self.rotary_dim] + k_pass = key[:, :, :, self.rotary_dim :] + + q_rot = query[:, :, :, : self.rotary_dim] + q_pass = query[:, :, :, self.rotary_dim :] + + sincos = fixed_pos_embedding(k_rot, 1, seq_len=seq_len) + k_rot = apply_rotary_pos_emb(k_rot, sincos, offset=offset) + q_rot = apply_rotary_pos_emb(q_rot, sincos, offset=offset) + + key = torch.cat([k_rot, k_pass], dim=-1) + query = torch.cat([q_rot, q_pass], dim=-1) + else: + sincos = fixed_pos_embedding(key, 1, seq_len=seq_len) + key = apply_rotary_pos_emb(key, sincos, offset=offset) + query = apply_rotary_pos_emb(query, sincos, offset=offset) + + key = key.permute(0, 2, 1, 3) + query = query.permute(0, 2, 1, 3) + + if layer_past is not None: + past_key = layer_past[0] + past_value = layer_past[1] + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + + if use_cache is True: + present = (key, value) + else: + present = None + + # breakpoint() + # B, H, L, D + # attn_output = _flash_attention_forward( + # query.half().permute(0, 2, 1, 3), + # key.half().permute(0, 2, 1, 3), + # value.half().permute(0, 2, 1, 3), + # attention_mask, + # seq_len, + # dropout=self.attn_dropout.p, + # is_causal=True, + # use_top_left_mask=self._flash_attn_uses_top_left_mask, + # ) + + # qkv = torch.stack([query, key, value], dim=3).permute(0, 2, 3, 1, 4).half() + # attn_output = flash_attn_qkvpacked_func( + # qkv, + # dropout_p=self.attn_dropout.p, + # causal=True) + + attn_output = flash_attn_func( + query.half().permute(0, 2, 1, 3), + key.half().permute(0, 2, 1, 3), + value.half().permute(0, 2, 1, 3), + dropout_p=self.attn_dropout.p, + causal=True, + ) + + attn_output = attn_output.permute(0, 2, 1, 3) + attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim) + attn_output = self.out_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (None,) + + return outputs + + +class ProGenMLP(nn.Module): + def __init__(self, intermediate_size, config): # in MLP: intermediate_size= 4 * embed_dim + super().__init__() + embed_dim = config.n_embd + + self.fc_in = nn.Linear(embed_dim, intermediate_size) + self.fc_out = nn.Linear(intermediate_size, embed_dim) + + self.act = ACT2FN[config.activation_function] + self.dropout = nn.Dropout(config.resid_pdrop) + + def forward(self, hidden_states): + hidden_states = self.fc_in(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.fc_out(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class ProGenBlock(nn.Module): + def __init__(self, config): + super().__init__() + inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd + self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) + self.attn = ProGenFlashAttention2(config) if FLASH_ATTN_AVAILABLE else ProGenAttention(config) + self.mlp = ProGenMLP(inner_dim, config) + + def forward( + self, + hidden_states, + layer_past=None, + attention_mask=None, + head_mask=None, + use_cache=False, + output_attentions=False, + ): + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_outputs = self.attn( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + attn_output = attn_outputs[0] # output_attn: a, present, (attentions) + outputs = attn_outputs[1:] + + feed_forward_hidden_states = self.mlp(hidden_states) + hidden_states = attn_output + feed_forward_hidden_states + residual + + if use_cache: + outputs = (hidden_states,) + outputs + else: + outputs = (hidden_states,) + outputs[1:] + + return outputs # hidden_states, present, (attentions) + + +class ProGenPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = ProGenConfig + base_model_prefix = "transformer" + is_parallelizable = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (nn.Linear,)): + # Slightly different from Mesh Transformer JAX which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +class ProGenModel(ProGenPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.embed_dim = config.n_embd + self.vocab_size = config.vocab_size + self.wte = nn.Embedding(config.vocab_size, self.embed_dim) + self.drop = nn.Dropout(config.embd_pdrop) + self.h = nn.ModuleList([ProGenBlock(config) for _ in range(config.n_layer)]) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + self.rotary_dim = min(config.rotary_dim, config.n_ctx // config.num_attention_heads) + self.init_weights() + + # Model parallel + self.model_parallel = False + self.device_map = None + + + def parallelize(self, device_map=None): + # Check validity of device_map + self.device_map = ( + get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map + ) + assert_device_map(self.device_map, len(self.h)) + self.model_parallel = True + self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) + self.last_device = "cuda:" + str(max(self.device_map.keys())) + self.wte = self.wte.to(self.first_device) + # Load onto devices + for k, v in self.device_map.items(): + for block in v: + cuda_device = "cuda:" + str(k) + self.h[block] = self.h[block].to(cuda_device) + # ln_f to last + self.ln_f = self.ln_f.to(self.last_device) + + + def deparallelize(self): + self.model_parallel = False + self.device_map = None + self.first_device = "cpu" + self.last_device = "cpu" + self.wte = self.wte.to("cpu") + for index in range(len(self.h)): + self.h[index] = self.h[index].to("cpu") + self.ln_f = self.ln_f.to("cpu") + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.wte + + def set_input_embeddings(self, new_embeddings): + self.wte = new_embeddings + + def forward( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + + if position_ids is not None: + position_ids = position_ids.view(-1, input_shape[-1]) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0][0].size(-2) + + if position_ids is None: + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + # Attention mask. + if attention_mask is not None: + assert batch_size > 0, "batch_size has to be defined and > 0" + attention_mask = attention_mask.view(batch_size, -1) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * -10000.0 + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x num_attention_heads x N x N + # head_mask has shape n_layer x batch x num_attention_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + + hidden_states = inputs_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = input_shape + (hidden_states.size(-1),) + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure layer_past is on same device as hidden_states (might not be correct) + if layer_past is not None: + layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + if use_cache: + logger.warning( + "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " + "`use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache, output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + None, + attention_mask, + head_mask[i], + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.ln_f(hidden_states) + + hidden_states = hidden_states.view(*output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class ProGenForCausalLM(ProGenPreTrainedModel): + _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias", r"lm_head\.weight"] + + def __init__(self, config): + super().__init__(config) + self.transformer = ProGenModel(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size) + self.init_weights() + + # Model parallel + self.model_parallel = False + self.device_map = None + + def parallelize(self, device_map=None): + self.device_map = ( + get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.transformer.h)) + self.transformer.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.transformer.first_device) + self.model_parallel = True + + def deparallelize(self): + self.transformer.deparallelize() + self.transformer = self.transformer.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.model_parallel = False + torch.cuda.empty_cache() + + def get_output_embeddings(self): + return None + + def set_output_embeddings(self, new_embeddings): + return + + def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) + # only last token for inputs_ids if past is defined in kwargs + if past: + input_ids = input_ids[:, -1].unsqueeze(-1) + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past: + position_ids = position_ids[:, -1].unsqueeze(-1) + else: + position_ids = None + return { + "input_ids": input_ids, + "past_key_values": past, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + + def forward( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + ``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to + ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]`` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.transformer.first_device) + hidden_states = hidden_states.to(self.lm_head.weight.device) + + # make sure sampling in fp16 works correctly and + # compute loss in fp32 to match with mesh-tf version + # https://github.com/EleutherAI/gpt-neo/blob/89ce74164da2fb16179106f54e2269b5da8db333/models/gpt2/gpt2.py#L179 + lm_logits = self.lm_head(hidden_states).to(torch.float16) + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + loss = loss.to(hidden_states.dtype) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the :obj:`past_key_values` cache if + :meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is + called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past + ) + \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/benchmark.py b/tests/benchmark.py index 9f5e1b2..ed0310f 100644 --- a/tests/benchmark.py +++ b/tests/benchmark.py @@ -8,7 +8,7 @@ from transformers import EsmForMaskedLM, EsmTokenizer from faesm.esm import FAEsmForMaskedLM -from tests.utils import generate_random_esm2_inputs +# from tests.utils import generate_random_esm2_inputs # Set Seaborn theme and professional settings sns.set_theme(style="white") # Remove grid by using "white" @@ -27,6 +27,27 @@ } ) +def generate_random_esm2_inputs( + tokenizer, batch_size=3, min_seq_length=5, max_seq_length=10, device="cuda" +): + """Generate random ESM2 model inputs.""" + random_lengths = torch.randint( + min_seq_length, max_seq_length + 1, (batch_size,), device=device + ) + random_tokens = [ + torch.randint(low=4, high=29, size=(length,), device=device).tolist() + for length in random_lengths + ] + sequences = ["".join(tokenizer.convert_ids_to_tokens(seq)) for seq in random_tokens] + esm_input = tokenizer.batch_encode_plus( + sequences, + add_special_tokens=True, + padding=True, + truncation=True, + return_tensors="pt", + ) + esm_input = {k: v.to(device) for k, v in esm_input.items()} + return esm_input def benchmark_torch_memory(f, *args, **kwargs): torch.cuda.reset_peak_memory_stats() @@ -51,7 +72,7 @@ def benchmark_inference_time(f, *args, **kwargs): "facebook/esm2_t33_650M_UR50D", 8, torch.float16, - [100, 200, 300, 400, 500, 600, 700, 800, 1000], + [100, 200, 300, 400, 500, 600, 700], 10, ) ], diff --git a/tests/benchmark_faesmfold_vs_esmfold.py b/tests/benchmark_faesmfold_vs_esmfold.py new file mode 100644 index 0000000..820b736 --- /dev/null +++ b/tests/benchmark_faesmfold_vs_esmfold.py @@ -0,0 +1,221 @@ + +import time + +import matplotlib.pyplot as plt +import numpy as np +import pytest +import seaborn as sns +import torch +from transformers import EsmForMaskedLM, EsmTokenizer, EsmForProteinFolding +from tqdm import tqdm +from faesm.esm import FAEsmForMaskedLM +from faesm.esmfold import FAEsmForProteinFolding +import random + + +def generate_random_protein_sequences(mini_length, max_length): + """Generate random protein sequences.""" + length = random.randint(mini_length, max_length) + return "".join( + [ + random.choice("ACDEFGHIKLMNPQRSTVWY") + for _ in range(length) + ] + ) + +# Set Seaborn theme and professional settings +sns.set_theme(style="white") # Remove grid by using "white" +color_palette = sns.color_palette("Set2") # Professional color palette + +# Matplotlib font and size settings +plt.rcParams.update( + { + "font.family": "serif", # Use serif fonts for a professional look + "font.size": 14, # Larger font size for better readability + "axes.titlesize": 18, # Larger titles + "axes.labelsize": 16, # Larger axis labels + "xtick.labelsize": 14, # Larger x-tick labels + "ytick.labelsize": 14, # Larger y-tick labels + "legend.fontsize": 14, # Larger legend font + } +) + + +def benchmark_torch_memory(f, *args, **kwargs): + torch.cuda.reset_peak_memory_stats() + torch.cuda.synchronize() + f(*args, **kwargs) + torch.cuda.synchronize() + return torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024 # Convert to GB + + +def benchmark_inference_time(f, *args, **kwargs): + torch.cuda.synchronize() + start_time = time.time() + f(*args, **kwargs) + torch.cuda.synchronize() + return time.time() - start_time + +@pytest.mark.parametrize( + "dtype,max_seq_lengths,repeats", + [ + ( + torch.float16, + [100,200,300,400,500], + 3, + ) + ], +) +def test_esmfold_vs_faesmfold_benchmark(dtype, max_seq_lengths, repeats): + device = "cuda" if torch.cuda.is_available() else "cpu" + + + esm_memory_usage, fa_esm_memory_usage = [], [] + esm_inference_times, fa_esm_inference_times = [], [] + + for seq_length in max_seq_lengths: + inputs = generate_random_protein_sequences(mini_length=seq_length-50, max_length=seq_length) + + esm_memory_fold, fa_esm_memory_fold = [], [] + esm_time_fold, fa_esm_time_fold = [], [] + + for _ in tqdm(range(repeats)): + esmfold = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1").to(device).eval() + esmfold.esm = esmfold.esm.to(dtype) + def esm_forward(): + esmfold.infer_pdb(inputs) + esmfold.to(device) + esm_memory_fold.append(benchmark_torch_memory(esm_forward)) + esm_time_fold.append(benchmark_inference_time(esm_forward)) + esmfold.to("cpu") + torch.cuda.empty_cache() + + fa_esmfold = FAEsmForProteinFolding.from_pretrained("facebook/esmfold_v1").to(device).eval() + def fa_esm_forward(): + fa_esmfold.esm.half() + fa_esmfold.infer_pdb(inputs) + + fa_esmfold.to(device) + fa_esm_memory_fold.append(benchmark_torch_memory(fa_esm_forward)) + fa_esm_time_fold.append(benchmark_inference_time(fa_esm_forward)) + fa_esmfold.to("cpu") + torch.cuda.empty_cache() + + esm_memory_usage.append(np.mean(esm_memory_fold)) + fa_esm_memory_usage.append(np.mean(fa_esm_memory_fold)) + esm_inference_times.append(np.mean(esm_time_fold)) + fa_esm_inference_times.append(np.mean(fa_esm_time_fold)) + + print( + f"Seq Len: {seq_length}, Avg ESMFold Mem: {esm_memory_usage[-1]:.3f} GB, Avg FAESMFold Mem: {fa_esm_memory_usage[-1]:.3f} GB" + ) + print( + f"Seq Len: {seq_length}, Avg ESMFold Time: {esm_inference_times[-1]:.3f} s, Avg FAESMFold Time: {fa_esm_inference_times[-1]:.3f} s" + ) + + max_seq_lengths_filtered = max_seq_lengths[1:] + esm_inference_times = esm_inference_times[1:] + fa_esm_inference_times = fa_esm_inference_times[1:] + + memory_reduction = [ + (1 - (fa / esm)) * 100 for fa, esm in zip(fa_esm_memory_usage, esm_memory_usage) + ] + time_reduction = [ + (1 - (fa / esm)) * 100 for fa, esm in zip(fa_esm_inference_times, esm_inference_times) + ] + + fig, axes = plt.subplots(1, 2, figsize=(20, 8)) # Larger figure for better resolution + + # Left Plot: Memory Benchmark + ax1 = axes[0] + ax1.plot( + max_seq_lengths, + esm_memory_usage, + label="ESMFold Memory Usage (GB)", + marker="o", + color=color_palette[0], + ) + ax1.plot( + max_seq_lengths, + fa_esm_memory_usage, + label="FAESMFold Memory Usage (GB)", + marker="o", + color=color_palette[1], + ) + ax1.set_xlabel("Sequence Length") + ax1.set_ylabel("Memory Usage (GB)", color=color_palette[0]) + ax1.tick_params(axis="y", labelcolor=color_palette[0]) + ax1.legend(loc="upper left") + + ax1_twin = ax1.twinx() + ax1_twin.plot( + max_seq_lengths, + memory_reduction, + label="Memory Reduction (%)", + marker="o", + linestyle="--", + color=color_palette[2], + ) + ax1_twin.set_ylabel("Memory Reduction (%)", color=color_palette[2]) + ax1_twin.tick_params(axis="y", labelcolor=color_palette[2]) + ax1_twin.legend(loc="upper right") + + ax1.set_title("Memory Benchmark") + + # Right Plot: Time Benchmark + ax2 = axes[1] + ax2.plot( + max_seq_lengths_filtered, + esm_inference_times, + label="ESMFold Inference Time (s)", + marker="o", + color=color_palette[0], + ) + ax2.plot( + max_seq_lengths_filtered, + fa_esm_inference_times, + label="FAESMFold Inference Time (s)", + marker="o", + color=color_palette[1], + ) + ax2.set_xlabel("Sequence Length") + ax2.set_ylabel("Inference Time (s)", color=color_palette[0]) + ax2.tick_params(axis="y", labelcolor=color_palette[0]) + ax2.legend(loc="upper left") + + ax2_twin = ax2.twinx() + ax2_twin.plot( + max_seq_lengths_filtered, + time_reduction, + label="Time Reduction (%)", + marker="o", + linestyle="--", + color=color_palette[2], + ) + ax2_twin.set_ylabel("Time Reduction (%)", color=color_palette[2]) + ax2_twin.tick_params(axis="y", labelcolor=color_palette[2]) + ax2_twin.legend(loc="upper right") + + ax2.set_title("Inference Time Benchmark") + + plt.suptitle( + f"Data Type: {dtype}, Averaged over {repeats} runs", + fontsize=20, + ) + plt.tight_layout(rect=[0, 0, 1, 0.95]) + plt.savefig("esmfold_benchmark.png", dpi=300) # High resolution + plt.close() + + for seq_length, fa_mem, esm_mem, fa_time, esm_time in zip( + max_seq_lengths, + fa_esm_memory_usage, + esm_memory_usage, + fa_esm_inference_times, + esm_inference_times, + ): + assert ( + fa_mem <= esm_mem + ), f"Seq {seq_length}: FAESM {fa_mem:.3f} GB > ESM {esm_mem:.3f} GB!" + assert ( + fa_time <= esm_time + ), f"Seq {seq_length}: FAESM {fa_time:.3f} s > ESM {esm_time:.3f} s!" diff --git a/tests/benchmark_faprogen2.py b/tests/benchmark_faprogen2.py new file mode 100644 index 0000000..6a76a64 --- /dev/null +++ b/tests/benchmark_faprogen2.py @@ -0,0 +1,221 @@ +import time + +import matplotlib.pyplot as plt +import numpy as np +import pytest +import seaborn as sns +import torch +from transformers import AutoTokenizer +from faesm.progen2 import ProGenForCausalLM as FAProGenForCausalLM +from tests.progen2 import ProGenForCausalLM + + + +# Set Seaborn theme and professional settings +sns.set_theme(style="white") # Remove grid by using "white" +color_palette = sns.color_palette("Set2") # Professional color palette + +# Matplotlib font and size settings +plt.rcParams.update( + { + "font.family": "serif", # Use serif fonts for a professional look + "font.size": 14, # Larger font size for better readability + "axes.titlesize": 18, # Larger titles + "axes.labelsize": 16, # Larger axis labels + "xtick.labelsize": 14, # Larger x-tick labels + "ytick.labelsize": 14, # Larger y-tick labels + "legend.fontsize": 14, # Larger legend font + } +) + +def generate_random_protein_sequences(mini_length, max_length): + import random + """Generate random protein sequences.""" + length = random.randint(mini_length, max_length) + return "".join( + [ + random.choice("ACDEFGHIKLMNPQRSTVWY") + for _ in range(length) + ] + ) + +def benchmark_torch_memory(f, *args, **kwargs): + torch.cuda.reset_peak_memory_stats() + torch.cuda.synchronize() + f(*args, **kwargs) + torch.cuda.synchronize() + return torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024 # Convert to GB + + +def benchmark_inference_time(f, *args, **kwargs): + torch.cuda.synchronize() + start_time = time.time() + f(*args, **kwargs) + torch.cuda.synchronize() + return time.time() - start_time + + +@pytest.mark.parametrize( + "model_version,dtype,max_seq_lengths,repeats", + [ + ( + "jinyuan22/ProGen2-xlarge", + torch.float16, + [200, 300, 400, 500, 600, 700, 800, 1000], + 3, + ) + ], +) +def test_progen2_vs_faprogen2_benchmark(model_version, dtype, max_seq_lengths, repeats): + tokenizer = AutoTokenizer.from_pretrained(model_version) + device = "cuda" if torch.cuda.is_available() else "cpu" + + progen2 = ProGenForCausalLM.from_pretrained(model_version).to(dtype).to("cpu").eval() + fa_progen2 = ( + FAProGenForCausalLM.from_pretrained(model_version).to(dtype).to("cpu").eval() + ) + + progen2_memory_usage, fa_progen2_memory_usage = [], [] + progen2_inference_times, fa_progen2_inference_times = [], [] + + for seq_length in max_seq_lengths: + sequence = generate_random_protein_sequences(mini_length=seq_length-50, max_length=seq_length) + inputs = tokenizer(sequence, return_tensors="pt").to(device) + + progen2_memory_fold, fa_progen2_memory_fold = [], [] + progen2_time_fold, fa_progen2_time_fold = [], [] + + for _ in range(repeats): + + def progen2_forward(): + progen2(inputs.input_ids) + progen2.to(device) + progen2_memory_fold.append(benchmark_torch_memory(progen2_forward)) + progen2_time_fold.append(benchmark_inference_time(progen2_forward)) + progen2.to("cpu") + + def fa_progen2_forward(): + fa_progen2(inputs.input_ids) + fa_progen2.to(device) + fa_progen2_memory_fold.append(benchmark_torch_memory(fa_progen2_forward)) + fa_progen2_time_fold.append(benchmark_inference_time(fa_progen2_forward)) + fa_progen2.to("cpu") + + progen2_memory_usage.append(np.mean(progen2_memory_fold)) + fa_progen2_memory_usage.append(np.mean(fa_progen2_memory_fold)) + progen2_inference_times.append(np.mean(progen2_time_fold)) + fa_progen2_inference_times.append(np.mean(fa_progen2_time_fold)) + + print( + f"Seq Len: {seq_length}, Avg progen2 Mem: {progen2_memory_usage[-1]:.3f} GB, Avg FAprogen2 Mem: {fa_progen2_memory_usage[-1]:.3f} GB" + ) + print( + f"Seq Len: {seq_length}, Avg progen2 Time: {progen2_inference_times[-1]:.3f} s, Avg FAprogen2 Time: {fa_progen2_inference_times[-1]:.3f} s" + ) + + max_seq_lengths_filtered = max_seq_lengths[1:] + progen2_inference_times = progen2_inference_times[1:] + fa_progen2_inference_times = fa_progen2_inference_times[1:] + + memory_reduction = [ + (1 - (fa / progen2)) * 100 for fa, progen2 in zip(fa_progen2_memory_usage, progen2_memory_usage) + ] + time_reduction = [ + (1 - (fa / progen2)) * 100 for fa, progen2 in zip(fa_progen2_inference_times, progen2_inference_times) + ] + + fig, axes = plt.subplots(1, 2, figsize=(20, 8)) # Larger figure for better resolution + + # Left Plot: Memory Benchmark + ax1 = axes[0] + ax1.plot( + max_seq_lengths, + progen2_memory_usage, + label="progen2 Memory Usage (GB)", + marker="o", + color=color_palette[0], + ) + ax1.plot( + max_seq_lengths, + fa_progen2_memory_usage, + label="FAprogen2 Memory Usage (GB)", + marker="o", + color=color_palette[1], + ) + ax1.set_xlabel("Sequence Length") + ax1.set_ylabel("Memory Usage (GB)", color=color_palette[0]) + ax1.tick_params(axis="y", labelcolor=color_palette[0]) + ax1.legend(loc="upper left") + + ax1_twin = ax1.twinx() + ax1_twin.plot( + max_seq_lengths, + memory_reduction, + label="Memory Reduction (%)", + marker="o", + linestyle="--", + color=color_palette[2], + ) + ax1_twin.set_ylabel("Memory Reduction (%)", color=color_palette[2]) + ax1_twin.tick_params(axis="y", labelcolor=color_palette[2]) + ax1_twin.legend(loc="upper right") + + ax1.set_title("Memory Benchmark") + + # Right Plot: Time Benchmark + ax2 = axes[1] + ax2.plot( + max_seq_lengths_filtered, + progen2_inference_times, + label="progen2 Inference Time (s)", + marker="o", + color=color_palette[0], + ) + ax2.plot( + max_seq_lengths_filtered, + fa_progen2_inference_times, + label="FAprogen2 Inference Time (s)", + marker="o", + color=color_palette[1], + ) + ax2.set_xlabel("Sequence Length") + ax2.set_ylabel("Inference Time (s)", color=color_palette[0]) + ax2.tick_params(axis="y", labelcolor=color_palette[0]) + ax2.legend(loc="upper left") + + ax2_twin = ax2.twinx() + ax2_twin.plot( + max_seq_lengths_filtered, + time_reduction, + label="Time Reduction (%)", + marker="o", + linestyle="--", + color=color_palette[2], + ) + ax2_twin.set_ylabel("Time Reduction (%)", color=color_palette[2]) + ax2_twin.tick_params(axis="y", labelcolor=color_palette[2]) + ax2_twin.legend(loc="upper right") + + ax2.set_title("Inference Time Benchmark") + + plt.suptitle( + f"Model Version: {model_version}\nBatch Size: 1, Data Type: {dtype}, Averaged over {repeats} runs", + fontsize=20, + ) + plt.tight_layout(rect=[0, 0, 1, 0.95]) + plt.savefig("FAProGen2_benchmark.png", dpi=300) # High resolution + plt.close() + + for seq_length, fa_mem, progen2_mem, fa_time, progen2_time in zip( + max_seq_lengths, + fa_progen2_memory_usage, + progen2_memory_usage, + fa_progen2_inference_times, + progen2_inference_times, + ): + assert ( + fa_mem <= progen2_mem + ), f"Seq {seq_length}: FAprogen2 {fa_mem:.3f} GB > progen2 {progen2_mem:.3f} GB!" + assert ( + fa_time <= progen2_time + ), f"Seq {seq_length}: FAprogen2 {fa_time:.3f} s > progen2 {progen2_time:.3f} s!" diff --git a/tests/progen2.py b/tests/progen2.py new file mode 100644 index 0000000..0b6a679 --- /dev/null +++ b/tests/progen2.py @@ -0,0 +1,754 @@ +# coding=utf-8 +# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Modified forward-pass implementation based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/gptj/modeling_gptj.py + +from typing import Tuple + +import numpy as np + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.utils.model_parallel_utils import assert_device_map, get_device_map + + +from transformers.configuration_utils import PretrainedConfig + + +logger = logging.get_logger(__name__) + + +class ProGenConfig(PretrainedConfig): + model_type = "progen" + + def __init__( + self, + vocab_size=50400, + n_positions=2048, + n_ctx=2048, + n_embd=4096, + n_layer=28, + n_head=16, + rotary_dim=64, + n_inner=None, + activation_function="gelu_new", + resid_pdrop=0.0, + embd_pdrop=0.0, + attn_pdrop=0.0, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + scale_attn_weights=True, + gradient_checkpointing=False, + use_cache=True, + bos_token_id=50256, + eos_token_id=50256, + **kwargs + ): + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + self.vocab_size = vocab_size + self.n_ctx = n_ctx + self.n_positions = n_positions + self.n_embd = n_embd + self.n_layer = n_layer + self.n_head = n_head + self.n_inner = n_inner + self.rotary_dim = rotary_dim + self.activation_function = activation_function + self.resid_pdrop = resid_pdrop + self.embd_pdrop = embd_pdrop + self.attn_pdrop = attn_pdrop + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.gradient_checkpointing = gradient_checkpointing + self.scale_attn_weights = scale_attn_weights + self.use_cache = use_cache + + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + + @property + def max_position_embeddings(self): + return self.n_positions + + @property + def hidden_size(self): + return self.n_embd + + @property + def num_attention_heads(self): + return self.n_head + + @property + def num_hidden_layers(self): + return self.n_layer + +def fixed_pos_embedding(x, seq_dim=1, seq_len=None): + dim = x.shape[-1] + if seq_len is None: + seq_len = x.shape[seq_dim] + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim)) + sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(seq_len), inv_freq).to(x.device).float() + return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp) + + +def rotate_every_two(x): + x1 = x[:, :, :, ::2] + x2 = x[:, :, :, 1::2] + x = torch.stack((-x2, x1), axis=-1) + return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)') + + +def apply_rotary_pos_emb(x, sincos, offset=0): + sin, cos = map(lambda t: t[None, offset : x.shape[1] + offset, None, :].repeat_interleave(2, 3), sincos) + # einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2) + return (x * cos) + (rotate_every_two(x) * sin) + + +class ProGenAttention(nn.Module): + def __init__(self, config): + super().__init__() + + max_positions = config.max_position_embeddings + self.register_buffer( + "bias", + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( + 1, 1, max_positions, max_positions + ), + ) + self.register_buffer("masked_bias", torch.tensor(-1e9)) + + self.attn_dropout = nn.Dropout(config.attn_pdrop) + self.resid_dropout = nn.Dropout(config.resid_pdrop) + + self.embed_dim = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_attention_heads + if self.head_dim * self.num_attention_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and `num_attention_heads`: {self.num_attention_heads})." + ) + self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()) + self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim * 3, bias=False) + + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.rotary_dim = None + if config.rotary_dim is not None: + self.rotary_dim = config.rotary_dim + + def _split_heads(self, x, n_head, dim_head, mp_num): + reshaped = x.reshape(x.shape[:-1] + (n_head//mp_num, dim_head)) + reshaped = reshaped.reshape(x.shape[:-2] + (-1, ) + reshaped.shape[-1:]) + return reshaped + + def _merge_heads(self, tensor, num_attention_heads, attn_head_size): + """ + Merges attn_head_size dim and num_attn_heads dim into n_ctx + """ + if len(tensor.shape) == 5: + tensor = tensor.permute(0, 1, 3, 2, 4).contiguous() + elif len(tensor.shape) == 4: + tensor = tensor.permute(0, 2, 1, 3).contiguous() + else: + raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}") + new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,) + return tensor.view(new_shape) + + def _attn( + self, + query, + key, + value, + attention_mask=None, + head_mask=None, + ): + + # compute causal mask from causal mask buffer + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] + + # Keep the attention weights computation in fp32 to avoid overflow issues + query = query.to(torch.float32) + key = key.to(torch.float32) + + attn_weights = torch.matmul(query, key.transpose(-1, -2)) + + attn_weights = attn_weights / self.scale_attn + attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.Softmax(dim=-1)(attn_weights) + attn_weights = attn_weights.to(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + + def forward( + self, + hidden_states, + attention_mask=None, + layer_past=None, + head_mask=None, + use_cache=False, + output_attentions=False, + ): + + qkv = self.qkv_proj(hidden_states) + # TODO(enijkamp): factor out number of logical TPU-v3/v4 cores or make forward pass agnostic + # mp_num = 4 + mp_num = 8 + qkv_split = qkv.reshape(qkv.shape[:-1] + (mp_num, -1)) + + local_dim = self.head_dim * self.num_attention_heads // mp_num + query, value, key = torch.split(qkv_split, local_dim, dim=-1) + query = self._split_heads(query, self.num_attention_heads, self.head_dim, mp_num=mp_num) + key = self._split_heads(key, self.num_attention_heads, self.head_dim, mp_num=mp_num) + + value = self._split_heads(value, self.num_attention_heads, self.head_dim, mp_num=mp_num) + value = value.permute(0, 2, 1, 3) + + seq_len = key.shape[1] + offset = 0 + + if layer_past is not None: + offset = layer_past[0].shape[-2] + seq_len += offset + + if self.rotary_dim is not None: + k_rot = key[:, :, :, : self.rotary_dim] + k_pass = key[:, :, :, self.rotary_dim :] + + q_rot = query[:, :, :, : self.rotary_dim] + q_pass = query[:, :, :, self.rotary_dim :] + + sincos = fixed_pos_embedding(k_rot, 1, seq_len=seq_len) + k_rot = apply_rotary_pos_emb(k_rot, sincos, offset=offset) + q_rot = apply_rotary_pos_emb(q_rot, sincos, offset=offset) + + key = torch.cat([k_rot, k_pass], dim=-1) + query = torch.cat([q_rot, q_pass], dim=-1) + else: + sincos = fixed_pos_embedding(key, 1, seq_len=seq_len) + key = apply_rotary_pos_emb(key, sincos, offset=offset) + query = apply_rotary_pos_emb(query, sincos, offset=offset) + + key = key.permute(0, 2, 1, 3) + query = query.permute(0, 2, 1, 3) + + if layer_past is not None: + past_key = layer_past[0] + past_value = layer_past[1] + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + + if use_cache is True: + present = (key, value) + else: + present = None + + # compute self-attention: V x Softmax(QK^T) + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + + attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim) + + attn_output = self.out_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights,) + + return outputs # a, present, (attentions) + + +class ProGenMLP(nn.Module): + def __init__(self, intermediate_size, config): # in MLP: intermediate_size= 4 * embed_dim + super().__init__() + embed_dim = config.n_embd + + self.fc_in = nn.Linear(embed_dim, intermediate_size) + self.fc_out = nn.Linear(intermediate_size, embed_dim) + + self.act = ACT2FN[config.activation_function] + self.dropout = nn.Dropout(config.resid_pdrop) + + def forward(self, hidden_states): + hidden_states = self.fc_in(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.fc_out(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class ProGenBlock(nn.Module): + def __init__(self, config): + super().__init__() + inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd + self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) + self.attn = ProGenAttention(config) + self.mlp = ProGenMLP(inner_dim, config) + + def forward( + self, + hidden_states, + layer_past=None, + attention_mask=None, + head_mask=None, + use_cache=False, + output_attentions=False, + ): + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_outputs = self.attn( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + attn_output = attn_outputs[0] # output_attn: a, present, (attentions) + outputs = attn_outputs[1:] + + feed_forward_hidden_states = self.mlp(hidden_states) + hidden_states = attn_output + feed_forward_hidden_states + residual + + if use_cache: + outputs = (hidden_states,) + outputs + else: + outputs = (hidden_states,) + outputs[1:] + + return outputs # hidden_states, present, (attentions) + + +class ProGenPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = ProGenConfig + base_model_prefix = "transformer" + is_parallelizable = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (nn.Linear,)): + # Slightly different from Mesh Transformer JAX which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +class ProGenModel(ProGenPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.embed_dim = config.n_embd + self.vocab_size = config.vocab_size + self.wte = nn.Embedding(config.vocab_size, self.embed_dim) + self.drop = nn.Dropout(config.embd_pdrop) + self.h = nn.ModuleList([ProGenBlock(config) for _ in range(config.n_layer)]) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + self.rotary_dim = min(config.rotary_dim, config.n_ctx // config.num_attention_heads) + self.init_weights() + + # Model parallel + self.model_parallel = False + self.device_map = None + + + def parallelize(self, device_map=None): + # Check validity of device_map + self.device_map = ( + get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map + ) + assert_device_map(self.device_map, len(self.h)) + self.model_parallel = True + self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) + self.last_device = "cuda:" + str(max(self.device_map.keys())) + self.wte = self.wte.to(self.first_device) + # Load onto devices + for k, v in self.device_map.items(): + for block in v: + cuda_device = "cuda:" + str(k) + self.h[block] = self.h[block].to(cuda_device) + # ln_f to last + self.ln_f = self.ln_f.to(self.last_device) + + + def deparallelize(self): + self.model_parallel = False + self.device_map = None + self.first_device = "cpu" + self.last_device = "cpu" + self.wte = self.wte.to("cpu") + for index in range(len(self.h)): + self.h[index] = self.h[index].to("cpu") + self.ln_f = self.ln_f.to("cpu") + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.wte + + def set_input_embeddings(self, new_embeddings): + self.wte = new_embeddings + + def forward( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + + if position_ids is not None: + position_ids = position_ids.view(-1, input_shape[-1]) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0][0].size(-2) + + if position_ids is None: + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + # Attention mask. + if attention_mask is not None: + assert batch_size > 0, "batch_size has to be defined and > 0" + attention_mask = attention_mask.view(batch_size, -1) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * -10000.0 + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x num_attention_heads x N x N + # head_mask has shape n_layer x batch x num_attention_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + + hidden_states = inputs_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = input_shape + (hidden_states.size(-1),) + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure layer_past is on same device as hidden_states (might not be correct) + if layer_past is not None: + layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + if use_cache: + logger.warning( + "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " + "`use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache, output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + None, + attention_mask, + head_mask[i], + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.ln_f(hidden_states) + + hidden_states = hidden_states.view(*output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class ProGenForCausalLM(ProGenPreTrainedModel): + _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias", r"lm_head\.weight"] + + def __init__(self, config): + super().__init__(config) + self.transformer = ProGenModel(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size) + self.init_weights() + + # Model parallel + self.model_parallel = False + self.device_map = None + + def parallelize(self, device_map=None): + self.device_map = ( + get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.transformer.h)) + self.transformer.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.transformer.first_device) + self.model_parallel = True + + def deparallelize(self): + self.transformer.deparallelize() + self.transformer = self.transformer.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.model_parallel = False + torch.cuda.empty_cache() + + def get_output_embeddings(self): + return None + + def set_output_embeddings(self, new_embeddings): + return + + def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) + # only last token for inputs_ids if past is defined in kwargs + if past: + input_ids = input_ids[:, -1].unsqueeze(-1) + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past: + position_ids = position_ids[:, -1].unsqueeze(-1) + else: + position_ids = None + return { + "input_ids": input_ids, + "past_key_values": past, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + + def forward( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + ``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to + ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]`` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.transformer.first_device) + hidden_states = hidden_states.to(self.lm_head.weight.device) + + # make sure sampling in fp16 works correctly and + # compute loss in fp32 to match with mesh-tf version + # https://github.com/EleutherAI/gpt-neo/blob/89ce74164da2fb16179106f54e2269b5da8db333/models/gpt2/gpt2.py#L179 + lm_logits = self.lm_head(hidden_states).to(torch.float32) + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + loss = loss.to(hidden_states.dtype) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the :obj:`past_key_values` cache if + :meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is + called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past + ) \ No newline at end of file