Skip to content

Commit

Permalink
Fixes bug for "data" input to "project.py" and "test_inputs.py"
Browse files Browse the repository at this point in the history
  • Loading branch information
DrPaulSharp committed Jun 10, 2024
1 parent 8252dd5 commit 6b979a1
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 19 deletions.
12 changes: 4 additions & 8 deletions RAT/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,14 +231,10 @@ def make_cells(project: RAT.Project) -> Cells:

data_index = project.data.index(contrast.data)

if 'data' in project.data[data_index].model_fields_set:
all_data.append(project.data[data_index].data)
data_limits.append(project.data[data_index].data_range)
simulation_limits.append(project.data[data_index].simulation_range)
else:
all_data.append([0.0, 0.0, 0.0])
data_limits.append([0.0, 0.0])
simulation_limits.append([0.0, 0.0])
all_data.append(project.data[data_index].data)
data_limits.append(project.data[data_index].data_range)
simulation_limits.append(project.data[data_index].simulation_range)


# Populate the set of cells
cells = Cells()
Expand Down
4 changes: 2 additions & 2 deletions RAT/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def model_post_init(self, __context: Any) -> None:
"""If the "data_range" and "simulation_range" fields are not set, but "data" is supplied, the ranges should be
set to the min and max values of the first column (assumed to be q) of the supplied data.
"""
if len(self.data[:, 0]) > 0:
if self.data.shape[0] > 0:
data_min = np.min(self.data[:, 0])
data_max = np.max(self.data[:, 0])
for field in ["data_range", "simulation_range"]:
Expand All @@ -135,7 +135,7 @@ def check_ranges(self) -> 'Data':
"""The limits of the "data_range" field must lie within the range of the supplied data, whilst the limits
of the "simulation_range" field must lie outside the range of the supplied data.
"""
if len(self.data[:, 0]) > 0:
if self.data.shape[0] > 0:
data_min = np.min(self.data[:, 0])
data_max = np.max(self.data[:, 0])
if "data_range" in self.model_fields_set and (self.data_range[0] < data_min or
Expand Down
5 changes: 4 additions & 1 deletion RAT/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class Project(BaseModel, validate_assignment=True, extra='forbid', arbitrary_typ
value_1='Resolution Param 1'))

custom_files: ClassList = ClassList()
data: ClassList = ClassList(RAT.models.Data(name='Simulation'))
data: ClassList = ClassList()
layers: ClassList = ClassList()
domain_contrasts: ClassList = ClassList()
contrasts: ClassList = ClassList()
Expand Down Expand Up @@ -187,6 +187,9 @@ def model_post_init(self, __context: Any) -> None:
self.parameters.remove('Substrate Roughness')
self.parameters.insert(0, RAT.models.ProtectedParameter(**substrate_roughness_values))

if 'Simulation' not in self.data.get_names():
self.data.insert(0, RAT.models.Data(name='Simulation'))

self._all_names = self.get_all_names()
self._contrast_model_field = self.get_contrast_model_field()
self._protected_parameters = self.get_all_protected_parameters()
Expand Down
16 changes: 8 additions & 8 deletions tests/test_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
@pytest.fixture
def standard_layers_project():
"""Add parameters to the default project for a non polarised calculation."""
test_project = RAT.Project(data=RAT.ClassList([RAT.models.Data(name='Simulation', data=np.array([[1.0, 1.0, 1.0]]))]))
test_project = RAT.Project(data=RAT.ClassList([RAT.models.Data(name='Test Data', data=np.array([[1.0, 1.0, 1.0]]))]))
test_project.parameters.append(name='Test Thickness')
test_project.parameters.append(name='Test SLD')
test_project.parameters.append(name='Test Roughness')
test_project.custom_files.append(name='Test Custom File', filename='matlab_test.m', language='matlab')
test_project.layers.append(name='Test Layer', thickness='Test Thickness', SLD='Test SLD', roughness='Test Roughness')
test_project.contrasts.append(name='Test Contrast', data='Simulation', background='Background 1', bulk_in='SLD Air',
test_project.contrasts.append(name='Test Contrast', data='Test Data', background='Background 1', bulk_in='SLD Air',
bulk_out='SLD D2O', scalefactor='Scalefactor 1', resolution='Resolution 1',
model=['Test Layer'])
return test_project
Expand All @@ -31,15 +31,15 @@ def standard_layers_project():
def domains_project():
"""Add parameters to the default project for a domains calculation."""
test_project = RAT.Project(calculation=Calculations.Domains,
data=RAT.ClassList([RAT.models.Data(name='Simulation', data=np.array([[1.0, 1.0, 1.0]]))]))
data=RAT.ClassList([RAT.models.Data(name='Test Data', data=np.array([[1.0, 1.0, 1.0]]))]))
test_project.parameters.append(name='Test Thickness')
test_project.parameters.append(name='Test SLD')
test_project.parameters.append(name='Test Roughness')
test_project.custom_files.append(name='Test Custom File', filename='matlab_test.m', language='matlab')
test_project.layers.append(name='Test Layer', thickness='Test Thickness', SLD='Test SLD', roughness='Test Roughness')
test_project.domain_contrasts.append(name='up', model=['Test Layer'])
test_project.domain_contrasts.append(name='down', model=['Test Layer'])
test_project.contrasts.append(name='Test Contrast', data='Simulation', background='Background 1', bulk_in='SLD Air',
test_project.contrasts.append(name='Test Contrast', data='Test Data', background='Background 1', bulk_in='SLD Air',
bulk_out='SLD D2O', scalefactor='Scalefactor 1', resolution='Resolution 1',
domain_ratio='Domain Ratio 1', model=['down', 'up'])
return test_project
Expand Down Expand Up @@ -165,7 +165,7 @@ def custom_xy_problem():
problem.contrastCustomFiles = [1]
problem.contrastDomainRatios = [0]
problem.resample = [False]
problem.dataPresent = [1]
problem.dataPresent = [0]
problem.oilChiDataPresent = [0]
problem.numberOfContrasts = 1
problem.numberOfLayers = 0
Expand Down Expand Up @@ -240,9 +240,9 @@ def custom_xy_cells():
"""The expected cells object from "custom_xy_project"."""
cells = Cells()
cells.f1 = [[0, 1]]
cells.f2 = [np.array([[0.0, 0.0, 0.0]])]
cells.f3 = [[0.0, 0.0]]
cells.f4 = [[0.0, 0.0]]
cells.f2 = [np.empty([0, 3])]
cells.f3 = [[]]
cells.f4 = [[]]
cells.f5 = [0]
cells.f6 = [0]
cells.f7 = ['Substrate Roughness', 'Test Thickness', 'Test SLD', 'Test Roughness']
Expand Down

0 comments on commit 6b979a1

Please sign in to comment.