diff --git a/compiler/sam-kernels.sh b/compiler/sam-kernels.sh index 5cc75fe0..6caf0a21 100755 --- a/compiler/sam-kernels.sh +++ b/compiler/sam-kernels.sh @@ -47,6 +47,8 @@ GEN_KERNEL_NAMES=( mat_vecmul_iter mat_vecmul_iter_short tensor3_website_expr + mat_mask_tri_DA3sum_final + mat_mask_tri_DA3_final ) HAND_KERNEL_NAMES=( @@ -95,6 +97,8 @@ TACO_ARGS=( "x(i)=B(i,j)*C(j,k)*D(k,l)*E(l,m)*f(m) -f=x:s -f=B:ss -f=C:ss -f=D:ss -f=E:ss -f=f:s -s=reorder(i,j,k,l,m)" "x(i)=B(i,j)*C(j,k)*d(k) -f=x:s -f=B:ss -f=C:ss -f=d:s -s=reorder(i,j,k)" "x=B(i)*C(j)*D(i,j,k)*E(j,l)*F(l,m,n) -f=B:s -f=C:s -f=D:sss -f=E:ss -f=F:sss -s=reorder(i,j,k,l,m,n)" + "x=Diag(i,l)*(B1(i,j)*B2(j,k)*B3(k,l)) -f=Diag:ss -f=B1:ss -f=B2:ss -f=B3:ss:1,0 -s=reorder(i,l,j,k)" + "X(i,l)=Diag(i,l)*(B1(i,j)*B2(j,k)*B3(k,l)) -f=X:ss -f=Diag:ss -f=B1:ss -f=B2:ss -f=B3:ss:1,0 -s=reorder(i,l,j,k)" ) mkdir -p $dir diff --git a/compiler/sam-outputs/dot/mat_mask_tri_DA3_final.gv b/compiler/sam-outputs/dot/mat_mask_tri_DA3_final.gv new file mode 100644 index 00000000..1bac63f1 --- /dev/null +++ b/compiler/sam-outputs/dot/mat_mask_tri_DA3_final.gv @@ -0,0 +1,119 @@ +digraph SAM { + comment="X=ss01,Diag=ss01,B1=ss01,B2=ss01,B3=ss10" + 47 [comment="type=fiberlookup,index=i,tensor=Diag,mode=0,format=compressed,src=true,root=true" label="FiberLookup i: Diag0\ncompressed" color=green4 shape=box style=filled type="fiberlookup" index="i" tensor="Diag" mode="0" format="compressed" src="true" root="true"] + 46 [comment="type=intersect,index=i" label="intersect i" color=purple shape=box style=filled type="intersect" index="i"] + 45 [comment="type=broadcast" shape=point style=invis type="broadcast"] + 14 [comment="type=crddrop,outer=i,inner=l" label="CrdDrop i,l" color=orange shape=box style=filled type="crddrop" outer="i" inner="l"] + 2 [comment="type=fiberwrite,index=i,tensor=X,mode=0,format=compressed,segsize=2,crdsize=Diag0_dim,sink=true" label="FiberWrite i: X0\ncompressed" color=green3 shape=box style=filled type="fiberwrite" index="i" tensor="X" mode="0" format="compressed" segsize="2" crdsize="Diag0_dim" sink="true"] + 1 [comment="type=fiberwrite,index=l,tensor=X,mode=1,format=compressed,segsize=Diag0_dim+1,crdsize=Diag0_dim*Diag1_dim,sink=true" label="FiberWrite l: X1\ncompressed" color=green3 shape=box style=filled type="fiberwrite" index="l" tensor="X" mode="1" format="compressed" segsize="Diag0_dim+1" crdsize="Diag0_dim*Diag1_dim" sink="true"] + 44 [comment="type=repsiggen,index=i" label="RepeatSignalGenerator i" color=cyan3 shape=box style=filled type="repsiggen" index="i"] + 43 [comment="type=broadcast" shape=point style=invis type="broadcast"] + 41 [comment="type=repeat,index=i,tensor=B2,root=true" label="Repeat i: B2" color=cyan2 shape=box style=filled type="repeat" index="i" tensor="B2" root="true"] + 34 [comment="type=repeat,index=l,tensor=B2,root=false" label="Repeat l: B2" color=cyan2 shape=box style=filled type="repeat" index="l" tensor="B2" root="false"] + 32 [comment="type=fiberlookup,index=j,tensor=B2,mode=0,format=compressed,src=true,root=false" label="FiberLookup j: B20\ncompressed" color=green4 shape=box style=filled type="fiberlookup" index="j" tensor="B2" mode="0" format="compressed" src="true" root="false"] + 30 [comment="type=intersect,index=j" label="intersect j" color=purple shape=box style=filled type="intersect" index="j"] + 29 [comment="type=broadcast" shape=point style=invis type="broadcast"] + 16 [comment="type=crddrop,outer=j,inner=k" label="CrdDrop j,k" color=orange shape=box style=filled type="crddrop" outer="j" inner="k"] + 15 [comment="type=crddrop,outer=l,inner=j" label="CrdDrop l,j" color=orange shape=box style=filled type="crddrop" outer="l" inner="j"] + 28 [comment="type=repsiggen,index=j" label="RepeatSignalGenerator j" color=cyan3 shape=box style=filled type="repsiggen" index="j"] + 27 [comment="type=broadcast" shape=point style=invis type="broadcast"] + 25 [comment="type=repeat,index=j,tensor=Diag,root=false" label="Repeat j: Diag" color=cyan2 shape=box style=filled type="repeat" index="j" tensor="Diag" root="false"] + 17 [comment="type=repeat,index=k,tensor=Diag,root=false" label="Repeat k: Diag" color=cyan2 shape=box style=filled type="repeat" index="k" tensor="Diag" root="false"] + 8 [comment="type=arrayvals,tensor=Diag" label="Array Vals: Diag" color=green2 shape=box style=filled type="arrayvals" tensor="Diag"] + 7 [comment="type=mul" label="Mul" color=brown shape=box style=filled type="mul"] + 6 [comment="type=reduce" label="Reduce" color=brown shape=box style=filled type="reduce"] + 5 [comment="type=reduce" label="Reduce" color=brown shape=box style=filled type="reduce"] + 0 [comment="type=fiberwrite,mode=vals,tensor=X,size=1*Diag0_dim*Diag1_dim,sink=true" label="FiberWrite Vals: X" color=green3 shape=box style=filled type="fiberwrite" tensor="X" mode="vals" size="1*Diag0_dim*Diag1_dim" sink="true"] + 26 [comment="type=repeat,index=j,tensor=B3,root=false" label="Repeat j: B3" color=cyan2 shape=box style=filled type="repeat" index="j" tensor="B3" root="false"] + 24 [comment="type=fiberlookup,index=k,tensor=B3,mode=0,format=compressed,src=true,root=false" label="FiberLookup k: B30\ncompressed" color=green4 shape=box style=filled type="fiberlookup" index="k" tensor="B3" mode="0" format="compressed" src="true" root="false"] + 22 [comment="type=intersect,index=k" label="intersect k" color=purple shape=box style=filled type="intersect" index="k"] + 21 [comment="type=broadcast" shape=point style=invis type="broadcast"] + 20 [comment="type=repsiggen,index=k" label="RepeatSignalGenerator k" color=cyan3 shape=box style=filled type="repsiggen" index="k"] + 19 [comment="type=broadcast" shape=point style=invis type="broadcast"] + 18 [comment="type=repeat,index=k,tensor=B1,root=false" label="Repeat k: B1" color=cyan2 shape=box style=filled type="repeat" index="k" tensor="B1" root="false"] + 11 [comment="type=arrayvals,tensor=B1" label="Array Vals: B1" color=green2 shape=box style=filled type="arrayvals" tensor="B1"] + 10 [comment="type=mul" label="Mul" color=brown shape=box style=filled type="mul"] + 9 [comment="type=mul" label="Mul" color=brown shape=box style=filled type="mul"] + 12 [comment="type=arrayvals,tensor=B2" label="Array Vals: B2" color=green2 shape=box style=filled type="arrayvals" tensor="B2"] + 13 [comment="type=arrayvals,tensor=B3" label="Array Vals: B3" color=green2 shape=box style=filled type="arrayvals" tensor="B3"] + 23 [comment="type=fiberlookup,index=k,tensor=B2,mode=1,format=compressed,src=true,root=false" label="FiberLookup k: B21\ncompressed" color=green4 shape=box style=filled type="fiberlookup" index="k" tensor="B2" mode="1" format="compressed" src="true" root="false"] + 42 [comment="type=repeat,index=i,tensor=B3,root=true" label="Repeat i: B3" color=cyan2 shape=box style=filled type="repeat" index="i" tensor="B3" root="true"] + 40 [comment="type=fiberlookup,index=l,tensor=B3,mode=1,format=compressed,src=true,root=false" label="FiberLookup l: B31\ncompressed" color=green4 shape=box style=filled type="fiberlookup" index="l" tensor="B3" mode="1" format="compressed" src="true" root="false"] + 38 [comment="type=intersect,index=l" label="intersect l" color=purple shape=box style=filled type="intersect" index="l"] + 37 [comment="type=broadcast" shape=point style=invis type="broadcast"] + 36 [comment="type=repsiggen,index=l" label="RepeatSignalGenerator l" color=cyan3 shape=box style=filled type="repsiggen" index="l"] + 35 [comment="type=broadcast" shape=point style=invis type="broadcast"] + 33 [comment="type=repeat,index=l,tensor=B1,root=false" label="Repeat l: B1" color=cyan2 shape=box style=filled type="repeat" index="l" tensor="B1" root="false"] + 31 [comment="type=fiberlookup,index=j,tensor=B1,mode=1,format=compressed,src=true,root=false" label="FiberLookup j: B11\ncompressed" color=green4 shape=box style=filled type="fiberlookup" index="j" tensor="B1" mode="1" format="compressed" src="true" root="false"] + 39 [comment="type=fiberlookup,index=l,tensor=Diag,mode=1,format=compressed,src=true,root=false" label="FiberLookup l: Diag1\ncompressed" color=green4 shape=box style=filled type="fiberlookup" index="l" tensor="Diag" mode="1" format="compressed" src="true" root="false"] + 48 [comment="type=fiberlookup,index=i,tensor=B1,mode=0,format=compressed,src=true,root=true" label="FiberLookup i: B10\ncompressed" color=green4 shape=box style=filled type="fiberlookup" index="i" tensor="B1" mode="0" format="compressed" src="true" root="true"] + 47 -> 46 [label="crd_in-Diag" style=dashed type="crd" comment="in-Diag"] + 46 -> 45 [label="crd_i" style=dashed type="crd" comment="i"] + 45 -> 14 [label="crd_i" style=dashed type="crd" comment="i"] + 14 -> 2 [label="crd_outer-i" style=dashed type="crd" comment="outer-i"] + 14 -> 1 [label="crd_inner-l" style=dashed type="crd" comment="inner-l"] + 45 -> 44 [label="crd" style=dashed type="crd" comment=""] + 44 -> 43 [label="repsig" style=dotted type="repsig"] + 43 -> 41 [label="repsig" style=dotted type="repsig"] + 41 -> 34 [label="ref" style=bold type="ref"] + 34 -> 32 [label="ref" style=bold type="ref"] + 32 -> 30 [label="crd_in-B2" style=dashed type="crd" comment="in-B2"] + 30 -> 29 [label="crd_j" style=dashed type="crd" comment="j"] + 29 -> 16 [label="crd_j" style=dashed type="crd" comment="j"] + 16 -> 15 [label="crd_outer-j" style=dashed type="crd" comment="outer-j"] + 15 -> 14 [label="crd_outer-l" style=dashed type="crd" comment="outer-l"] + 29 -> 28 [label="crd" style=dashed type="crd" comment=""] + 28 -> 27 [label="repsig" style=dotted type="repsig"] + 27 -> 25 [label="repsig" style=dotted type="repsig"] + 25 -> 17 [label="ref" style=bold type="ref"] + 17 -> 8 [label="ref" style=bold type="ref"] + 8 -> 7 [label="val" type="val"] + 7 -> 6 [label="val" type="val"] + 6 -> 5 [label="val" type="val"] + 5 -> 0 [label="val" type="val"] + 27 -> 26 [label="repsig" style=dotted type="repsig"] + 26 -> 24 [label="ref" style=bold type="ref"] + 24 -> 22 [label="crd_in-B3" style=dashed type="crd" comment="in-B3"] + 22 -> 21 [label="crd_k" style=dashed type="crd" comment="k"] + 21 -> 16 [label="crd_k" style=dashed type="crd" comment="k"] + 21 -> 20 [label="crd" style=dashed type="crd" comment=""] + 20 -> 19 [label="repsig" style=dotted type="repsig"] + 19 -> 17 [label="repsig" style=dotted type="repsig"] + 19 -> 18 [label="repsig" style=dotted type="repsig"] + 18 -> 11 [label="ref" style=bold type="ref"] + 11 -> 10 [label="val" type="val"] + 10 -> 9 [label="val" type="val"] + 9 -> 7 [label="val" type="val"] + 22 -> 12 [label="ref_out-B2" style=bold type="ref" comment="out-B2"] + 12 -> 10 [label="val" type="val"] + 22 -> 13 [label="ref_out-B3" style=bold type="ref" comment="out-B3"] + 13 -> 9 [label="val" type="val"] + 24 -> 22 [label="ref_in-B3" style=bold type="ref" comment="in-B3"] + 30 -> 18 [label="ref_out-B1" style=bold type="ref" comment="out-B1"] + 30 -> 23 [label="ref_out-B2" style=bold type="ref" comment="out-B2"] + 23 -> 22 [label="crd_in-B2" style=dashed type="crd" comment="in-B2"] + 23 -> 22 [label="ref_in-B2" style=bold type="ref" comment="in-B2"] + 32 -> 30 [label="ref_in-B2" style=bold type="ref" comment="in-B2"] + 43 -> 42 [label="repsig" style=dotted type="repsig"] + 42 -> 40 [label="ref" style=bold type="ref"] + 40 -> 38 [label="crd_in-B3" style=dashed type="crd" comment="in-B3"] + 38 -> 37 [label="crd_l" style=dashed type="crd" comment="l"] + 37 -> 15 [label="crd_l" style=dashed type="crd" comment="l"] + 37 -> 36 [label="crd" style=dashed type="crd" comment=""] + 36 -> 35 [label="repsig" style=dotted type="repsig"] + 35 -> 33 [label="repsig" style=dotted type="repsig"] + 33 -> 31 [label="ref" style=bold type="ref"] + 31 -> 30 [label="crd_in-B1" style=dashed type="crd" comment="in-B1"] + 31 -> 30 [label="ref_in-B1" style=bold type="ref" comment="in-B1"] + 35 -> 34 [label="repsig" style=dotted type="repsig"] + 38 -> 25 [label="ref_out-Diag" style=bold type="ref" comment="out-Diag"] + 38 -> 26 [label="ref_out-B3" style=bold type="ref" comment="out-B3"] + 40 -> 38 [label="ref_in-B3" style=bold type="ref" comment="in-B3"] + 46 -> 39 [label="ref_out-Diag" style=bold type="ref" comment="out-Diag"] + 39 -> 38 [label="crd_in-Diag" style=dashed type="crd" comment="in-Diag"] + 39 -> 38 [label="ref_in-Diag" style=bold type="ref" comment="in-Diag"] + 46 -> 33 [label="ref_out-B1" style=bold type="ref" comment="out-B1"] + 47 -> 46 [label="ref_in-Diag" style=bold type="ref" comment="in-Diag"] + 48 -> 46 [label="crd_in-B1" style=dashed type="crd" comment="in-B1"] + 48 -> 46 [label="ref_in-B1" style=bold type="ref" comment="in-B1"] +} diff --git a/compiler/sam-outputs/dot/mat_mask_tri_DA3sum_final.gv b/compiler/sam-outputs/dot/mat_mask_tri_DA3sum_final.gv new file mode 100644 index 00000000..6b699547 --- /dev/null +++ b/compiler/sam-outputs/dot/mat_mask_tri_DA3sum_final.gv @@ -0,0 +1,102 @@ +digraph SAM { + comment="x=none,Diag=ss01,B1=ss01,B2=ss01,B3=ss10" + 42 [comment="type=fiberlookup,index=i,tensor=Diag,mode=0,format=compressed,src=true,root=true" label="FiberLookup i: Diag0\ncompressed" color=green4 shape=box style=filled type="fiberlookup" index="i" tensor="Diag" mode="0" format="compressed" src="true" root="true"] + 41 [comment="type=intersect,index=i" label="intersect i" color=purple shape=box style=filled type="intersect" index="i"] + 40 [comment="type=repsiggen,index=i" label="RepeatSignalGenerator i" color=cyan3 shape=box style=filled type="repsiggen" index="i"] + 39 [comment="type=broadcast" shape=point style=invis type="broadcast"] + 37 [comment="type=repeat,index=i,tensor=B2,root=true" label="Repeat i: B2" color=cyan2 shape=box style=filled type="repeat" index="i" tensor="B2" root="true"] + 31 [comment="type=repeat,index=l,tensor=B2,root=false" label="Repeat l: B2" color=cyan2 shape=box style=filled type="repeat" index="l" tensor="B2" root="false"] + 29 [comment="type=fiberlookup,index=j,tensor=B2,mode=0,format=compressed,src=true,root=false" label="FiberLookup j: B20\ncompressed" color=green4 shape=box style=filled type="fiberlookup" index="j" tensor="B2" mode="0" format="compressed" src="true" root="false"] + 27 [comment="type=intersect,index=j" label="intersect j" color=purple shape=box style=filled type="intersect" index="j"] + 26 [comment="type=repsiggen,index=j" label="RepeatSignalGenerator j" color=cyan3 shape=box style=filled type="repsiggen" index="j"] + 25 [comment="type=broadcast" shape=point style=invis type="broadcast"] + 23 [comment="type=repeat,index=j,tensor=Diag,root=false" label="Repeat j: Diag" color=cyan2 shape=box style=filled type="repeat" index="j" tensor="Diag" root="false"] + 16 [comment="type=repeat,index=k,tensor=Diag,root=false" label="Repeat k: Diag" color=cyan2 shape=box style=filled type="repeat" index="k" tensor="Diag" root="false"] + 10 [comment="type=arrayvals,tensor=Diag" label="Array Vals: Diag" color=green2 shape=box style=filled type="arrayvals" tensor="Diag"] + 9 [comment="type=mul" label="Mul" color=brown shape=box style=filled type="mul"] + 8 [comment="type=reduce" label="Reduce" color=brown shape=box style=filled type="reduce"] + 7 [comment="type=reduce" label="Reduce" color=brown shape=box style=filled type="reduce"] + 6 [comment="type=reduce" label="Reduce" color=brown shape=box style=filled type="reduce"] + 5 [comment="type=reduce" label="Reduce" color=brown shape=box style=filled type="reduce"] + 0 [comment="type=fiberwrite,mode=vals,tensor=x,size=1,sink=true" label="FiberWrite Vals: x" color=green3 shape=box style=filled type="fiberwrite" tensor="x" mode="vals" size="1" sink="true"] + 24 [comment="type=repeat,index=j,tensor=B3,root=false" label="Repeat j: B3" color=cyan2 shape=box style=filled type="repeat" index="j" tensor="B3" root="false"] + 22 [comment="type=fiberlookup,index=k,tensor=B3,mode=0,format=compressed,src=true,root=false" label="FiberLookup k: B30\ncompressed" color=green4 shape=box style=filled type="fiberlookup" index="k" tensor="B3" mode="0" format="compressed" src="true" root="false"] + 20 [comment="type=intersect,index=k" label="intersect k" color=purple shape=box style=filled type="intersect" index="k"] + 19 [comment="type=repsiggen,index=k" label="RepeatSignalGenerator k" color=cyan3 shape=box style=filled type="repsiggen" index="k"] + 18 [comment="type=broadcast" shape=point style=invis type="broadcast"] + 17 [comment="type=repeat,index=k,tensor=B1,root=false" label="Repeat k: B1" color=cyan2 shape=box style=filled type="repeat" index="k" tensor="B1" root="false"] + 13 [comment="type=arrayvals,tensor=B1" label="Array Vals: B1" color=green2 shape=box style=filled type="arrayvals" tensor="B1"] + 12 [comment="type=mul" label="Mul" color=brown shape=box style=filled type="mul"] + 11 [comment="type=mul" label="Mul" color=brown shape=box style=filled type="mul"] + 14 [comment="type=arrayvals,tensor=B2" label="Array Vals: B2" color=green2 shape=box style=filled type="arrayvals" tensor="B2"] + 15 [comment="type=arrayvals,tensor=B3" label="Array Vals: B3" color=green2 shape=box style=filled type="arrayvals" tensor="B3"] + 21 [comment="type=fiberlookup,index=k,tensor=B2,mode=1,format=compressed,src=true,root=false" label="FiberLookup k: B21\ncompressed" color=green4 shape=box style=filled type="fiberlookup" index="k" tensor="B2" mode="1" format="compressed" src="true" root="false"] + 38 [comment="type=repeat,index=i,tensor=B3,root=true" label="Repeat i: B3" color=cyan2 shape=box style=filled type="repeat" index="i" tensor="B3" root="true"] + 36 [comment="type=fiberlookup,index=l,tensor=B3,mode=1,format=compressed,src=true,root=false" label="FiberLookup l: B31\ncompressed" color=green4 shape=box style=filled type="fiberlookup" index="l" tensor="B3" mode="1" format="compressed" src="true" root="false"] + 34 [comment="type=intersect,index=l" label="intersect l" color=purple shape=box style=filled type="intersect" index="l"] + 33 [comment="type=repsiggen,index=l" label="RepeatSignalGenerator l" color=cyan3 shape=box style=filled type="repsiggen" index="l"] + 32 [comment="type=broadcast" shape=point style=invis type="broadcast"] + 30 [comment="type=repeat,index=l,tensor=B1,root=false" label="Repeat l: B1" color=cyan2 shape=box style=filled type="repeat" index="l" tensor="B1" root="false"] + 28 [comment="type=fiberlookup,index=j,tensor=B1,mode=1,format=compressed,src=true,root=false" label="FiberLookup j: B11\ncompressed" color=green4 shape=box style=filled type="fiberlookup" index="j" tensor="B1" mode="1" format="compressed" src="true" root="false"] + 35 [comment="type=fiberlookup,index=l,tensor=Diag,mode=1,format=compressed,src=true,root=false" label="FiberLookup l: Diag1\ncompressed" color=green4 shape=box style=filled type="fiberlookup" index="l" tensor="Diag" mode="1" format="compressed" src="true" root="false"] + 43 [comment="type=fiberlookup,index=i,tensor=B1,mode=0,format=compressed,src=true,root=true" label="FiberLookup i: B10\ncompressed" color=green4 shape=box style=filled type="fiberlookup" index="i" tensor="B1" mode="0" format="compressed" src="true" root="true"] + 42 -> 41 [label="crd_in-Diag" style=dashed type="crd" comment="in-Diag"] + 41 -> 40 [label="crd" style=dashed type="crd"] + 40 -> 39 [label="repsig" style=dotted type="repsig"] + 39 -> 37 [label="repsig" style=dotted type="repsig"] + 37 -> 31 [label="ref" style=bold type="ref"] + 31 -> 29 [label="ref" style=bold type="ref"] + 29 -> 27 [label="crd_in-B2" style=dashed type="crd" comment="in-B2"] + 27 -> 26 [label="crd" style=dashed type="crd"] + 26 -> 25 [label="repsig" style=dotted type="repsig"] + 25 -> 23 [label="repsig" style=dotted type="repsig"] + 23 -> 16 [label="ref" style=bold type="ref"] + 16 -> 10 [label="ref" style=bold type="ref"] + 10 -> 9 [label="val" type="val"] + 9 -> 8 [label="val" type="val"] + 8 -> 7 [label="val" type="val"] + 7 -> 6 [label="val" type="val"] + 6 -> 5 [label="val" type="val"] + 5 -> 0 [label="val" type="val"] + 25 -> 24 [label="repsig" style=dotted type="repsig"] + 24 -> 22 [label="ref" style=bold type="ref"] + 22 -> 20 [label="crd_in-B3" style=dashed type="crd" comment="in-B3"] + 20 -> 19 [label="crd" style=dashed type="crd"] + 19 -> 18 [label="repsig" style=dotted type="repsig"] + 18 -> 16 [label="repsig" style=dotted type="repsig"] + 18 -> 17 [label="repsig" style=dotted type="repsig"] + 17 -> 13 [label="ref" style=bold type="ref"] + 13 -> 12 [label="val" type="val"] + 12 -> 11 [label="val" type="val"] + 11 -> 9 [label="val" type="val"] + 20 -> 14 [label="ref_out-B2" style=bold type="ref" comment="out-B2"] + 14 -> 12 [label="val" type="val"] + 20 -> 15 [label="ref_out-B3" style=bold type="ref" comment="out-B3"] + 15 -> 11 [label="val" type="val"] + 22 -> 20 [label="ref_in-B3" style=bold type="ref" comment="in-B3"] + 27 -> 17 [label="ref_out-B1" style=bold type="ref" comment="out-B1"] + 27 -> 21 [label="ref_out-B2" style=bold type="ref" comment="out-B2"] + 21 -> 20 [label="crd_in-B2" style=dashed type="crd" comment="in-B2"] + 21 -> 20 [label="ref_in-B2" style=bold type="ref" comment="in-B2"] + 29 -> 27 [label="ref_in-B2" style=bold type="ref" comment="in-B2"] + 39 -> 38 [label="repsig" style=dotted type="repsig"] + 38 -> 36 [label="ref" style=bold type="ref"] + 36 -> 34 [label="crd_in-B3" style=dashed type="crd" comment="in-B3"] + 34 -> 33 [label="crd" style=dashed type="crd"] + 33 -> 32 [label="repsig" style=dotted type="repsig"] + 32 -> 30 [label="repsig" style=dotted type="repsig"] + 30 -> 28 [label="ref" style=bold type="ref"] + 28 -> 27 [label="crd_in-B1" style=dashed type="crd" comment="in-B1"] + 28 -> 27 [label="ref_in-B1" style=bold type="ref" comment="in-B1"] + 32 -> 31 [label="repsig" style=dotted type="repsig"] + 34 -> 23 [label="ref_out-Diag" style=bold type="ref" comment="out-Diag"] + 34 -> 24 [label="ref_out-B3" style=bold type="ref" comment="out-B3"] + 36 -> 34 [label="ref_in-B3" style=bold type="ref" comment="in-B3"] + 41 -> 35 [label="ref_out-Diag" style=bold type="ref" comment="out-Diag"] + 35 -> 34 [label="crd_in-Diag" style=dashed type="crd" comment="in-Diag"] + 35 -> 34 [label="ref_in-Diag" style=bold type="ref" comment="in-Diag"] + 41 -> 30 [label="ref_out-B1" style=bold type="ref" comment="out-B1"] + 42 -> 41 [label="ref_in-Diag" style=bold type="ref" comment="in-Diag"] + 43 -> 41 [label="crd_in-B1" style=dashed type="crd" comment="in-B1"] + 43 -> 41 [label="ref_in-B1" style=bold type="ref" comment="in-B1"] +} diff --git a/compiler/sam-outputs/onyx-dot/mat_mask_tri_DA3sum_final.gv b/compiler/sam-outputs/onyx-dot/mat_mask_tri_DA3sum_final.gv new file mode 100644 index 00000000..3ccaa34e --- /dev/null +++ b/compiler/sam-outputs/onyx-dot/mat_mask_tri_DA3sum_final.gv @@ -0,0 +1,102 @@ +digraph SAM { + comment="x=none,E=ss01,B=ss01,C=ss01,D=ss10" + 42 [comment="type=fiberlookup,index=i,tensor=E,mode=0,format=compressed,src=true,root=true" label="FiberLookup i: E0\ncompressed" color=green4 shape=box style=filled type="fiberlookup" index="i" tensor="E" mode="0" format="compressed" src="true" root="true"] + 41 [comment="type=intersect,index=i" label="intersect i" color=purple shape=box style=filled type="intersect" index="i"] + 40 [comment="type=repsiggen,index=i" label="RepeatSignalGenerator i" color=cyan3 shape=box style=filled type="repsiggen" index="i"] + 39 [comment="type=broadcast" shape=point style=invis type="broadcast"] + 37 [comment="type=repeat,index=i,tensor=C,root=true" label="Repeat i: C" color=cyan2 shape=box style=filled type="repeat" index="i" tensor="C" root="true"] + 31 [comment="type=repeat,index=l,tensor=C,root=false" label="Repeat l: C" color=cyan2 shape=box style=filled type="repeat" index="l" tensor="C" root="false"] + 29 [comment="type=fiberlookup,index=j,tensor=C,mode=0,format=compressed,src=true,root=false" label="FiberLookup j: C0\ncompressed" color=green4 shape=box style=filled type="fiberlookup" index="j" tensor="C" mode="0" format="compressed" src="true" root="false"] + 27 [comment="type=intersect,index=j" label="intersect j" color=purple shape=box style=filled type="intersect" index="j"] + 26 [comment="type=repsiggen,index=j" label="RepeatSignalGenerator j" color=cyan3 shape=box style=filled type="repsiggen" index="j"] + 25 [comment="type=broadcast" shape=point style=invis type="broadcast"] + 23 [comment="type=repeat,index=j,tensor=E,root=false" label="Repeat j: E" color=cyan2 shape=box style=filled type="repeat" index="j" tensor="E" root="false"] + 16 [comment="type=repeat,index=k,tensor=E,root=false" label="Repeat k: E" color=cyan2 shape=box style=filled type="repeat" index="k" tensor="E" root="false"] + 10 [comment="type=arrayvals,tensor=E" label="Array Vals: E" color=green2 shape=box style=filled type="arrayvals" tensor="E"] + 9 [comment="type=mul" label="Mul" color=brown shape=box style=filled type="mul"] + 8 [comment="type=reduce" label="Reduce" color=brown shape=box style=filled type="reduce"] + 7 [comment="type=reduce" label="Reduce" color=brown shape=box style=filled type="reduce"] + 6 [comment="type=reduce" label="Reduce" color=brown shape=box style=filled type="reduce"] + 5 [comment="type=reduce" label="Reduce" color=brown shape=box style=filled type="reduce"] + 0 [comment="type=fiberwrite,mode=vals,tensor=x,size=1,sink=true" label="FiberWrite Vals: x" color=green3 shape=box style=filled type="fiberwrite" tensor="x" mode="vals" size="1" sink="true"] + 24 [comment="type=repeat,index=j,tensor=D,root=false" label="Repeat j: D" color=cyan2 shape=box style=filled type="repeat" index="j" tensor="D" root="false"] + 22 [comment="type=fiberlookup,index=k,tensor=D,mode=0,format=compressed,src=true,root=false" label="FiberLookup k: D0\ncompressed" color=green4 shape=box style=filled type="fiberlookup" index="k" tensor="D" mode="0" format="compressed" src="true" root="false"] + 20 [comment="type=intersect,index=k" label="intersect k" color=purple shape=box style=filled type="intersect" index="k"] + 19 [comment="type=repsiggen,index=k" label="RepeatSignalGenerator k" color=cyan3 shape=box style=filled type="repsiggen" index="k"] + 18 [comment="type=broadcast" shape=point style=invis type="broadcast"] + 17 [comment="type=repeat,index=k,tensor=B,root=false" label="Repeat k: B" color=cyan2 shape=box style=filled type="repeat" index="k" tensor="B" root="false"] + 13 [comment="type=arrayvals,tensor=B" label="Array Vals: B" color=green2 shape=box style=filled type="arrayvals" tensor="B"] + 12 [comment="type=mul" label="Mul" color=brown shape=box style=filled type="mul"] + 11 [comment="type=mul" label="Mul" color=brown shape=box style=filled type="mul"] + 14 [comment="type=arrayvals,tensor=C" label="Array Vals: C" color=green2 shape=box style=filled type="arrayvals" tensor="C"] + 15 [comment="type=arrayvals,tensor=D" label="Array Vals: D" color=green2 shape=box style=filled type="arrayvals" tensor="D"] + 21 [comment="type=fiberlookup,index=k,tensor=C,mode=1,format=compressed,src=true,root=false" label="FiberLookup k: C1\ncompressed" color=green4 shape=box style=filled type="fiberlookup" index="k" tensor="C" mode="1" format="compressed" src="true" root="false"] + 38 [comment="type=repeat,index=i,tensor=D,root=true" label="Repeat i: D" color=cyan2 shape=box style=filled type="repeat" index="i" tensor="D" root="true"] + 36 [comment="type=fiberlookup,index=l,tensor=D,mode=1,format=compressed,src=true,root=false" label="FiberLookup l: D1\ncompressed" color=green4 shape=box style=filled type="fiberlookup" index="l" tensor="D" mode="1" format="compressed" src="true" root="false"] + 34 [comment="type=intersect,index=l" label="intersect l" color=purple shape=box style=filled type="intersect" index="l"] + 33 [comment="type=repsiggen,index=l" label="RepeatSignalGenerator l" color=cyan3 shape=box style=filled type="repsiggen" index="l"] + 32 [comment="type=broadcast" shape=point style=invis type="broadcast"] + 30 [comment="type=repeat,index=l,tensor=B,root=false" label="Repeat l: B" color=cyan2 shape=box style=filled type="repeat" index="l" tensor="B" root="false"] + 28 [comment="type=fiberlookup,index=j,tensor=B,mode=1,format=compressed,src=true,root=false" label="FiberLookup j: B1\ncompressed" color=green4 shape=box style=filled type="fiberlookup" index="j" tensor="B" mode="1" format="compressed" src="true" root="false"] + 35 [comment="type=fiberlookup,index=l,tensor=E,mode=1,format=compressed,src=true,root=false" label="FiberLookup l: E1\ncompressed" color=green4 shape=box style=filled type="fiberlookup" index="l" tensor="E" mode="1" format="compressed" src="true" root="false"] + 43 [comment="type=fiberlookup,index=i,tensor=B,mode=0,format=compressed,src=true,root=true" label="FiberLookup i: B0\ncompressed" color=green4 shape=box style=filled type="fiberlookup" index="i" tensor="B" mode="0" format="compressed" src="true" root="true"] + 42 -> 41 [label="crd_in-E" style=dashed type="crd" comment="in-E"] + 41 -> 40 [label="crd" style=dashed type="crd"] + 40 -> 39 [label="repsig" style=dotted type="repsig"] + 39 -> 37 [label="repsig" style=dotted type="repsig"] + 37 -> 31 [label="ref" style=bold type="ref"] + 31 -> 29 [label="ref" style=bold type="ref"] + 29 -> 27 [label="crd_in-C" style=dashed type="crd" comment="in-C"] + 27 -> 26 [label="crd" style=dashed type="crd"] + 26 -> 25 [label="repsig" style=dotted type="repsig"] + 25 -> 23 [label="repsig" style=dotted type="repsig"] + 23 -> 16 [label="ref" style=bold type="ref"] + 16 -> 10 [label="ref" style=bold type="ref"] + 10 -> 9 [label="val" type="val"] + 9 -> 8 [label="val" type="val"] + 8 -> 7 [label="val" type="val"] + 7 -> 6 [label="val" type="val"] + 6 -> 5 [label="val" type="val"] + 5 -> 0 [label="val" type="val"] + 25 -> 24 [label="repsig" style=dotted type="repsig"] + 24 -> 22 [label="ref" style=bold type="ref"] + 22 -> 20 [label="crd_in-D" style=dashed type="crd" comment="in-D"] + 20 -> 19 [label="crd" style=dashed type="crd"] + 19 -> 18 [label="repsig" style=dotted type="repsig"] + 18 -> 16 [label="repsig" style=dotted type="repsig"] + 18 -> 17 [label="repsig" style=dotted type="repsig"] + 17 -> 13 [label="ref" style=bold type="ref"] + 13 -> 12 [label="val" type="val"] + 12 -> 11 [label="val" type="val"] + 11 -> 9 [label="val" type="val"] + 20 -> 14 [label="ref_out-C" style=bold type="ref" comment="out-C"] + 14 -> 12 [label="val" type="val"] + 20 -> 15 [label="ref_out-D" style=bold type="ref" comment="out-D"] + 15 -> 11 [label="val" type="val"] + 22 -> 20 [label="ref_in-D" style=bold type="ref" comment="in-D"] + 27 -> 17 [label="ref_out-B" style=bold type="ref" comment="out-B"] + 27 -> 21 [label="ref_out-C" style=bold type="ref" comment="out-C"] + 21 -> 20 [label="crd_in-C" style=dashed type="crd" comment="in-C"] + 21 -> 20 [label="ref_in-C" style=bold type="ref" comment="in-C"] + 29 -> 27 [label="ref_in-C" style=bold type="ref" comment="in-C"] + 39 -> 38 [label="repsig" style=dotted type="repsig"] + 38 -> 36 [label="ref" style=bold type="ref"] + 36 -> 34 [label="crd_in-D" style=dashed type="crd" comment="in-D"] + 34 -> 33 [label="crd" style=dashed type="crd"] + 33 -> 32 [label="repsig" style=dotted type="repsig"] + 32 -> 30 [label="repsig" style=dotted type="repsig"] + 30 -> 28 [label="ref" style=bold type="ref"] + 28 -> 27 [label="crd_in-B" style=dashed type="crd" comment="in-B"] + 28 -> 27 [label="ref_in-B" style=bold type="ref" comment="in-B"] + 32 -> 31 [label="repsig" style=dotted type="repsig"] + 34 -> 23 [label="ref_out-E" style=bold type="ref" comment="out-E"] + 34 -> 24 [label="ref_out-D" style=bold type="ref" comment="out-D"] + 36 -> 34 [label="ref_in-D" style=bold type="ref" comment="in-D"] + 41 -> 35 [label="ref_out-E" style=bold type="ref" comment="out-E"] + 35 -> 34 [label="crd_in-E" style=dashed type="crd" comment="in-E"] + 35 -> 34 [label="ref_in-E" style=bold type="ref" comment="in-E"] + 41 -> 30 [label="ref_out-B" style=bold type="ref" comment="out-B"] + 42 -> 41 [label="ref_in-E" style=bold type="ref" comment="in-E"] + 43 -> 41 [label="crd_in-B" style=dashed type="crd" comment="in-B"] + 43 -> 41 [label="ref_in-B" style=bold type="ref" comment="in-B"] +} diff --git a/sam.egg-info/PKG-INFO b/sam.egg-info/PKG-INFO deleted file mode 100644 index 78efc11c..00000000 --- a/sam.egg-info/PKG-INFO +++ /dev/null @@ -1,131 +0,0 @@ -Metadata-Version: 2.1 -Name: sam -Version: 0.0.1 -Summary: Sparse Abstract Machine -Home-page: https://github.com/weiya711/sam -Author: Olivia Hsu -Author-email: oliviahsu1107@gmail.com -License: UNKNOWN -Description: # The Sparse Abstract Machine (SAM) IR, Compiler, and Simulator - - ![Master Makefile CI](https://github.com/weiya711/sam/actions/workflows/makefile.yml/badge.svg?branch=master) - ![Master Python CI](https://github.com/weiya711/sam/actions/workflows/python-package-conda.yml/badge.svg?branch=master) - - ## SAM Front-end Compiler - - Overview: - tensor expression + format language + schedule - --> - SAM Graph - --> - dot file and png of dot file - --> - RTL Graph or Simulator Graph - - ### Compiling SAM graphs - Init the taco/ repo as a submodule - ``` - make submodules - ``` - - Setup the compilation for the taco/ repo - ``` - make taco/build - ``` - - Run the script to generate a handful of example sam graphs - ``` - make sam - ``` - - The example sam graphs should now be located in `compiler/sam-outputs/` in both the `dot/` and `png/` folers. - - ### Naming convention - Naming rules - - all (block) types are lower case: repeat, repeat_gen, fiber_lookup, fiber_write, reduce, intersect, union, sparse_accum - - network signal types are: crd, ref, val, repsig, and bv - - Tensor casing: Matrices and higher order tensors are upper case, scalars and vectors are lower case - - Index variables are going to be i, j, k, ..., etc. - - Tensor ranks are going to correspond to 0, 1, 2, ..., etc. (no longer using rows and columns) - - For a given expression result is always 'x' (or 'X') and the inputs start from 'b, c, ..., etc.' of equivalently 'B, C, ..., etc.' - - Metadata Naming - Metadata naming convention for other blocks: - - Metadata naming convention for fiber (lookup and write) blocks: fiber_-___ - Examples: - 1. fiber_lookup_Bi_B0_compressed - 2. repeat_Ci - - ## SAM Simulator - ### Installing SAM Simulator as a Package - ``` - pip install -e . - ``` - - ### Running Tests - The simulator uses pytest to run tests - - To run all tests type - ``` - cd sam/sim/ - pytest - ``` - - Use the following pytest optional arguments below - ``` - --debug-sim Turn on debug mode for sim - --count= Repeat each test for n iterations - -k [] Run only tests with testname and paramlist - -vv Double verbose - -s Forward printouts to stdout - --full-trace Print full trace to stdout - ``` - - - ### Test Naming Convention - Full kernel tests follow the naming convention `test______...` where: - 1. `` is the name of the tensor algebra kernel being tested (e.g. mat_elemmul, mat_mul, vec_elemmul, etc.) - 2. `*format` takes on `u | c | s` for formats uncompressed, compressed, singleton respectively - 3. ` specifies if the test is _randomly generated_ or a _directed (handwritten)_ test - - Primitive unit tests follow the naming convention `test___` where: - 1. `` is the name of the primitive being tested (e.g. array, intersect, union, etc.) - 2. `` is the name of the feature being tested (e.g. for an array we can test both loads and stores) - 3. `` is the name of the order of stream being tested (1d for vectors, - 2d for matrices, ..., and nd for all dimensions/tensor orders, etc.) - - - ### Directory Structure - ``` - sim - │ - │ - │ - └───src - │ │ base.py - │ │ joiner.py - │ │ ... # All primitive block classes - │ - └───test - │ │ test.py - │ │ file022.txt - │ │ - │ └───apps - │ │ test_mat_elemmul.py - │ │ ... # Full kernel/expression tests - │ - └───────primitives - │ test_joiner.py - │ ... # Primitive unit tests - - ``` - - ## SAM Binding to Onyx - See the `README` in `sam/sam/onyx` - - ## License - All files in this project (code, scripts, documentaiton) are released under the [MIT License](LICENSE) - -Platform: UNKNOWN -Requires-Python: >=3.5 -Description-Content-Type: text/markdown diff --git a/sam/onyx/hw_nodes/compute_node.py b/sam/onyx/hw_nodes/compute_node.py index 0107ff83..0593cfcb 100644 --- a/sam/onyx/hw_nodes/compute_node.py +++ b/sam/onyx/hw_nodes/compute_node.py @@ -115,7 +115,7 @@ def connect(self, other, edge, kwargs=None): new_conns = { f'pe_to_crddrop_res_to_{conn}': [ - ([(pe, "res"), (crddrop, f"cmrg_coord_in_{conn}")], 17), + ([(pe, "res"), (crddrop, f"coord_in_{conn}")], 17), ] } return new_conns diff --git a/sam/onyx/hw_nodes/intersect_node.py b/sam/onyx/hw_nodes/intersect_node.py index c95bb3f1..6a7f8445 100644 --- a/sam/onyx/hw_nodes/intersect_node.py +++ b/sam/onyx/hw_nodes/intersect_node.py @@ -150,7 +150,7 @@ def connect(self, other, edge, kwargs=None): new_conns = { f'isect_to_merger_{conn}': [ # Send isect row and isect col to merger inside isect_col - ([(isect, "coord_out"), (merge, f"cmrg_coord_in_{conn}")], 17), + ([(isect, "coord_out"), (merge, f"coord_in_{conn}")], 17), ] } diff --git a/sam/onyx/hw_nodes/merge_node.py b/sam/onyx/hw_nodes/merge_node.py index f074e3e4..3b2e5fd2 100644 --- a/sam/onyx/hw_nodes/merge_node.py +++ b/sam/onyx/hw_nodes/merge_node.py @@ -56,7 +56,7 @@ def connect(self, other, edge, kwargs=None): print(conn) new_conns = { f'merge_{conn}_to_wr_scan': [ - ([(merge, f"cmrg_coord_out_{conn}"), (wr_scan, f"data_in")], 17), + ([(merge, f"coord_out_{conn}"), (wr_scan, f"data_in")], 17), ] } @@ -66,7 +66,7 @@ def connect(self, other, edge, kwargs=None): print("MERGE TO UNION FOR VECTOR REDUCE") new_conns = { f'merge_to_union_inner': [ - ([(merge, f"cmrg_coord_out_{0}"), (isect, f"coord_in_{0}")], 17), + ([(merge, f"coord_out_{0}"), (isect, f"coord_in_{0}")], 17), ] } @@ -77,7 +77,7 @@ def connect(self, other, edge, kwargs=None): other_red = other.get_name() new_conns = { f'merge_to_reduce_inner': [ - ([(merge, f"cmrg_coord_out_{0}"), (other_red, f"reduce_data_in")], 17), + ([(merge, f"coord_out_{0}"), (other_red, f"reduce_data_in")], 17), ] } @@ -103,7 +103,7 @@ def connect(self, other, edge, kwargs=None): new_conns = { f'merger_to_merger_{out_conn}_to_{in_conn}': [ - ([(merge, f"cmrg_coord_out_{out_conn}"), (other_merge, f"cmrg_coord_in_{in_conn}")], 17), + ([(merge, f"coord_out_{out_conn}"), (other_merge, f"coord_in_{in_conn}")], 17), ] } diff --git a/sam/onyx/hw_nodes/read_scanner_node.py b/sam/onyx/hw_nodes/read_scanner_node.py index 91b8149e..3091f058 100644 --- a/sam/onyx/hw_nodes/read_scanner_node.py +++ b/sam/onyx/hw_nodes/read_scanner_node.py @@ -178,7 +178,7 @@ def connect(self, other, edge, kwargs=None): new_conns = { f'rd_scan_to_crddrop_{conn}': [ - ([(rd_scan, out_conn), (crddrop, f"cmrg_coord_in_{conn}")], 17), + ([(rd_scan, out_conn), (crddrop, f"coord_in_{conn}")], 17), ] } diff --git a/sam/util.py b/sam/util.py index b6d88edb..f349d996 100644 --- a/sam/util.py +++ b/sam/util.py @@ -11,6 +11,7 @@ import scipy.io import scipy.sparse import sparse +import struct # All environment variables for SAM should live here or in make file cwd = os.getcwd() @@ -19,6 +20,8 @@ SUITESPARSE_PATH = os.getenv('SUITESPARSE_PATH', default=os.path.join(SAM_HOME, "data", "suitesparse")) SUITESPARSE_FORMATTED_PATH = os.getenv('SUITESPARSE_FORMATTED_PATH', default=os.path.join(SAM_HOME, "data", "suitesparse-formatted")) +SPARSEML_PATH = os.getenv('SPARSE_ML_PATH', default=os.path.join(SAM_HOME, "data", "sparseml")) +SPARSEML_PATH_FORMATTED = os.getenv('SPARSE_ML_PATH_FORMATTED', default=os.path.join(SAM_HOME, "data", "sparseml-formatted")) FROSTT_PATH = os.getenv('FROSTT_PATH', default=os.path.join(SAM_HOME, "data", "frostt")) VALIDATION_OUTPUT_PATH = os.getenv('VALIDATION_OUTPUT_PATH', default=os.path.join(SAM_HOME, "data", "gold")) @@ -36,6 +39,72 @@ def safeCastScipyTensorToInts(tensor): return scipy.sparse.coo_matrix(tensor.coords, data, tensor.shape) +# TODO: this function is duplicated multiple times across aha repository +# and should be moved to a common location +def bfbin2float(bfstr): + sign = bfstr[0] + exp = bfstr[1:9] + lfrac = bfstr[9:16] + if sign == "0" and exp == "11111111" and lfrac != "0000000": + return float('nan') + elif sign == "1" and exp == "11111111" and lfrac != "0000000": + return -float('nan') + elif sign == "0" and exp == "11111111" and lfrac == "0000000": + return float('inf') + elif sign == "1" and exp == "11111111" and lfrac == "0000000": + return -float('inf') + elif sign == "0" and exp == "00000000" and lfrac == "0000000": + return float(0) + elif sign == "1" and exp == "00000000" and lfrac == "0000000": + return -float(0) + else: + mult = 1 + if sign == "1": + mult = -1 + nexp = int(exp, 2) - 127 + if exp != 0: + lfrac = "1" + lfrac + else: + lfrac = "0" + lfrac + nfrac = int(lfrac, 2) + return mult * nfrac * (2 ** (nexp - 7)) + + +# TODO: this function is duplicated multiple times across aha repository +# and should be moved to a common location +def float2bfbin(fnum): + if fnum == "NaN": + sign = "0" + exp = "11111111" + lfrac = "11111111" + elif fnum == "-NaN": + sign = "1" + exp = "11111111" + lfrac = "11111111" + elif fnum == "Inf" or fnum > 3.402823466e+38: + sign = "0" + exp = "11111111" + lfrac = "00000000" + elif fnum == "-Inf" or fnum < -3.402823466e+38: + sign = "1" + exp = "11111111" + lfrac = "00000000" + else: + fstr = "".join("{:08b}".format(elem) for elem in struct.pack("!f", fnum)) + sign = fstr[0] + exp = fstr[1:9] + lfrac = "0" + fstr[9:16] + hfrac = fstr[16:] + # Enable rounding + if (hfrac[0] == "1" and (hfrac[1] == "1" or hfrac[2] == "1")) or (lfrac[7] == "1" and hfrac[0] == "1"): + # bit 8 of the float mantissa is set, so round up + if lfrac[1:8] == "1111111": # roll over mantissa and increase exp if needed + exp = "{:08b}".format((int(exp, 2) + 1)) # exp overflow? + lfrac = "{:08b}".format((int(lfrac, 2) + 1)) + + return sign + exp + lfrac[1:8] + + # ScipyTensorShifter shifts all elements in the last mode # of the input scipy/sparse tensor by one. class ScipyTensorShifter: @@ -242,6 +311,20 @@ def load(self, path): return coo +class NumpyNPYArrayLoader: + def __init__(self): + pass + + def load(self, path): + np_array = numpy.load(path) + if (np_array.dtype == numpy.dtype('S16')): + input_fp_array = numpy.empty_like(np_array, dtype=numpy.float32) + for idx, val in numpy.ndenumerate(np_array): + input_fp_array[idx] = bfbin2float(str(val).split("'")[1]) + coo = scipy.sparse.coo_array(input_fp_array) + return coo + + def shape_str(shape): return str(shape[0]) + " " + str(shape[1]) @@ -283,6 +366,25 @@ def load(self, tensor, cast): return self.tensor +class InputCacheSparseML: + def __init__(self): + self.lastLoaded = None + self.lastName = None + self.tensor = None + + def load(self, tensor, cast): + if self.lastName == str(tensor): + return self.tensor + else: + self.lastLoaded = tensor.load(NumpyNPYArrayLoader()) + self.lastName = str(tensor) + if cast: + self.tensor = self.lastLoaded + else: + self.tensor = self.lastLoaded + return self.tensor + + class FormatWriter: def __init__(self, cast_int=True): self.cast = cast_int @@ -600,6 +702,19 @@ def load(self, loader): return loader.load(self.path) +class SparseMLTensor: + def __init__(self, path): + self.path = path + self.__name__ = self.__str__() + + def __str__(self): + f = os.path.split(self.path)[1] + return f.replace(".npy", "") + + def load(self, loader): + return loader.load(self.path) + + # TensorCollectionSuiteSparse represents the set of all downloaded # SuiteSparse tensors. class TensorCollectionSuiteSparse: