diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..ebc54b3 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1 @@ +include pyar/AIMNet2/models/aimnet2_wb97m-d3_0.jpt \ No newline at end of file diff --git a/pyar/interface/aimnet_2.py b/pyar/interface/aimnet_2.py index f145051..83118a0 100755 --- a/pyar/interface/aimnet_2.py +++ b/pyar/interface/aimnet_2.py @@ -8,6 +8,8 @@ from pyar.AIMNet2.calculators import aimnet2_ase_opt # noqa: F401 from pyar.AIMNet2.calculators import aimnet2ase # noqa: F401 import sys +import pkg_resources +import os Aimnet2_logger = logging.getLogger('pyar.aimnet-2') @@ -22,8 +24,10 @@ device = torch.device('cpu') print(device) -#Plese adjust the path according the cloned repo. -aimnet2 = torch.jit.load('aimnet2_wb97m-d3_0.jpt', map_location=device) +model_path = pkg_resources.resource_filename('pyar', 'AIMNet2/models/aimnet2_wb97m-d3_0.jpt') +aimnet2_script = pkg_resources.resource_filename('pyar', 'AIMNet2/calculators/aimnet2_ase_opt.py') +# Load the model +aimnet2 = torch.jit.load(model_path, map_location=device) class Aimnet2(SF): def __init__(self, molecule, qc_params): @@ -46,7 +50,7 @@ def __init__(self, molecule, qc_params): self.inp_min_file = 'trial_' + self.job_name + '_min.xyz' self.out_file = 'trial_' + self.job_name + '.out' - self.cmd = f"python aimnet2_ase_opt.py aimnet2_wb97m-d3_0.jpt --traj result.traj {self.inp_file} {self.inp_min_file}" + self.cmd = f"python {aimnet2_script} {model_path} --traj result.traj {self.inp_file} {self.inp_min_file}" if self.charge != 0: self.cmd = "{} -c {}".format(self.cmd, self.charge) diff --git a/setup.py b/setup.py index 229a2f3..618c63a 100644 --- a/setup.py +++ b/setup.py @@ -13,9 +13,11 @@ 'pyar/scripts/pyar-tabu', 'pyar/scripts/pyar-clustering', 'pyar/scripts/pyar-similarity', - 'pyar/AIMNet2/calculators/aimnet2_ase_opt.py', - 'pyar/AIMNet2/models/aimnet2_wb97m-d3_0.jpt' + 'pyar/AIMNet2/calculators/aimnet2_ase_opt.py' ], + package_data={ + 'pyar': ['AIMNet2/models/aimnet2_wb97m-d3_0.jpt'] + }, url='https://github.com/anooplab/pyar', license='GPL v3', author='Anoop et al', @@ -28,7 +30,6 @@ 'pandas', 'matplotlib', 'pyh5md', - 'hdbscan', 'h5py', 'DBCV @ git+https://github.com/christopherjenness/DBCV.git', 'dscribe' @@ -45,3 +46,4 @@ ], python_requires='>=3.6', ) +