diff --git a/festim/exports/derived_quantities/surface_flux.py b/festim/exports/derived_quantities/surface_flux.py index e68f47c1d..511485797 100644 --- a/festim/exports/derived_quantities/surface_flux.py +++ b/festim/exports/derived_quantities/surface_flux.py @@ -116,6 +116,16 @@ def __init__(self, field, surface, azimuth_range=(0, 2 * np.pi)) -> None: self.r = None self.azimuth_range = azimuth_range + @property + def export_unit(self): + # obtain domain dimension + dim = self.function.function_space().mesh().topology().dim() + # return unit depending on field and dimension of domain + if self.field == "T": + return f"W m{dim-2}".replace(" m0", "") + else: + return f"H m{dim-2} s-1".replace(" m0", "") + @property def allowed_meshes(self): return ["cylindrical"] @@ -128,10 +138,7 @@ def title(self): quantity_title = f"{self.field} flux surface {self.surface}" if self.show_units: - if self.field == "T": - return quantity_title + " (W)" - else: - return quantity_title + " (H s-1)" + return quantity_title + f" ({self.export_unit})" else: return quantity_title @@ -208,6 +215,13 @@ def __init__( self.polar_range = polar_range self.azimuth_range = azimuth_range + @property + def export_unit(self): + if self.field == "T": + return f"W" + else: + return f"H s-1" + @property def allowed_meshes(self): return ["spherical"] @@ -220,10 +234,7 @@ def title(self): quantity_title = f"{self.field} flux surface {self.surface}" if self.show_units: - if self.field == "T": - return quantity_title + " (W)" - else: - return quantity_title + " (H s-1)" + return quantity_title + f" ({self.export_unit})" else: return quantity_title diff --git a/test/unit/test_exports/test_derived_quantities/test_derived_quantities.py b/test/unit/test_exports/test_derived_quantities/test_derived_quantities.py index 1ca73cb15..4d8bc61de 100644 --- a/test/unit/test_exports/test_derived_quantities/test_derived_quantities.py +++ b/test/unit/test_exports/test_derived_quantities/test_derived_quantities.py @@ -120,8 +120,6 @@ def test_with_units_simple(self): self.avg_surface_2, self.point_1, self.point_2, - self.cyl_surface_flux_1, - self.cyl_surface_flux_2, self.sph_surface_flux_1, self.sph_surface_flux_2, self.ads_h, @@ -144,8 +142,6 @@ def test_with_units_simple(self): "Average T surface 6 (K)", "retention value at [2] (H m-3)", "T value at [9] (K)", - "solute flux surface 2 (H s-1)", - "Heat flux surface 3 (W)", "solute flux surface 5 (H s-1)", "Heat flux surface 6 (W)", "Adsorbed H on surface 1 (H m-2)", diff --git a/test/unit/test_exports/test_derived_quantities/test_surface_flux.py b/test/unit/test_exports/test_derived_quantities/test_surface_flux.py index 1d92d3a44..e2b2e4015 100644 --- a/test/unit/test_exports/test_derived_quantities/test_surface_flux.py +++ b/test/unit/test_exports/test_derived_quantities/test_surface_flux.py @@ -335,6 +335,23 @@ def test_cylindrical_flux_title_no_units_temperature(): assert my_heat_flux.title == "Heat flux surface 4" +@pytest.mark.parametrize( + "function, field, expected_title", + [ + (c_1D, "solute", "solute flux surface 3 (H m-1 s-1)"), + (c_1D, "T", "Heat flux surface 3 (W m-1)"), + (c_2D, "solute", "solute flux surface 3 (H s-1)"), + (c_2D, "T", "Heat flux surface 3 (W)"), + ], +) +def test_cylindrical_flux_with_units(function, field, expected_title): + my_flux = SurfaceFluxCylindrical(field=field, surface=3) + my_flux.function = function + my_flux.show_units = True + + assert my_flux.title == expected_title + + def test_spherical_flux_title_no_units_solute(): """A simple test to check that the title is set correctly in festim.SphericalSurfaceFlux with a solute field without units""" @@ -349,3 +366,18 @@ def test_spherical_flux_title_no_units_temperature(): my_heat_flux = SurfaceFluxSpherical("T", 5) assert my_heat_flux.title == "Heat flux surface 5" + + +@pytest.mark.parametrize( + "function, field, expected_title", + [ + (c_1D, "solute", "solute flux surface 3 (H s-1)"), + (c_1D, "T", "Heat flux surface 3 (W)"), + ], +) +def test_spherical_flux_with_units(function, field, expected_title): + my_flux = SurfaceFluxSpherical(field=field, surface=3) + my_flux.function = function + my_flux.show_units = True + + assert my_flux.title == expected_title