From 0129ab0292799c0d614a104d06288664dd02d2c7 Mon Sep 17 00:00:00 2001 From: Rick Ho Date: Fri, 1 Apr 2022 11:44:48 +0800 Subject: [PATCH 1/6] update readme --- README.md | 33 +++++++++++++++++++++++++++++++++ setup.py | 3 ++- 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index ab849d22..bb8823b9 100644 --- a/README.md +++ b/README.md @@ -99,8 +99,20 @@ FastMoE's model parallel requires sophiscated parallel strategies that neither P Megatron-LM provides. The `fmoe.DistributedGroupedDataParallel` module is introduced to replace PyTorch's DDP module. +#### Faster Performance Features + +From a PPoPP'22 paper, _FasterMoE: modeling and optimizing training of +large-scale dynamic pre-trained models_, we have adopted techniques to make +FastMoE's model parallel much more efficient. + +These optimizations are named as **Faster Performance Features**, and can be +enabled via several environment variables. Their usage and constraints are +detailed in [a separate document](doc/fastermoe). + ## Citation +For the core FastMoE system. + ``` @article{he2021fastmoe, title={FastMoE: A Fast Mixture-of-Expert Training System}, @@ -110,6 +122,27 @@ introduced to replace PyTorch's DDP module. } ``` +For the [faster performance features](doc/fastermoe). + +``` +@inproceedings{he2022fastermoe, + author = {He, Jiaao and Zhai, Jidong and Antunes, Tiago and Wang, Haojie and Luo, Fuwen and Shi, Shangfeng and Li, Qin}, + title = {FasterMoE: Modeling and Optimizing Training of Large-Scale Dynamic Pre-Trained Models}, + year = {2022}, + isbn = {9781450392044}, + publisher = {Association for Computing Machinery}, + address = {New York, NY, USA}, + url = {https://doi.org/10.1145/3503221.3508418}, + doi = {10.1145/3503221.3508418}, + booktitle = {Proceedings of the 27th ACM SIGPLAN Symposium on Principles and Practice of Parallel Programming}, + pages = {120–134}, + numpages = {15}, + keywords = {parallelism, distributed deep learning, performance modeling}, + location = {Seoul, Republic of Korea}, + series = {PPoPP '22} +} +``` + ## Troubleshootings / Discussion If you have any problem using FastMoE, or you are interested in getting involved in developing FastMoE, feel free to join the [our slack channel](https://join.slack.com/t/fastmoe/shared_invite/zt-mz0ai6ol-ggov75D62YsgHfzShw8KYw). diff --git a/setup.py b/setup.py index e421c4ed..a380e28c 100644 --- a/setup.py +++ b/setup.py @@ -13,6 +13,7 @@ 'Tiago Antunes', 'Jinjun Peng', 'Qin Li', + 'Mingshu Zhai' ] is_rocm_pytorch = False @@ -37,7 +38,7 @@ if __name__ == '__main__': setuptools.setup( name='fastmoe', - version='0.3.0', + version='1.0.0', description='An efficient Mixture-of-Experts system for PyTorch', author=', '.join(authors), author_email='hja20@mails.tsinghua.edu.cn', From 7e6459471e7c28d67ab51ed7989a73e2167e6a37 Mon Sep 17 00:00:00 2001 From: Rick Ho Date: Fri, 1 Apr 2022 16:45:26 +0800 Subject: [PATCH 2/6] readme for fastermoe --- doc/fastermoe/README.md | 98 +++++++++++++++++++++++++++++++++++++ doc/fastermoe/smartsch.png | Bin 0 -> 9862 bytes 2 files changed, 98 insertions(+) create mode 100644 doc/fastermoe/README.md create mode 100644 doc/fastermoe/smartsch.png diff --git a/doc/fastermoe/README.md b/doc/fastermoe/README.md new file mode 100644 index 00000000..5fdbf3a9 --- /dev/null +++ b/doc/fastermoe/README.md @@ -0,0 +1,98 @@ +Boost the Performance by FasterMoE +=== + +一个中文版见[这篇博客](https://laekov.com.cn/view/181401) + +There are three main optimizations in the PPoPP'22 paper _FasterMoE: Modeling +and Optimizing Training of Large-scale Dynamic Pre-trained Models_. Thanks to +the contributions of authors of the article, their optimizations are now +integrated into FastMoE, and can be enabled via switches of environment +variables. These optimizations can greatly increase the training efficiency of +FastMoE. + +## Smart Scheduling + +Recall that in an MoE layer, two `all-to-all`s are performed with the experts' +computation in-between. In FasterMoE, the `all-to-all`s are broken down using +a _group-wise exchange_ algorithm. And then, the expert can instantly start +its jobs as long as a part of input, e.g. tokens from one other worker, is +ready. + +Its effectiveness is revealed in the following timeline. `S` and `R` stand for +the components of the `all-to-all`s, and `C` stands for computation of the +expert. + +![](smartsch.png) + +In FastMoE, to enable smart scheduling, set the environment variable ` +FMOE_FASTER_SCHEDULE_ENABLE` to `1` or `ON`, and it is now by default off. + +Please note that there are a few constraints for smart scheduling in the +current version of FastMoE. `num_expert` has to be `1`, which means only one +expert can reside on each worker. The input and output features have to be of +the same length for the experts. This is because the developers of FasterMoE +only implement this on their prototype, and they are looking for the +community's efforts to have other cases supported. + +To fine-tune the performance of smart scheduling, the environment variable +`FMOE_FASTER_GROUP_SIZE` stands for the size of worker groups in the +_Group-wise Exchange_ algorithm. In other words, it is the granularity of the +schedule. It should be set to a proper value that balance between pipeline +bubbles and inefficient undersized computation granularity. + +## Expert Shadowing + +According to observations when training real models, when no limitation is +placed over expert selection, it follows a skew distribution, which means a few +experts are much more popular than others. This introduces significant +performance issue of load imbalance when using FastMoE's model parallel mode. + +The authors of FasterMoE proposes the solution that for the hot experts, their +parameters are broadcast to all workers, namely shadows. With the shadows, +computation of the hot experts can be performed locally on all workers, +avoiding the bottleneck of sending so much workload to the workers containing +the hot experts. Besides, a performance predictor, together with a shadow +selection algorithm, is used to determine which experts to be shadowed before +each iteration. + +In FastMoE, this feature is enabled by the environment variable +`FMOE_FASTER_SHADOW_ENABLE`. For simplicity, this feature is only available +when smart scheduling is enabled. Besides the constraints of smart scheduling, +this feature requires the experts to be identical in structure, so that +parameters can be copied between experts. + +A default shadow selection policy is located at +`fmoe/fastermoe/shadow_policy.py`. If you want to alter the policy, please code +there and re-install FastMoE. For the default policy, we assume that the +experts are two-layer MLPs. A few parameters of the policy can be specified by +the following environment variables for better effectiveness of the shadowing +mechanism. + +* `FMOE_FASTER_GLBPLC_NETBW` is the bandwidth of the interconnection between + workers, measured by `GBps`. +* `FMOE_FASTER_GLBPLC_GPUTP` is the GeMM throughput of the GPUs, measured by + `FLOPs`, e.g. `13e12` for NVIDIA V100 PCIe GPUs using fp32. +* `FMOE_FASTER_GLBPLC_ALPHA` is the fraction of the activation length in the + middle of the MLP to the input and output feature length, commonly seen to be +`2` or `4` in transformers. +* `FMOE_FASTER_GLBPLC_DMODEL` is the feature length of input and output of the + experts. This parameter can be set automatically by FastMoE. + +## Topology-aware Gate + +The two optimizations above do not change the behavior of the model, while this +one does. To reduce network congestion when training in distributed system +with hierarchical network topology, e.g. many GPUs in each of many nodes, the +number of samples transmitted through the slower upper-level network is +limited. The overfilling tokens select experts within the same lower-level +network to reduce the communication overhead. + +The example topology-aware gate is implemented as `FasterGate` among FastMoE's +gates. However, note that it may influence the accuracy of the model. And for +different training hardware, different topology-aware gates shall be designed +according to the specific case. + +The environment variable `FMOE_TOPO_GPUS_PER_NODE` represents number of GPUs in +each local network, e.g. each node. And `FMOE_TOPO_OUTGOING_FRACTION` controls +the fraction of tokens that are allowed to be sent across the upper-level +network. diff --git a/doc/fastermoe/smartsch.png b/doc/fastermoe/smartsch.png new file mode 100644 index 0000000000000000000000000000000000000000..3cd8ba6800060993d239a791316f42ce5d061975 GIT binary patch literal 9862 zcmeHtX*`tg+c%PseM=!_i3!t;88epb%U}k>jAc|b#tdWZW8Vu|3XvEpBx{i*p^}VH z5ki#gp(5GI+I{)`{{Q>`yPwbJe)qh1UOX@6x{h-@&ht3V>pZUS@x3MqYl-6JKE}<$ z!oq82YGlL0!ioXpLQW3gdfI;VC~#w?+n@|s9(GCoU|~6UJ;2x@z$=78^(3>%!wmnH zGS5pNp zst~Y-hPvb5{fTa5?|&&$1494>iVjeyJj?))FjVir5a1r7j&Mf6L4c(1>+4CjC%d4j z0rILQYA_@ihS(p1HpN?D;Zt0#*We#Tam|E0yTx zPxJ!1Q2*sW7y?IdRdaL?Q8zZDlD*vU9)^HN_LX=O0TcdGjRjn&@|S9uG1Ae~)kTdS zY+?qu*Tl$?Y^QJJs*mv{L6AfnUrRX5R>PHGpszt7g_(unX;>J-%iaj>sX=!$Ai8;A zY)Cd_3d9Z;NTI>4EC>V)4C@`@ZUI+Qhge(Sg2Qn7uwXMwcU6R%2Q|pm(8}1`z}1-U zPVjK^F(=aP1A^RwXntlSlMraI8^+w;1_JkmA#bMdD*1Q{B;h6K1^JTVr5_E=&tA{6cwhSGO| zsG1w;hqxK|+fZpi{(y}xYA9b*!$5UE2U0+ohXDc>R`g#G|XAt5`^>6{! z|FqeVJSccL&K>QGw)Ik{+nQ5VUA*x?qp44T4Z$6c zcX!qBQiEW9ysSLky}Sa*mX<`UyCo_J8id}rlx9J33BUO27bW?1P!E* z4Z>VQ17?c0BLkrTK}S2%Z9{B>f$6*`o&h1D>JDzMrdFg-Z&Mn@Gzb=ELWW{O0|P=m zVUAc2RTFQtnvb;?*25b@uni+1P`>T~7QhrFiYJ9)iwNKs_$xR7<$ncVn6Z9HhA<1u2^KRWeSC=1{GHJAqHA02 zx)Nd>tMRj%C6n zwGL#&VA(`Uu(yk?Oo+$WbBU&UD4?9`d)3n7zss$0Ik)J%C$CaYME=p4njX=wY^RZa z*Mk<ot zY4BI+(R2sNnFAZgyHB?XsB~U<{<2+bq38r`<=NT5_S6&sF@(VT+>akJ(Pk1$oH%iF zOCvT>lL7{0mm|3K5tQ}2_M5@X{%L3Mc#`It>!s6?{hV}zKfXK?V+ZpLhqSMyhw@tW z2VOsTfBN1U=gQ8eZqLJ!V?RH`^u@?dPC7I1zr;RyLRSy?9;i~DSQ%E624Ve=-bsu6-EEr${>z zD;Lidb9+xYwZC9-R@WRb@HO(Re*KXXH{|{&??z&p#x>bPQI{JH)k4yKC?D>Bnyey3 zN;{$XlRo@Bn144?^WslDzp{$Hz zR+y7-?8=uq@5$BR_Qk3)hz0WneH&~O9DUeoAt)?ax z>dLRHNx1ZMKfzl_>G(T$RKnIgl5$|R+FM4%)8LbbuE(?rD+Zb8h3@g#o#DzXW<+m* zJq9b&SarTag5F8?3EtIUW+uAkZA(2EwPY?Cx~1nkCk_QljZc3OB$b9TTy}hAa5@(J zBdUB>hzJA0B0Hk@a0q*Ib9v%biKM4q`E4W54<=sg)Wh~f`|j;XM(k~NFE0AF zbcZi(UieC@^zY0$KUsQQIVH#6f4Qf$+oY&?^TAOkPoDU4E(xRz>s1-7H}U1b+`-}c z^w`_Fk0<&I1ot+B6o;R!sfre_R~PmwM$mts-D%V{YB|O4v&_!HZ_xfw2qD1VNRa^+izJ$U< zd?he>KTANy+&E!`w-K9|aSRjguAzvD)r0YHeK(C^_6vG{CDS`S1ZP|@0xf(FDsPvAPkx|bE^og^{*-W_DLW_H0T+Z$}v8_P&YiW~x zixFS=N?7W=rzUTt)xE2l2>%I%E-Wf1-Bct}{sbL}`~q%m8^JB^98O8@PfwV-gPOMA zH9K>CU{2PYK_DP|>X)7`YjJ1>gg7Lup88XHE{(*D<9pOs&Bu+a(k_|`4UMA;_*gTC zn~8N>_c{C& zDbasn%thbma{u{BB-O$!!TkMs50%H>n+C7jqvEvN!CaNtOvsI=(J4~a&cu~cpUm8` zw_kHYkPypXQ=eaays`Y{x3N@UvJfR;)jBRgTDhj>`>&BwNc7%u(|LznWK}Ff-1+5O z=kHF`VOLLqmRYspr>KoX!pIrS5F;#!B0~ZiKi6u+ltKm-gGMOnr_~-A0T}} zIIC-DFuTw2jXYCp`Tk)(zQ0)(b5h*a%7A#^ggmNe&Fp;NaK#C1=^*j^$ioXx2>CU) zQKt?uIRA%b!|K2u zvE`}D(32z6CgCzVA|;&3v)>9JA8S-;MT$goGt<&9&s`7~_m)lTgMOusJq8n73xAY3 zHFAtPOS`@6%ZLZbAFcM}5Au5Z>hkrwzJWfe5A6LP`A4zc|M;k7x$HxU=0mqJ=4*XH zf1gr?*~Cy!^Gryw+!63_0NZbI{uobpu+=t}ongVCabFY@F66DQ?n%T^s{s!6Rm5NV-9ck|RH9?Qg0%ICd z(_Us&kjt8;dn>AC!d@SpfP;_K? zJTB*9DnH;WmtUD`o^!qQ^0zrnP5!}gSw}|k!7RCk;Du=D#aHy@>rz@TjXh?M#wGCE z+@TQ;fd0Ad+%-EpS>0FLYrwp57wVLAFIh7)qYn00eBL%STZ*lrU)U6s19mAu3impJ zDWY$^D=AVEW=BbZ+c3}v6ee#~@lWpL^tn&&D4#vlcx4IKdUio^7Tj=*aYv#{|0`wf z;>w`YP%uNO&2kM}Q$8tvFAgN10RW`Y2pA!%qRjR|SYqmNu=AM6(p`u#7O#6KzC7JtD?p5^ z6bV;2`WN(7V6$0}Z&Kca?2rI{Ntia_W+~O6Y^6^y3LNq&NU_rx4g4s@agk=!&(o1 zdb#a`V_U1Cc2Uh{?{kOzK_DOQnzIkyR?A(y%Wf`2Sx&#|8@3ho{MuLz_-iH7q_C9^ zW^4wPxrlni=R~-$=6Npuc82;_zlxS{Ji8&X*U*oKnwTBv6r5?jO8-#yVfQ@PxvTX} z`j?5xWf23zZQ8qd0E(pur^k>5u|g%8m8+*;H7-?ij_u{REIY=5+COEaYD?$q8Qp$5 z6Id7rntn`dA--w(LdA>k(eK^06C^M9uyRaFa;!qw&#Q@vp;6TM<^9LnT>Gm>Ef;`JjrIrzaMR zRcso6J=O%fd^xOR5MjKAn-nY|&u?^IVTTJU&T+!FXGKGTX(?r|W8H@)_qL60ZfTc` z&cxCfpY7?;q7Hx0Em!W`rvVux_aCjUsSxS&Eie41IUMVpl41`X=y?_xsd2=5Hxocd z&lLcG(jik2U=FT@X{^m#S0r0z2scH}QF30wc+NK;QLUHOqw1{g=mo?z|i>u5s%Y^9Rdpfvg$1K~?~0$uib6 zW7s?evXqm4;UZ6eX>*)AcdmgYota6B_$@Ew!Y6wwki*l<>)pt{p5pU%m6|bSWmSGZ zKjAfw{2+Pj%5(QUUJs`jvgLBR?-r-o5$~JXpwFCAVdCEzy8^AaBt(xVW2_`H*;eS~jxkuMV01f**k{_?l%&LpCwXf)`Bu zX;d8Krc35(qZ6|~18w#N%P2JUKQ2_J2PC4rii`u|yOjAu*0H?6;FO0&)(s`Z&M+1L+(-r*Bd6!Sc9{8e$3qwf%I{C zPn}xf&2ae>yAPr~5+roG-bW%8F)h`w6I`4?MCNWKZ0#ebf}>a11LS}R%~|)o|G&#X z=H8$&DDc>wqjzJ$&}mevKkNnKO4=n4%g0wzgV&4R)tk{ z;DUdC;v^Cee5fEuSu|um5ln-`-<>j&!xGt(x%zy3?Rg}GV@6{mEX*$KO|#{d;jF~U zdQmA5g;oCcM@tvRHHt=B#B&B0jV>G_6gcVt`F?l)p9AwlmCbqD!midc)6=BU7uu$1 ztpvjar@;*KB;dR;vHmhbWe~R3xT7^uqKJg11V7wNA$`OC!xQ5Jk-)kuD)kq^(+cm#d z^cgi{fDRJ*dNY_Mp5Wvn31pk7g3jkt@$vEMe(JpF(bbk$O=<0nK{wJwzqR}5ecq^@ zy?O#OBBt=s;&DsY+btPCw8qyQQ9MUSj>?8VH2HZ1w#c|3$`ehMNBrW`yR)U%QB-Qn z=q<9Z8Z~nN&MGIr8-pu-Scu|xQ$9D7bi~R1glDTFhN!3^$c5kM6hgnPs}w~oL(d*+ z;*hwNplkkF8!4v?-VCIqU5IpWQMawG14rLHc)g@W>Nrro^d!l+-ng3C)gO& z$^E7MD1$w2pH7KCcu;)R2E8h+dF7^7j zC^rGr#;>3{-M8AK@^#U*ML|U0Lv*Gs1tlPPJGs6--Tlitb-3|d%$qkn><4!bee3RS z{+9oTho9eVXVZ<9gXdFRW@b`yLna(}kHn=&q{kQw_N=0%>G9=8<3G~~jIVVE@1ASF zt4|6HzF*6YduOb9(76sQl?|RQ^!Rvs`K$55%Q>zi*`-tX;sv{kPibnj7d>@VPrn&~ zFvYgtHsmgDvMXCMdVef_(jhwRT(W97j#Nc4JhDIv3jC2#3Zvi17{MqQ6dbFf!)7k#=@uctA$`!j~2xAvK=g!2AVm>L51@8A_PqP4u?Y8Dd*| z{_COIL(>w5v4`rKr`~$!+7v<5h+j?QE?NUXHZ}`jfjQwV*JDzFlGKFQo!tLXwqbBI zlf{J~iNMQL9DvB33a`YZXJAu2GlBx6ndpGD6D#lfUIrSEyy(sIdwXs+5QYiJGhbQC zxFU2i%7>kmyGG<40GyxaZ;9_G2a@H|tlT%4{XPpq>aci_z^T$+%&Wq9(C(}o&6RVn zQC=t;oOdI7aP(e2(kWl~+Oc)M+^cpqci+eYxzqnfXM8NW7kUjKJjzqAUq6s|4b$F0 zk1mLcJQNcC{DgWvQc`lYGM3r@;;^sbC3a;;1}&@0z}#PQrb#-fsg`ayP>hLF&I2`x z$ej(!*W+m|6#{jByc#SkzjRKSbE$3Ye7^ob)u7vH4UAGd=t) z{i>DD=;`8-zia4D&+@aA%pS#ZbpGsI&D;*3*MN?x~$$?w!8-^dRQrc#>@ zP~GFO(W5c98iKJ4ZylF|rIpQ0ZiuZ-BYUp?P86D{rPWPLJrN1tuzb~V28k>_RV$aL zqb;apzYZG#Wsh3wdVaFQ+*&Uj##;WBE~nEOXu#(n(NlN47m`26Y<9>_3^iy)bv%m` zy;^cFZ0%S`bmZ;wAhcLXrt9m++MjCW4;(hQ!u;5%@zCjQ^Se#1GP@SOj(aAzX6y4J z4+|eMexrRYuS|BP8U*C&8RzSLr4gQQt}kQ+FN~g#GHjIVIHvYO}z;+@VEE|L9+FV#b$b6CHGpF2qsS`W;SiW$MLfCsah&34k$ksFF&94~?lsw`@H2^ojwl z-f1WUK~nlLRGz!jf93kD>7V=gLlalhz6w5?c%vkD0${0wxckobh?Iox)YpmwS(bQe zat)S|ocK;skC@w@cJAs)J!0}inL65swx_o!{nYcS6XYmi#A3b%f zy?@s#-);IxY=1?0<xKqB`fNS+I}`NmjXJF;n{Wr2{rG$0}SVN&Br zXM`3Ws`P=0(Wz#!LWJy8o&N`$f&a2*g%b>)fX zRMV=*UnHcjF3#tJD8;3a>GMyENZsuU0)nNLX19vDareDiuCyekXkY5q5gx2hVd96$ zl`bst)E?E-|0khg7skD^QE|R=>C5Hy3%M=P-af78&5vCyXJ)4U&1y9N#cB+Nn1gxF z72iq(xcs-~LL~tsy&MQdW#yRueE09}(d?aE8C71V*OY`*FbTOIr?aC)#D<^ecc=%y z&>6$Md|F=o?YcY{?wL20M2fw9*jOtYsyt-ck3d;O66R`{3Vm1eb0C429J zz`2|1ArO|zbA#D5;g1QxV!!j3Uv9xNI8U;FD&)u%Phy1uVh-XgCn^`aRz zp5?`D|E4y;O&8vRg1b+mMN05j`zc?)e{jB@313gYiLB!n1pFs$b1>m9unR2m#L>?8 z&3eIZ>(-s^nulab{|&`SWIE!*g)?RM*SV``3?W*K`9%iP;ktvi23i#6& zWBB?mm_`Hk@XR zLc?+@>s{3O+Q726R`cgY2?;}z#sd!!r`nIAf#F%md+97(co()m{B1GCH$u;3h5~D3 OF*CL_dT4O*@_zwZpL$XN literal 0 HcmV?d00001 From baa47d466c54c68cba394e7e80cd848dfe42abc8 Mon Sep 17 00:00:00 2001 From: Rick Ho Date: Fri, 1 Apr 2022 16:45:36 +0800 Subject: [PATCH 3/6] update chinese readme --- doc/readme-cn.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/doc/readme-cn.md b/doc/readme-cn.md index 76f2173d..df00e711 100644 --- a/doc/readme-cn.md +++ b/doc/readme-cn.md @@ -95,6 +95,15 @@ FastMoE 的模型并行模式需要专门的并行策略, 而 PyTorch 和 Megatr 都不支持这样的策略. 因此, 需要使用 `fmoe.DistributedGroupedDataParallel` 模块来代替 PyTorch 的 DDP 模块. +### 如何训练得更快 + +在 PPoPP'22 会议上有一篇论文: _FasterMoE: modeling and optimizing training of +large-scale dynamic pre-trained models_. 我们将文中的技术集成到了 FastMoE 系统中, +从而提升其模型并行的效率. + +这些新特性被命名为 **Faster Performance Features**, 并通过一些环境变量来控制是否 +启用它们. 详见[这篇单独的文档](doc/fastermoe). + ## 答疑 / 讨论 如果您在使用 FastMoE 的过程中有任何疑问, 或您有兴趣参与 FastMoE 的相关工作, From 6557e88e90b91ed9ffbd210ed03724dbe0490d9b Mon Sep 17 00:00:00 2001 From: Rick Ho Date: Fri, 1 Apr 2022 17:34:08 +0800 Subject: [PATCH 4/6] update chinese url --- doc/fastermoe/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/fastermoe/README.md b/doc/fastermoe/README.md index 5fdbf3a9..85d08f2d 100644 --- a/doc/fastermoe/README.md +++ b/doc/fastermoe/README.md @@ -1,7 +1,7 @@ Boost the Performance by FasterMoE === -一个中文版见[这篇博客](https://laekov.com.cn/view/181401) +一个中文版见[这篇博客](https://laekov.com.cn/view/181401#howto) There are three main optimizations in the PPoPP'22 paper _FasterMoE: Modeling and Optimizing Training of Large-scale Dynamic Pre-trained Models_. Thanks to From 1b8fef31fab317b75321174385e19b84278d1c81 Mon Sep 17 00:00:00 2001 From: Rick Ho Date: Sat, 2 Apr 2022 09:56:45 +0800 Subject: [PATCH 5/6] some comments on code and test --- fmoe/fastermoe/schedule.py | 2 ++ tests/README.md | 7 +++++++ 2 files changed, 9 insertions(+) create mode 100644 tests/README.md diff --git a/fmoe/fastermoe/schedule.py b/fmoe/fastermoe/schedule.py index 3a5fc31f..14290702 100644 --- a/fmoe/fastermoe/schedule.py +++ b/fmoe/fastermoe/schedule.py @@ -61,6 +61,8 @@ def stash_fn(params, idx): out = _local_gather(local_output_buf, pos_g, out_batch_size, maybe_overlap=False) + # gib and local_input_buf are necessary, because ctx.gibs are created + # based on their memory variables = (pos_s, pos_g, local_expert_count, global_expert_count, stored_models, gib, local_input_buf) diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 00000000..0a6f2ad8 --- /dev/null +++ b/tests/README.md @@ -0,0 +1,7 @@ +FastMoE test +=== + +To run unit test, directly run `pytest` in this directory. + +`test.sh` is a wrapper script to execute single tests without pytest for +debugging purpose. From 33895a052c40b97ba39bc014d0367b10d2724943 Mon Sep 17 00:00:00 2001 From: Rick Ho Date: Sat, 2 Apr 2022 10:15:33 +0800 Subject: [PATCH 6/6] release note v1.0.0 --- doc/release-note.md | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/doc/release-note.md b/doc/release-note.md index ba016b3a..5cf877d5 100644 --- a/doc/release-note.md +++ b/doc/release-note.md @@ -1,3 +1,19 @@ +## v1.0.0 + +### FasterMoE + +* The new performance boosting features in the PPoPP'22 paper FasterMoE, detailed in the document. + * Expert Shadowing. + * Smart Scheduling. + * Topology-aware gate. + +### Bug fixes + +* Transformer-XL examples. +* Compatibility to PyTorch versions. +* Megatron-LM documents. +* GShardGate. + ## v0.3.0 ### FMoE core