Skip to content

Commit

Permalink
Merge pull request #4 from kamalsaleh/master
Browse files Browse the repository at this point in the history
Allow to save plot instead of showing it & add more tests
  • Loading branch information
kamalsaleh authored Jul 11, 2024
2 parents 1214161 + f49be16 commit e171462
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 4 deletions.
2 changes: 1 addition & 1 deletion PackageInfo.g
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ SetPackageInfo( rec(

PackageName := "MachineLearningForCAP",
Subtitle := "Exploring categorical machine learning in CAP",
Version := "2024.07-10",
Version := "2024.07-12",
Date := (function ( ) if IsBound( GAPInfo.SystemEnvironment.GAP_PKG_RELEASE_DATE ) then return GAPInfo.SystemEnvironment.GAP_PKG_RELEASE_DATE; else return Concatenation( ~.Version{[ 1 .. 4 ]}, "-", ~.Version{[ 6, 7 ]}, "-01" ); fi; end)( ),
License := "GPL-2.0-or-later",

Expand Down
33 changes: 33 additions & 0 deletions examples/Expressions.g
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,37 @@ e := Sin( x1 ) / Cos( x1 ) + Sin( x2 ) ^ 2 + Cos( x2 ) ^ 2;
#! Sin( x1 ) / Cos( x1 ) + Sin( x2 ) ^ 2 + Cos( x2 ) ^ 2
SimplifyExpressionUsingPython( [ e ] );
#! [ "Tan(x1) + 1" ]
Diff( e, 1 )( dummy_input );
#! Sin( x1 ) ^ 2 / Cos( x1 ) ^ 2 + 1
LazyDiff( e, 1 )( dummy_input );;
# Diff( [ "x1", "x2", "x3" ],
# "(((Sin(x1))/(Cos(x1)))+((Sin(x2))^(2)))+((Cos(x2))^(2))", 1 )( [ x1, x2, x3 ] );
JacobianMatrixUsingPython( [ x1*Cos(x2)+Exp(x3), x1*x2*x3 ], [ 1, 2, 3 ] );
#! [ [ "Cos(x2)", "-x1*Sin(x2)", "Exp(x3)" ], [ "x2*x3", "x1*x3", "x1*x2" ] ]
LaTeXOutputUsingPython( e );
#! "\\frac{\\sin{\\left(x_{1} \\right)}}{\\cos{\\left(x_{1} \\right)}}
#! + \\sin^{2}{\\left(x_{2} \\right)} + \\cos^{2}{\\left(x_{2} \\right)}"
sigmoid := Expression( [ "x" ], "Exp(x)/(1+Exp(x))" );
#! Exp( x ) / (1 + Exp( x ))
sigmoid := AsFunction( sigmoid );
#! function( vec ) ... end
sigmoid( [ 0 ] );
#! 0.5
points := List( 0.1 * [ -20 .. 20 ], x -> [ x, sigmoid( [ x ] ) ] );
#! [ [ -2., 0.119203 ], [ -1.9, 0.130108 ], [ -1.8, 0.141851 ], [ -1.7, 0.154465 ],
#! [ -1.6, 0.167982 ], [ -1.5, 0.182426 ], [ -1.4, 0.197816 ], [ -1.3, 0.214165 ],
#! [ -1.2, 0.231475 ], [ -1.1, 0.24974 ], [ -1., 0.268941 ], [ -0.9, 0.28905 ],
#! [ -0.8, 0.310026 ], [ -0.7, 0.331812 ], [ -0.6, 0.354344 ], [ -0.5, 0.377541 ],
#! [ -0.4, 0.401312 ], [ -0.3, 0.425557 ], [ -0.2, 0.450166 ], [ -0.1, 0.475021 ],
#! [ 0., 0.5 ], [ 0.1, 0.524979 ], [ 0.2, 0.549834 ], [ 0.3, 0.574443 ],
#! [ 0.4, 0.598688 ], [ 0.5, 0.622459 ], [ 0.6, 0.645656 ], [ 0.7, 0.668188 ],
#! [ 0.8, 0.689974 ], [ 0.9, 0.71095 ], [ 1., 0.731059 ], [ 1.1, 0.75026 ],
#! [ 1.2, 0.768525 ], [ 1.3, 0.785835 ], [ 1.4, 0.802184 ], [ 1.5, 0.817574 ],
#! [ 1.6, 0.832018 ], [ 1.7, 0.845535 ], [ 1.8, 0.858149 ], [ 1.9, 0.869892 ],
#! [ 2., 0.880797 ] ]
labels := List( points, point -> SelectBasedOnCondition( point[2] < 0.5, 0, 1 ) );
#! [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
#! 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 ]
ScatterPlotUsingPython( points, labels : size := "100", action := "save" );;
# e.g, dir("/tmp/gaptempdirX7Qsal/")
#! @EndExample
11 changes: 8 additions & 3 deletions gap/Tools.gi
Original file line number Diff line number Diff line change
Expand Up @@ -707,7 +707,7 @@ InstallMethod( ScatterPlotUsingPython,
[ IsDenseList, IsDenseList ],

function ( points, labels )
local dir, path, file, size, stream, err, p;
local dir, path, file, size, action, stream, err, p;

dir := DirectoryTemporary( );

Expand All @@ -719,6 +719,8 @@ InstallMethod( ScatterPlotUsingPython,

size := CAP_INTERNAL_RETURN_OPTION_OR_DEFAULT( "size", "20" );

action := CAP_INTERNAL_RETURN_OPTION_OR_DEFAULT( "action", "show" );

IO_Write( file,
Concatenation(
"import matplotlib.pyplot as plt\n",
Expand Down Expand Up @@ -779,7 +781,10 @@ InstallMethod( ScatterPlotUsingPython,
"plt.ylabel('Y-axis')\n",
"plt.title('Scatter Plot using Matplotlib')\n",
"plt.legend()\n",
"plt.show()\n" ) );
SelectBasedOnCondition(
action = "save",
Concatenation( "plt.savefig('", Filename( dir, "plot.png" ), "', dpi=400)\n" ),
"plt.show()\n" ) ) );

IO_Close( file );

Expand All @@ -801,6 +806,6 @@ InstallMethod( ScatterPlotUsingPython,

fi;

return true;
return dir;

end );

0 comments on commit e171462

Please sign in to comment.