Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'master' into JHopeCollins/vector_sub1
Browse files Browse the repository at this point in the history
JHopeCollins committed Dec 3, 2024
2 parents 0774d82 + 2be9fb2 commit a2a9448
Showing 2 changed files with 28 additions and 11 deletions.
17 changes: 12 additions & 5 deletions firedrake/output/vtk_output.py
Original file line number Diff line number Diff line change
@@ -452,7 +452,7 @@ def __init__(self, filename, project_output=False, comm=None, mode="w",
@no_annotations
def _prepare_output(self, function, max_elem):
from firedrake import FunctionSpace, VectorFunctionSpace, \
TensorFunctionSpace, Function
TensorFunctionSpace, Function, Cofunction

name = function.name()
# Need to project/interpolate?
@@ -477,16 +477,23 @@ def _prepare_output(self, function, max_elem):
shape=shape)
else:
raise ValueError("Unsupported shape %s" % (shape, ))
output = Function(V)
if isinstance(function, Function):
output = Function(V)
else:
assert isinstance(function, Cofunction)
output = Function(V.dual())

if self.project:
if isinstance(function, Cofunction):
raise ValueError("Can not project Cofunctions")
output.project(function)
else:
output.interpolate(function)

return OFunction(array=get_array(output), name=name, function=output)

def _write_vtu(self, *functions):
from firedrake.function import Function
from firedrake import Function, Cofunction

# Check if the user has requested to write out a plain mesh
if len(functions) == 1 and isinstance(functions[0], ufl.Mesh):
@@ -496,8 +503,8 @@ def _write_vtu(self, *functions):
functions = [Function(V)]

for f in functions:
if not isinstance(f, Function):
raise ValueError("Can only output Functions or a single mesh, not %r" % type(f))
if not isinstance(f, (Function, Cofunction)):
raise ValueError(f"Can only output Functions, Cofunctions or a single mesh, not {type(f).__name__}")
meshes = tuple(extract_unique_domain(f) for f in functions)
if not all(m == meshes[0] for m in meshes):
raise ValueError("All functions must be on same mesh")
22 changes: 16 additions & 6 deletions tests/firedrake/output/test_pvd_output.py
Original file line number Diff line number Diff line change
@@ -79,11 +79,16 @@ def test_bad_file_name(tmpdir):
VTKFile(str(tmpdir.join("foo.vtu")))


def test_different_functions(mesh, pvd):
@pytest.mark.parametrize("space",
["primal", "dual"])
def test_different_functions(mesh, pvd, space):
V = FunctionSpace(mesh, "DG", 0)

f = Function(V, name="foo")
g = Function(V, name="bar")
if space == "primal":
f = Function(V, name="foo")
g = Function(V, name="bar")
else:
f = Cofunction(V.dual(), name="foo")
g = Cofunction(V.dual(), name="bar")

pvd.write(f)

@@ -136,9 +141,14 @@ def test_not_function(mesh, pvd):
pvd.write(grad(f))


def test_append(mesh, tmpdir):
@pytest.mark.parametrize("space",
["primal", "dual"])
def test_append(mesh, tmpdir, space):
V = FunctionSpace(mesh, "DG", 0)
g = Function(V)
if space == "primal":
g = Function(V)
else:
g = Cofunction(V.dual())

outfile = VTKFile(str(tmpdir.join("restart_test.pvd")))
outfile.write(g)

0 comments on commit a2a9448

Please sign in to comment.