From 7a71c46394973ff718da637fe4d4702c7f0c9f95 Mon Sep 17 00:00:00 2001 From: Misko Date: Mon, 12 Aug 2024 22:55:47 +0000 Subject: [PATCH] add resolution flag to escn --- src/fairchem/core/models/escn/escn.py | 3 ++- src/fairchem/core/models/escn/so3.py | 9 ++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/fairchem/core/models/escn/escn.py b/src/fairchem/core/models/escn/escn.py index 62a582b4cc..c288f3f251 100644 --- a/src/fairchem/core/models/escn/escn.py +++ b/src/fairchem/core/models/escn/escn.py @@ -87,6 +87,7 @@ def __init__( basis_width_scalar: float = 1.0, distance_resolution: float = 0.02, show_timing_info: bool = False, + resolution: int | None = None, ) -> None: if mmax_list is None: mmax_list = [2] @@ -176,7 +177,7 @@ def __init__( for lval in range(max(self.lmax_list) + 1): SO3_m_grid = nn.ModuleList() for m in range(max(self.lmax_list) + 1): - SO3_m_grid.append(SO3_Grid(lval, m)) + SO3_m_grid.append(SO3_Grid(lval, m, resolution=resolution)) self.SO3_grid.append(SO3_m_grid) diff --git a/src/fairchem/core/models/escn/so3.py b/src/fairchem/core/models/escn/so3.py index 988797df2e..34f505d51e 100644 --- a/src/fairchem/core/models/escn/so3.py +++ b/src/fairchem/core/models/escn/so3.py @@ -452,11 +452,7 @@ class SO3_Grid(torch.nn.Module): mmax (int): Maximum order of the spherical harmonics """ - def __init__( - self, - lmax: int, - mmax: int, - ) -> None: + def __init__(self, lmax: int, mmax: int, resolution: int | None = None) -> None: super().__init__() self.lmax = lmax self.mmax = mmax @@ -465,6 +461,9 @@ def __init__( self.long_resolution = 2 * (self.mmax + 1) + 1 else: self.long_resolution = 2 * (self.mmax) + 1 + if resolution: + self.long_resolution=resolution + self.lat_resolution=resolution self.initialized = False