From 5ff27831ddf3c9efd73ef801677b4103e53589b0 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 2 Jul 2026 09:59:50 +0100 Subject: [PATCH 1/2] Support -snes_adapt --- firedrake/dmhooks.py | 69 +++++ firedrake/mg/ufl_utils.py | 305 ++++++++++++++----- firedrake/variational_solver.py | 44 ++- tests/firedrake/multigrid/test_snes_adapt.py | 148 +++++++++ 4 files changed, 477 insertions(+), 89 deletions(-) create mode 100644 tests/firedrake/multigrid/test_snes_adapt.py diff --git a/firedrake/dmhooks.py b/firedrake/dmhooks.py index a2b60a1f50..9b65792554 100644 --- a/firedrake/dmhooks.py +++ b/firedrake/dmhooks.py @@ -446,6 +446,69 @@ def coarsen(dm, comm): return cdm +def _validate_refinement_markers(markers, mesh): + if not isinstance(markers, (firedrake.Function, firedrake.Cofunction)): + raise TypeError( + f"marking callback must return a Function or Cofunction, not a {type(markers).__name__}" + ) + M = markers.function_space() + if M.mesh() is not mesh: + raise ValueError("marking callback must return markers on the current solution mesh") + if M.finat_element.space_dimension() != 1: + raise ValueError("marking callback must return a DG0 Function or Cofunction") + + +def _attach_refined_context(parent, ctx, coefficient_mapping, coarsener): + dms = [ctx._problem.u_restrict.function_space().dm] + for value in coefficient_mapping.values(): + if isinstance(value, (firedrake.Function, firedrake.Cofunction)): + dm = value.function_space().dm + if dm not in dms: + dms.append(dm) + + for dm in dms: + add_hook(parent, setup=partial(push_parent, dm, parent), teardown=partial(pop_parent, dm, parent), + call_setup=True) + add_hook(parent, setup=partial(push_ctx_coarsener, dm, coarsener), + teardown=partial(pop_ctx_coarsener, dm, coarsener), + call_setup=True) + add_hook(parent, setup=partial(push_appctx, dm, ctx), teardown=partial(pop_appctx, dm, ctx), + call_setup=True) + + +def _adaptively_refine(dm, comm): + from firedrake.mg.adaptive_hierarchy import AdaptiveMeshHierarchy + from firedrake.mg.ufl_utils import refine + from firedrake.mg.utils import get_level + + ctx = get_appctx(dm) + callback = getattr(ctx, "_adapt_marking_callback", None) + + current_solution = ctx._x + mesh = current_solution.function_space().mesh() + hierarchy, level = get_level(mesh) + if hierarchy is None: + hierarchy = AdaptiveMeshHierarchy(mesh) + level = 0 + if not isinstance(hierarchy, AdaptiveMeshHierarchy): + raise RuntimeError("Adaptive SNES refinement requires an AdaptiveMeshHierarchy") + + if level == len(hierarchy) - 1: + solver = None + markers = callback(solver, current_solution) + _validate_refinement_markers(markers, mesh) + hierarchy.add_mesh(mesh.refine_marked_elements(markers)) + + coefficient_mapping = {} + refined_ctx = refine(ctx, refine, coefficient_mapping=coefficient_mapping) + refined_ctx._adapt_marking_callback = callback + + parent = get_parent(dm) + coarsener = get_ctx_coarsener(dm) + _attach_refined_context(parent, refined_ctx, coefficient_mapping, coarsener) + return refined_ctx._problem.dm + + @PETSc.Log.EventDecorator() def refine(dm, comm): """Callback to refine a DM. @@ -453,11 +516,17 @@ def refine(dm, comm): :arg DM: The DM to refine. :arg comm: The communicator for the new DM (ignored) """ + ctx = get_appctx(dm) + if getattr(ctx, "_adapt_marking_callback", None) is not None: + return _adaptively_refine(dm, comm) + from firedrake.mg.utils import get_level V = get_function_space(dm) if V is None: raise RuntimeError("No functionspace found on DM") hierarchy, level = get_level(V.mesh()) + if hierarchy is None: + raise RuntimeError("No mesh hierarchy available") if level >= len(hierarchy) - 1: raise RuntimeError("Cannot refine finest DM") if hasattr(V, "_fine"): diff --git a/firedrake/mg/ufl_utils.py b/firedrake/mg/ufl_utils.py index b40bfbd7ac..f3d577e523 100644 --- a/firedrake/mg/ufl_utils.py +++ b/firedrake/mg/ufl_utils.py @@ -14,7 +14,7 @@ from . import utils -__all__ = ["coarsen"] +__all__ = ["coarsen", "refine"] class CoarseningError(Exception): @@ -64,6 +64,15 @@ def coarsen(expr, self, coefficient_mapping=None): return expr +@singledispatch +def refine(expr, self, coefficient_mapping=None): + # Most coarsen handlers will simply reconstruct the expression tree. And + # very few of them branch on coarsen vs refine to handle both directions. + # Delegating here lets those shared handlers do the right thing when called + # via `refine(...)`. + return coarsen(expr, self, coefficient_mapping=coefficient_mapping) + + @coarsen.register(ufl.Mesh) @coarsen.register(ufl.MeshSequence) def coarsen_mesh(mesh, self, coefficient_mapping=None): @@ -73,9 +82,18 @@ def coarsen_mesh(mesh, self, coefficient_mapping=None): return hierarchy[level - 1] +@refine.register(ufl.Mesh) +@refine.register(ufl.MeshSequence) +def refine_mesh(mesh, self, coefficient_mapping=None): + hierarchy, level = utils.get_level(mesh) + if hierarchy is None: + raise CoarseningError("No mesh hierarchy available") + return hierarchy[level + 1] + + @coarsen.register(ufl.BaseForm) @coarsen.register(ufl.classes.Expr) -def coarse_expr(expr, self, coefficient_mapping=None): +def coarsen_expr(expr, self, coefficient_mapping=None): if expr is None: return None mapper = CoarsenIntegrand(self, coefficient_mapping) @@ -126,7 +144,7 @@ def coarsen_formsum(form, self, coefficient_mapping=None): @coarsen.register(firedrake.DirichletBC) def coarsen_bc(bc, self, coefficient_mapping=None): V = self(bc.function_space(), self, coefficient_mapping=coefficient_mapping) - val = self(bc.function_arg, self, coefficient_mapping=coefficient_mapping) + val = self(bc._original_arg, self, coefficient_mapping=coefficient_mapping) subdomain = bc.sub_domain return type(bc)(V, val, subdomain) @@ -149,18 +167,46 @@ def coarsen_equation_bc(ebc, self, coefficient_mapping=None): @coarsen.register(firedrake.functionspaceimpl.WithGeometryBase) def coarsen_function_space(V, self, coefficient_mapping=None): - if hasattr(V, "_coarse"): + # Handle MixedFunctionSpace : V.reconstruct requires MeshSequence. + mesh = V.mesh() if V.index is None else V.parent.mesh() + new_mesh = self(mesh, self) + if hasattr(V, "_coarse") and V._coarse.mesh() == new_mesh: return V._coarse - - V_fine = V - # Handle MixedFunctionSpace : V_fine.reconstruct requires MeshSequence. - fine_mesh = V_fine.mesh() if V_fine.index is None else V_fine.parent.mesh() - mesh_coarse = self(fine_mesh, self) - name = f"coarse_{V.name}" if V.name else None - V_coarse = V_fine.reconstruct(mesh=mesh_coarse, name=name) - V_coarse._fine = V_fine - V_fine._coarse = V_coarse - return V_coarse + # Get the parent name + V_parent = V + while hasattr(V_parent, "_fine") and V_parent._fine: + V_parent = V_parent._fine + name = V_parent.name + if name is not None: + mh, level = utils.get_level(new_mesh) + name = f"{name}_level_{level}" + # Reconstruct the space + V_new = V.reconstruct(mesh=new_mesh, name=name) + V_new._fine = V + V._coarse = V_new + return V_new + + +@refine.register(firedrake.functionspaceimpl.WithGeometryBase) +def refine_function_space(V, self, coefficient_mapping=None): + # Handle MixedFunctionSpace : V.reconstruct requires MeshSequence. + mesh = V.mesh() if V.index is None else V.parent.mesh() + new_mesh = self(mesh, self) + if hasattr(V, "_fine") and V._fine.mesh() == new_mesh: + return V._fine + # Get the parent name + V_parent = V + while hasattr(V_parent, "_coarse") and V_parent._coarse: + V_parent = V_parent._coarse + name = V_parent.name + if name is not None: + mh, level = utils.get_level(new_mesh) + name = f"{name}_level_{level}" + # Reconstruct the space + V_new = V.reconstruct(mesh=new_mesh, name=name) + V_new._coarse = V + V._fine = V_new + return V_new @coarsen.register(firedrake.Cofunction) @@ -170,10 +216,19 @@ def coarsen_function(expr, self, coefficient_mapping=None): coefficient_mapping = {} new = coefficient_mapping.get(expr) if new is None: - Vf = expr.function_space() - Vc = self(Vf, self) - new = firedrake.Function(Vc, name=f"coarse_{expr.name()}") - manager = get_transfer_manager(Vf.dm) + V = expr.function_space() + Vnew = self(V, self) + name = expr.name() + if name is not None: + try: + name, prev_level = name.split("_level_") + except ValueError: + prev_level = 0 + level = int(prev_level) - 1 + name = f"{name}_level_{level}" + + new = firedrake.Function(Vnew, name=name) + manager = get_transfer_manager(V.dm) if is_dual(expr): manager.restrict(expr, new) else: @@ -182,10 +237,40 @@ def coarsen_function(expr, self, coefficient_mapping=None): return new +@refine.register(firedrake.Cofunction) +@refine.register(firedrake.Function) +def refine_function(expr, self, coefficient_mapping=None): + if coefficient_mapping is None: + coefficient_mapping = {} + new = coefficient_mapping.get(expr) + if new is None: + V = expr.function_space() + Vnew = self(V, self) + name = expr.name() + if name is not None: + try: + name, prev_level = name.split("_level_") + except ValueError: + prev_level = 0 + level = int(prev_level) + 1 + name = f"{name}_level_{level}" + + new = firedrake.Function(Vnew, name=name) + new.interpolate(expr) + coefficient_mapping[expr] = new + return new + + @coarsen.register(firedrake.NonlinearVariationalProblem) def coarsen_nlvp(problem, self, coefficient_mapping=None): - if hasattr(problem, "_coarse"): - return problem._coarse + # Have we done this already? + mh, _ = utils.get_level(problem.u.function_space().mesh()) + if self == coarsen and hasattr(problem, "_coarse"): + if mh is utils.get_level(problem._coarse.u.function_space().mesh())[0]: + return problem._coarse + elif self == refine and hasattr(problem, "_fine"): + if mh is utils.get_level(problem._fine.u.function_space().mesh())[0]: + return problem._fine def inject_on_restrict(fine, restriction, rscale, injection, coarse): manager = get_transfer_manager(fine) @@ -226,10 +311,40 @@ def inject_on_restrict(fine, restriction, rscale, injection, coarse): Jp = self(problem.Jp, self, coefficient_mapping=coefficient_mapping) u = coefficient_mapping[problem.u_restrict] - fine = problem + orig = problem problem = firedrake.NonlinearVariationalProblem(F, u, bcs=bcs, J=J, Jp=Jp, is_linear=problem.is_linear, form_compiler_parameters=problem.form_compiler_parameters) - fine._coarse = problem + if self == coarsen: + orig._coarse = problem + elif self == refine: + orig._fine = problem + return problem + + +@coarsen.register(firedrake.LinearEigenproblem) +def coarsen_eigenproblem(problem, self, coefficient_mapping=None): + # Have we done this already? + mh, _ = utils.get_level(problem.output_space().mesh()) + if self == coarsen and hasattr(problem, "_coarse"): + if mh is utils.get_level(problem._coarse.output_space.mesh())[0]: + return problem._coarse + elif self == refine and hasattr(problem, "_fine"): + if mh is utils.get_level(problem._fine.output_space.mesh())[0]: + return problem._fine + + if coefficient_mapping is None: + coefficient_mapping = {} + bcs = [self(bc, self, coefficient_mapping=coefficient_mapping) + for bc in problem._original_bcs] + A = self(problem._original_A, self, coefficient_mapping=coefficient_mapping) + M = self(problem._original_M, self, coefficient_mapping=coefficient_mapping) + orig = problem + problem = firedrake.LinearEigenproblem(A, M, bcs=bcs, + bc_shift=orig.bc_shift, restrict=orig.restrict) + if self == coarsen: + orig._coarse = problem + elif self == refine: + orig._fine = problem return problem @@ -264,10 +379,17 @@ def coarsen_snescontext(context, self, coefficient_mapping=None): if coefficient_mapping is None: coefficient_mapping = {} + if self == refine: + new_attr = "_fine" + old_attr = "_coarse" + else: + new_attr = "_coarse" + old_attr = "_fine" + # Have we already done this? - coarse = context._coarse - if coarse is not None: - return coarse + new_context = getattr(context, new_attr) + if new_context is not None: + return new_context problem = self(context._problem, self, coefficient_mapping=coefficient_mapping) appctx = context.appctx @@ -282,74 +404,89 @@ def coarsen_snescontext(context, self, coefficient_mapping=None): # Assume not something that needs coarsening (e.g. float) new_appctx[k] = v - # Get options prefix for current level - parent_context = context - while parent_context._fine: - parent_context = parent_context._fine - parent_prefix = parent_context.options_prefix - opts = PETSc.Options(parent_prefix) - if opts.getString("snes_type", "") == "fas": - solver_prefix = "fas_" - else: - solver_prefix = "mg_" - _, level = utils.get_level(problem.u_restrict.function_space().mesh()) - if level == 0: - levels_prefix = f"{solver_prefix}coarse_" - else: - levels_prefix = f"{solver_prefix}levels_" - current_level_prefix = f"{solver_prefix}levels_{level}_" - options_prefix = f"{parent_prefix}{current_level_prefix}" - - # Use different mat_type on each level - mat_type = None - pmat_type = None - sub_mat_type = None - sub_pmat_type = None - for prefix in (levels_prefix, current_level_prefix): - mat_type = opts.getString(f"{prefix}mat_type", "") or mat_type - pmat_type = opts.getString(f"{prefix}pmat_type", "") or pmat_type - sub_mat_type = opts.getString(f"{prefix}sub_mat_type", "") or sub_mat_type - sub_pmat_type = opts.getString(f"{prefix}sub_pmat_type", "") or sub_pmat_type - - pmat_type = pmat_type or mat_type - sub_pmat_type = sub_pmat_type or sub_mat_type - coarse = context.reconstruct(problem=problem, - mat_type=mat_type, - pmat_type=pmat_type, - sub_mat_type=sub_mat_type, - sub_pmat_type=sub_pmat_type, - appctx=new_appctx, - options_prefix=options_prefix, - ) - coarse._coefficient_mapping = coefficient_mapping - coarse._fine = context - context._coarse = coarse + mat_type = context.mat_type + pmat_type = context.pmat_type + sub_mat_type = context.sub_mat_type + sub_pmat_type = context.sub_pmat_type + options_prefix = context.options_prefix + + if self == coarsen: + # Get options prefix for current level + parent_context = context + while parent_context._fine: + parent_context = parent_context._fine + parent_prefix = parent_context.options_prefix + opts = PETSc.Options(parent_prefix) + if opts.getString("snes_type", "") == "fas": + solver_prefix = "fas_" + else: + solver_prefix = "mg_" + _, level = utils.get_level(problem.u_restrict.function_space().mesh()) + if level == 0: + levels_prefix = f"{solver_prefix}coarse_" + else: + levels_prefix = f"{solver_prefix}levels_" + current_level_prefix = f"{solver_prefix}levels_{level}_" + options_prefix = f"{parent_prefix}{current_level_prefix}" + + # Use different mat_type on each level + mat_type = None + pmat_type = None + sub_mat_type = None + sub_pmat_type = None + for prefix in (levels_prefix, current_level_prefix): + mat_type = opts.getString(f"{prefix}mat_type", "") or mat_type + pmat_type = opts.getString(f"{prefix}pmat_type", "") or pmat_type + sub_mat_type = opts.getString(f"{prefix}sub_mat_type", "") or sub_mat_type + sub_pmat_type = opts.getString(f"{prefix}sub_pmat_type", "") or sub_pmat_type + + pmat_type = pmat_type or mat_type + sub_pmat_type = sub_pmat_type or sub_mat_type + + new_context = context.reconstruct(problem=problem, + mat_type=mat_type, + pmat_type=pmat_type, + sub_mat_type=sub_mat_type, + sub_pmat_type=sub_pmat_type, + appctx=new_appctx, + options_prefix=options_prefix, + ) + new_context._coefficient_mapping = coefficient_mapping + setattr(new_context, old_attr, context) + setattr(context, new_attr, new_context) + + for attr in ("_adapt_solver", "_adapt_marking_callback"): + if hasattr(context, attr): + setattr(new_context, attr, getattr(context, attr)) solutiondm = context._problem.u_restrict.function_space().dm parentdm = get_parent(solutiondm) - # Now that we have the coarse snescontext, push it to the coarsened DMs - # Otherwise they won't have the right transfer manager when they are - # coarsened in turn + # Now that we have the reconstructed snescontext, push it to the reconstructed DMs. + # Otherwise they will not have the right transfer manager when they are reconstructed in turn. for val in coefficient_mapping.values(): if isinstance(val, (firedrake.Function, firedrake.Cofunction)): V = val.function_space() - coarseneddm = V.dm + newdm = V.dm # Now attach the hook to the parent DM - if get_appctx(coarseneddm) is None: - push_appctx(coarseneddm, coarse) + if get_appctx(newdm) is None: + push_appctx(newdm, new_context) if parentdm.getAttr("__setup_hooks__"): - add_hook(parentdm, teardown=partial(pop_appctx, coarseneddm, coarse)) - - ises = problem.J.arguments()[0].function_space()._ises - coarse._nullspace = self(context._nullspace, self, coefficient_mapping=coefficient_mapping) - coarse.set_nullspace(coarse._nullspace, ises, transpose=False, near=False) - coarse._nullspace_T = self(context._nullspace_T, self, coefficient_mapping=coefficient_mapping) - coarse.set_nullspace(coarse._nullspace_T, ises, transpose=True, near=False) - coarse._near_nullspace = self(context._near_nullspace, self, coefficient_mapping=coefficient_mapping) - coarse.set_nullspace(coarse._near_nullspace, ises, transpose=False, near=True) - - return coarse + add_hook(parentdm, teardown=partial(pop_appctx, newdm, new_context)) + + ises = new_context._x.function_space()._ises + for attr, transpose, near in (("_nullspace", False, False), + ("_nullspace_T", True, False), + ("_near_nullspace", False, True)): + nullspace = getattr(context, attr) + try: + nullspace = self(nullspace, self, coefficient_mapping=coefficient_mapping) + except CoarseningError: + pass + setattr(new_context, attr, nullspace) + new_context.set_nullspace(nullspace, ises, transpose=transpose, near=near) + + return new_context @coarsen.register(firedrake.slate.AssembledVector) diff --git a/firedrake/variational_solver.py b/firedrake/variational_solver.py index 4031bf3c5c..193bace341 100644 --- a/firedrake/variational_solver.py +++ b/firedrake/variational_solver.py @@ -194,7 +194,8 @@ def __init__(self, problem, *, solver_parameters=None, post_jacobian_callback=None, pre_function_callback=None, post_function_callback=None, - pre_apply_bcs=True): + pre_apply_bcs=True, + marking_callback=None): r""" :arg problem: A :class:`NonlinearVariationalProblem` to solve. :kwarg nullspace: an optional :class:`.VectorSpaceBasis` (or @@ -223,9 +224,14 @@ def __init__(self, problem, *, solver_parameters=None, before residual assembly. :kwarg post_function_callback: As above, but called immediately after residual assembly. - :kwarg pre_apply_bcs: If `True`, the bcs are applied before the solve. + :kwarg pre_apply_bcs: If True, the bcs are applied before the solve. Otherwise, the problem is linearised around the initial guess before imposing bcs, and the bcs are appended to the nonlinear system. + :kwarg marking_callback: An optional callable of the form + ``callback(solver, u)`` for PETSc-driven adaptive refinement. + The callback receives this solver and the current Firedrake + solution, and must return a DG0 :class:`.Function` or + :class:`.Cofunction` with positive values on cells to refine. Example usage of the ``solver_parameters`` option: to set the nonlinear solver type to just use a linear solver, use @@ -303,6 +309,10 @@ def update_diffusivity(current_solution): self._ctx = ctx self._work = problem.u_restrict.dof_dset.layout_vec.duplicate() self.snes.setDM(problem.dm) + self._marking_callback = None + if marking_callback is not None: + self.set_marking_callback(marking_callback) + ctx._adapt_marking_callback = self._marking_callback ctx.set_function(self.snes) ctx.set_jacobian(self.snes) @@ -328,6 +338,24 @@ def update_diffusivity(current_solution): self._transfer_operators = () self._setup = False + def set_marking_callback(self, callback): + r"""Set the callback used by PETSc-driven adaptive refinement. + + The callback is called as ``callback(solver, u)`` when PETSc asks the + solution DM to refine. It must return a DG0 :class:`.Function` or + :class:`.Cofunction` on the current solution mesh, with positive values + on cells to refine. + """ + if not callable(callback): + raise TypeError(f"marking callback must be callable, not a {type(callback).__name__}") + self._marking_callback = callback + self.parameters.setdefault("adaptor_criterion", "refine") + self._ctx._adapt_marking_callback = callback + + def get_adapted_solution(self): + r"""Return the current solution, including after PETSc adapts the DM.""" + return self._ctx._problem.u + def set_transfer_manager(self, manager): r"""Set the object that manages transfer between grid levels. Typically a :class:`~.TransferManager` object. @@ -353,11 +381,12 @@ def solve(self, bounds=None): ``vinewtonssls`` or ``vinewtonrsls``. """ # Make sure the DM has this solver's callback functions + self._ctx._adapt_marking_callback = self._marking_callback self._ctx.set_function(self.snes) self._ctx.set_jacobian(self.snes) # Make sure appcontext is attached to every DM from every coefficient and DirichletBC before we solve. - problem = self._problem + problem = self._ctx._problem forms = (problem.F, problem.J, problem.Jp) coefficients = utils.unique(chain.from_iterable(form.coefficients() for form in forms if form is not None)) solution_dm = self.snes.getDM() @@ -396,14 +425,19 @@ def solve(self, bounds=None): self._transfer_operators): stack.enter_context(ctx) self.snes.solve(None, work) - work.copy(u) + # The appctx might have been refined + self._ctx = dmhooks.get_appctx(self.snes.getDM()) + problem = self._ctx._problem + solution = self.snes.getSolution() + with problem.u_restrict.dat.vec as u: + solution.copy(u) self._setup = True if problem.restrict: problem.u.assign(problem.u_restrict) solving_utils.check_snes_convergence(self.snes) # Grab the comm associated with the `_problem` and call PETSc's garbage cleanup routine - comm = self._problem.u_restrict.function_space().mesh().comm + comm = problem.u_restrict.function_space().mesh().comm PETSc.garbage_cleanup(comm) diff --git a/tests/firedrake/multigrid/test_snes_adapt.py b/tests/firedrake/multigrid/test_snes_adapt.py new file mode 100644 index 0000000000..2972f93c05 --- /dev/null +++ b/tests/firedrake/multigrid/test_snes_adapt.py @@ -0,0 +1,148 @@ +import pytest +from firedrake import * +from firedrake import dmhooks +from firedrake.mg.utils import get_level + + +def test_marking_callback_configures_refine_adaptor(): + def mark_cells(solver, current_solution): + M = FunctionSpace(current_solution.mesh(), "DG", 0) + return Function(M).assign(1) + + mesh = UnitSquareMesh(1, 1) + V = FunctionSpace(mesh, "CG", 1) + u = Function(V) + v = TestFunction(V) + problem = NonlinearVariationalProblem((u - 1.0)*v*dx, u) + solver = NonlinearVariationalSolver(problem, marking_callback=mark_cells) + + assert solver.parameters["adaptor_criterion"] == "refine" + assert solver._ctx._adapt_marking_callback is mark_cells + + +@pytest.mark.skipnetgen +def test_marking_callback_refine_hook_reconstructs_problem(): + from netgen.geom2d import SplineGeometry + seen = [] + + def mark_cells(solver, current_solution): + current_mesh = current_solution.function_space().mesh() + seen.append(current_mesh) + M = FunctionSpace(current_mesh, "DG", 0) + markers = Function(M) + markers.assign(1) + return markers + + geo = SplineGeometry() + geo.AddRectangle((0, 0), (1, 1), bc="boundary") + mesh = Mesh(geo.GenerateMesh(maxh=0.5)) + V = FunctionSpace(mesh, "CG", 1) + old_dim = V.dim() + u = Function(V) + v = TestFunction(V) + problem = NonlinearVariationalProblem((u - 1.0)*v*dx, u) + solver = NonlinearVariationalSolver(problem, marking_callback=mark_cells) + + dm = solver.snes.getDM() + with dmhooks.add_hooks(dm, solver, appctx=solver._ctx): + newdm = dm.refine() + solver._ctx = dmhooks.get_appctx(newdm) + + adapted = solver.get_adapted_solution() + adapted_mesh = adapted.function_space().mesh() + hierarchy, level = get_level(adapted_mesh) + + assert seen[0] is mesh + assert newdm == solver._ctx._problem.dm + assert adapted_mesh is not mesh + assert level == 1 + assert hierarchy[1] is adapted_mesh + assert adapted.function_space().dim() > old_dim + + +@pytest.mark.skipnetgen +def test_snes_adapt_sequence_with_adaptive_multigrid(): + from netgen.occ import WorkPlane, Axes, OCCGeometry, X, Z + + rect1 = WorkPlane(Axes((0, 0, 0), n=Z, h=X)).Rectangle(1, 2).Face() + rect2 = WorkPlane(Axes((0, 1, 0), n=Z, h=X)).Rectangle(2, 1).Face() + mesh = Mesh(OCCGeometry(rect1 + rect2, dim=2).GenerateMesh(maxh=0.8)) + amh = AdaptiveMeshHierarchy(mesh) + atm = AdaptiveTransferManager() + + V = FunctionSpace(mesh, "CG", 1) + old_dim = V.dim() + u = TrialFunction(V) + v = TestFunction(V) + uh = Function(V, name="solution") + a = inner(grad(u), grad(v))*dx + L = inner(Constant(1), v)*dx + bcs = DirichletBC(V, 0, "on_boundary") + problem = LinearVariationalProblem(a, L, uh, bcs=bcs) + + def estimate_error(current_solution): + current_mesh = current_solution.function_space().mesh() + Q = FunctionSpace(current_mesh, "DG", 0) + eta_sq = Function(Q) + p = TrialFunction(Q) + q = TestFunction(Q) + residual = Constant(1) + div(grad(current_solution)) + h = CellDiameter(current_mesh) + n = FacetNormal(current_mesh) + vol = CellVolume(current_mesh) + + a = inner(p, q / vol) * dx + L = (inner(residual**2, q * h**2) * dx + + inner(jump(grad(current_solution), n)**2, avg(q * h)) * dS) + sp = {"mat_type": "matfree", "ksp_type": "preonly", "pc_type": "jacobi"} + solve(a == L, eta_sq, solver_parameters=sp) + return Function(Q).interpolate(sqrt(eta_sq)) + + seen = [] + + def mark_cells(solver, current_solution): + current_mesh = current_solution.function_space().mesh() + seen.append(current_mesh) + eta = estimate_error(current_solution) + with eta.dat.vec_ro as eta_vec: + _, eta_max = eta_vec.max() + markers = Function(eta.function_space()) + markers.interpolate(conditional(gt(eta, 0.5 * eta_max), 1, 0)) + return markers + + refinements = 5 + params = { + "mat_type": "aij", + "snes_adapt_sequence": refinements, + "ksp_type": "cg", + "ksp_max_it": 10, + "ksp_monitor": None, + "pc_type": "mg", + "mg_levels": { + "ksp_type": "chebyshev", + "ksp_max_it": 1, + "pc_type": "jacobi", + }, + "mg_levels_0": { + "mat_type": "aij", + "ksp_type": "preonly", + "pc_type": "lu", + }, + } + solver = LinearVariationalSolver(problem, + solver_parameters=params, + marking_callback=mark_cells) + solver.set_transfer_manager(atm) + solver.solve() + + u_adapted = solver.get_adapted_solution() + adapted_mesh = u_adapted.function_space().mesh() + hierarchy, level = get_level(adapted_mesh) + + assert seen[0] == mesh + assert hierarchy is amh + assert level == refinements + assert len(amh) == refinements + 1 + assert adapted_mesh is not mesh + assert u_adapted is not uh + assert u_adapted.function_space().dim() > old_dim From 1c7406aa0424317c2e8c99b8b3abf56af11f1d9a Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 2 Jul 2026 09:59:50 +0100 Subject: [PATCH 2/2] Support -snes_adapt --- firedrake/dmhooks.py | 69 ++++++ firedrake/mg/ufl_utils.py | 228 ++++++++++++++----- firedrake/solving_utils.py | 5 +- firedrake/variational_solver.py | 39 +++- tests/firedrake/multigrid/test_snes_adapt.py | 148 ++++++++++++ 5 files changed, 430 insertions(+), 59 deletions(-) create mode 100644 tests/firedrake/multigrid/test_snes_adapt.py diff --git a/firedrake/dmhooks.py b/firedrake/dmhooks.py index a2b60a1f50..9b65792554 100644 --- a/firedrake/dmhooks.py +++ b/firedrake/dmhooks.py @@ -446,6 +446,69 @@ def coarsen(dm, comm): return cdm +def _validate_refinement_markers(markers, mesh): + if not isinstance(markers, (firedrake.Function, firedrake.Cofunction)): + raise TypeError( + f"marking callback must return a Function or Cofunction, not a {type(markers).__name__}" + ) + M = markers.function_space() + if M.mesh() is not mesh: + raise ValueError("marking callback must return markers on the current solution mesh") + if M.finat_element.space_dimension() != 1: + raise ValueError("marking callback must return a DG0 Function or Cofunction") + + +def _attach_refined_context(parent, ctx, coefficient_mapping, coarsener): + dms = [ctx._problem.u_restrict.function_space().dm] + for value in coefficient_mapping.values(): + if isinstance(value, (firedrake.Function, firedrake.Cofunction)): + dm = value.function_space().dm + if dm not in dms: + dms.append(dm) + + for dm in dms: + add_hook(parent, setup=partial(push_parent, dm, parent), teardown=partial(pop_parent, dm, parent), + call_setup=True) + add_hook(parent, setup=partial(push_ctx_coarsener, dm, coarsener), + teardown=partial(pop_ctx_coarsener, dm, coarsener), + call_setup=True) + add_hook(parent, setup=partial(push_appctx, dm, ctx), teardown=partial(pop_appctx, dm, ctx), + call_setup=True) + + +def _adaptively_refine(dm, comm): + from firedrake.mg.adaptive_hierarchy import AdaptiveMeshHierarchy + from firedrake.mg.ufl_utils import refine + from firedrake.mg.utils import get_level + + ctx = get_appctx(dm) + callback = getattr(ctx, "_adapt_marking_callback", None) + + current_solution = ctx._x + mesh = current_solution.function_space().mesh() + hierarchy, level = get_level(mesh) + if hierarchy is None: + hierarchy = AdaptiveMeshHierarchy(mesh) + level = 0 + if not isinstance(hierarchy, AdaptiveMeshHierarchy): + raise RuntimeError("Adaptive SNES refinement requires an AdaptiveMeshHierarchy") + + if level == len(hierarchy) - 1: + solver = None + markers = callback(solver, current_solution) + _validate_refinement_markers(markers, mesh) + hierarchy.add_mesh(mesh.refine_marked_elements(markers)) + + coefficient_mapping = {} + refined_ctx = refine(ctx, refine, coefficient_mapping=coefficient_mapping) + refined_ctx._adapt_marking_callback = callback + + parent = get_parent(dm) + coarsener = get_ctx_coarsener(dm) + _attach_refined_context(parent, refined_ctx, coefficient_mapping, coarsener) + return refined_ctx._problem.dm + + @PETSc.Log.EventDecorator() def refine(dm, comm): """Callback to refine a DM. @@ -453,11 +516,17 @@ def refine(dm, comm): :arg DM: The DM to refine. :arg comm: The communicator for the new DM (ignored) """ + ctx = get_appctx(dm) + if getattr(ctx, "_adapt_marking_callback", None) is not None: + return _adaptively_refine(dm, comm) + from firedrake.mg.utils import get_level V = get_function_space(dm) if V is None: raise RuntimeError("No functionspace found on DM") hierarchy, level = get_level(V.mesh()) + if hierarchy is None: + raise RuntimeError("No mesh hierarchy available") if level >= len(hierarchy) - 1: raise RuntimeError("Cannot refine finest DM") if hasattr(V, "_fine"): diff --git a/firedrake/mg/ufl_utils.py b/firedrake/mg/ufl_utils.py index b40bfbd7ac..3dbd8d9e7c 100644 --- a/firedrake/mg/ufl_utils.py +++ b/firedrake/mg/ufl_utils.py @@ -14,7 +14,7 @@ from . import utils -__all__ = ["coarsen"] +__all__ = ["coarsen", "refine"] class CoarseningError(Exception): @@ -64,6 +64,15 @@ def coarsen(expr, self, coefficient_mapping=None): return expr +@singledispatch +def refine(expr, self, coefficient_mapping=None): + # Most coarsen handlers will simply reconstruct the expression tree. And + # very few of them branch on coarsen vs refine to handle both directions. + # Delegating here lets those shared handlers do the right thing when called + # via `refine(...)`. + return coarsen(expr, self, coefficient_mapping=coefficient_mapping) + + @coarsen.register(ufl.Mesh) @coarsen.register(ufl.MeshSequence) def coarsen_mesh(mesh, self, coefficient_mapping=None): @@ -73,9 +82,18 @@ def coarsen_mesh(mesh, self, coefficient_mapping=None): return hierarchy[level - 1] +@refine.register(ufl.Mesh) +@refine.register(ufl.MeshSequence) +def refine_mesh(mesh, self, coefficient_mapping=None): + hierarchy, level = utils.get_level(mesh) + if hierarchy is None: + raise CoarseningError("No mesh hierarchy available") + return hierarchy[level + 1] + + @coarsen.register(ufl.BaseForm) @coarsen.register(ufl.classes.Expr) -def coarse_expr(expr, self, coefficient_mapping=None): +def coarsen_expr(expr, self, coefficient_mapping=None): if expr is None: return None mapper = CoarsenIntegrand(self, coefficient_mapping) @@ -126,7 +144,7 @@ def coarsen_formsum(form, self, coefficient_mapping=None): @coarsen.register(firedrake.DirichletBC) def coarsen_bc(bc, self, coefficient_mapping=None): V = self(bc.function_space(), self, coefficient_mapping=coefficient_mapping) - val = self(bc.function_arg, self, coefficient_mapping=coefficient_mapping) + val = self(bc._original_arg, self, coefficient_mapping=coefficient_mapping) subdomain = bc.sub_domain return type(bc)(V, val, subdomain) @@ -149,18 +167,46 @@ def coarsen_equation_bc(ebc, self, coefficient_mapping=None): @coarsen.register(firedrake.functionspaceimpl.WithGeometryBase) def coarsen_function_space(V, self, coefficient_mapping=None): - if hasattr(V, "_coarse"): + # Handle MixedFunctionSpace : V.reconstruct requires MeshSequence. + mesh = V.mesh() if V.index is None else V.parent.mesh() + new_mesh = self(mesh, self) + if hasattr(V, "_coarse") and V._coarse.mesh() == new_mesh: return V._coarse - - V_fine = V - # Handle MixedFunctionSpace : V_fine.reconstruct requires MeshSequence. - fine_mesh = V_fine.mesh() if V_fine.index is None else V_fine.parent.mesh() - mesh_coarse = self(fine_mesh, self) - name = f"coarse_{V.name}" if V.name else None - V_coarse = V_fine.reconstruct(mesh=mesh_coarse, name=name) - V_coarse._fine = V_fine - V_fine._coarse = V_coarse - return V_coarse + # Get the parent name + V_parent = V + while hasattr(V_parent, "_fine") and V_parent._fine: + V_parent = V_parent._fine + name = V_parent.name + if name is not None: + mh, level = utils.get_level(new_mesh) + name = f"{name}_level_{level}" + # Reconstruct the space + V_new = V.reconstruct(mesh=new_mesh, name=name) + V_new._fine = V + V._coarse = V_new + return V_new + + +@refine.register(firedrake.functionspaceimpl.WithGeometryBase) +def refine_function_space(V, self, coefficient_mapping=None): + # Handle MixedFunctionSpace : V.reconstruct requires MeshSequence. + mesh = V.mesh() if V.index is None else V.parent.mesh() + new_mesh = self(mesh, self) + if hasattr(V, "_fine") and V._fine.mesh() == new_mesh: + return V._fine + # Get the parent name + V_parent = V + while hasattr(V_parent, "_coarse") and V_parent._coarse: + V_parent = V_parent._coarse + name = V_parent.name + if name is not None: + mh, level = utils.get_level(new_mesh) + name = f"{name}_level_{level}" + # Reconstruct the space + V_new = V.reconstruct(mesh=new_mesh, name=name) + V_new._coarse = V + V._fine = V_new + return V_new @coarsen.register(firedrake.Cofunction) @@ -170,10 +216,19 @@ def coarsen_function(expr, self, coefficient_mapping=None): coefficient_mapping = {} new = coefficient_mapping.get(expr) if new is None: - Vf = expr.function_space() - Vc = self(Vf, self) - new = firedrake.Function(Vc, name=f"coarse_{expr.name()}") - manager = get_transfer_manager(Vf.dm) + V = expr.function_space() + Vnew = self(V, self) + name = expr.name() + if name is not None: + try: + name, prev_level = name.split("_level_") + except ValueError: + prev_level = 0 + level = int(prev_level) - 1 + name = f"{name}_level_{level}" + + new = firedrake.Function(Vnew, name=name) + manager = get_transfer_manager(V.dm) if is_dual(expr): manager.restrict(expr, new) else: @@ -182,10 +237,40 @@ def coarsen_function(expr, self, coefficient_mapping=None): return new +@refine.register(firedrake.Cofunction) +@refine.register(firedrake.Function) +def refine_function(expr, self, coefficient_mapping=None): + if coefficient_mapping is None: + coefficient_mapping = {} + new = coefficient_mapping.get(expr) + if new is None: + V = expr.function_space() + Vnew = self(V, self) + name = expr.name() + if name is not None: + try: + name, prev_level = name.split("_level_") + except ValueError: + prev_level = 0 + level = int(prev_level) + 1 + name = f"{name}_level_{level}" + + new = firedrake.Function(Vnew, name=name) + new.interpolate(expr) + coefficient_mapping[expr] = new + return new + + @coarsen.register(firedrake.NonlinearVariationalProblem) def coarsen_nlvp(problem, self, coefficient_mapping=None): - if hasattr(problem, "_coarse"): - return problem._coarse + # Have we done this already? + mh, _ = utils.get_level(problem.u.function_space().mesh()) + if self == coarsen and hasattr(problem, "_coarse"): + if mh is utils.get_level(problem._coarse.u.function_space().mesh())[0]: + return problem._coarse + elif self == refine and hasattr(problem, "_fine"): + if mh is utils.get_level(problem._fine.u.function_space().mesh())[0]: + return problem._fine def inject_on_restrict(fine, restriction, rscale, injection, coarse): manager = get_transfer_manager(fine) @@ -226,10 +311,40 @@ def inject_on_restrict(fine, restriction, rscale, injection, coarse): Jp = self(problem.Jp, self, coefficient_mapping=coefficient_mapping) u = coefficient_mapping[problem.u_restrict] - fine = problem + orig = problem problem = firedrake.NonlinearVariationalProblem(F, u, bcs=bcs, J=J, Jp=Jp, is_linear=problem.is_linear, form_compiler_parameters=problem.form_compiler_parameters) - fine._coarse = problem + if self == coarsen: + orig._coarse = problem + elif self == refine: + orig._fine = problem + return problem + + +@coarsen.register(firedrake.LinearEigenproblem) +def coarsen_eigenproblem(problem, self, coefficient_mapping=None): + # Have we done this already? + mh, _ = utils.get_level(problem.output_space().mesh()) + if self == coarsen and hasattr(problem, "_coarse"): + if mh is utils.get_level(problem._coarse.output_space.mesh())[0]: + return problem._coarse + elif self == refine and hasattr(problem, "_fine"): + if mh is utils.get_level(problem._fine.output_space.mesh())[0]: + return problem._fine + + if coefficient_mapping is None: + coefficient_mapping = {} + bcs = [self(bc, self, coefficient_mapping=coefficient_mapping) + for bc in problem._original_bcs] + A = self(problem._original_A, self, coefficient_mapping=coefficient_mapping) + M = self(problem._original_M, self, coefficient_mapping=coefficient_mapping) + orig = problem + problem = firedrake.LinearEigenproblem(A, M, bcs=bcs, + bc_shift=orig.bc_shift, restrict=orig.restrict) + if self == coarsen: + orig._coarse = problem + elif self == refine: + orig._fine = problem return problem @@ -264,10 +379,17 @@ def coarsen_snescontext(context, self, coefficient_mapping=None): if coefficient_mapping is None: coefficient_mapping = {} + if self == refine: + new_attr = "_fine" + old_attr = "_coarse" + else: + new_attr = "_coarse" + old_attr = "_fine" + # Have we already done this? - coarse = context._coarse - if coarse is not None: - return coarse + new_context = getattr(context, new_attr) + if new_context is not None: + return new_context problem = self(context._problem, self, coefficient_mapping=coefficient_mapping) appctx = context.appctx @@ -284,8 +406,9 @@ def coarsen_snescontext(context, self, coefficient_mapping=None): # Get options prefix for current level parent_context = context - while parent_context._fine: - parent_context = parent_context._fine + while getattr(parent_context, old_attr, None): + parent_context = getattr(parent_context, old_attr, None) + parent_prefix = parent_context.options_prefix opts = PETSc.Options(parent_prefix) if opts.getString("snes_type", "") == "fas": @@ -313,43 +436,42 @@ def coarsen_snescontext(context, self, coefficient_mapping=None): pmat_type = pmat_type or mat_type sub_pmat_type = sub_pmat_type or sub_mat_type - coarse = context.reconstruct(problem=problem, - mat_type=mat_type, - pmat_type=pmat_type, - sub_mat_type=sub_mat_type, - sub_pmat_type=sub_pmat_type, - appctx=new_appctx, - options_prefix=options_prefix, - ) - coarse._coefficient_mapping = coefficient_mapping - coarse._fine = context - context._coarse = coarse + new_context = context.reconstruct(problem=problem, + mat_type=mat_type, + pmat_type=pmat_type, + sub_mat_type=sub_mat_type, + sub_pmat_type=sub_pmat_type, + appctx=new_appctx, + options_prefix=options_prefix, + ) + new_context._coefficient_mapping = coefficient_mapping + setattr(new_context, old_attr, context) + setattr(context, new_attr, new_context) solutiondm = context._problem.u_restrict.function_space().dm parentdm = get_parent(solutiondm) - # Now that we have the coarse snescontext, push it to the coarsened DMs - # Otherwise they won't have the right transfer manager when they are - # coarsened in turn + # Now that we have the reconstructed snescontext, push it to the reconstructed DMs. + # Otherwise they will not have the right transfer manager when they are reconstructed in turn. for val in coefficient_mapping.values(): if isinstance(val, (firedrake.Function, firedrake.Cofunction)): V = val.function_space() - coarseneddm = V.dm + newdm = V.dm # Now attach the hook to the parent DM - if get_appctx(coarseneddm) is None: - push_appctx(coarseneddm, coarse) + if get_appctx(newdm) is None: + push_appctx(newdm, new_context) if parentdm.getAttr("__setup_hooks__"): - add_hook(parentdm, teardown=partial(pop_appctx, coarseneddm, coarse)) + add_hook(parentdm, teardown=partial(pop_appctx, newdm, new_context)) - ises = problem.J.arguments()[0].function_space()._ises - coarse._nullspace = self(context._nullspace, self, coefficient_mapping=coefficient_mapping) - coarse.set_nullspace(coarse._nullspace, ises, transpose=False, near=False) - coarse._nullspace_T = self(context._nullspace_T, self, coefficient_mapping=coefficient_mapping) - coarse.set_nullspace(coarse._nullspace_T, ises, transpose=True, near=False) - coarse._near_nullspace = self(context._near_nullspace, self, coefficient_mapping=coefficient_mapping) - coarse.set_nullspace(coarse._near_nullspace, ises, transpose=False, near=True) + ises = new_context._x.function_space()._ises + new_context._nullspace = self(context._nullspace, self, coefficient_mapping=coefficient_mapping) + new_context.set_nullspace(new_context._nullspace, ises, transpose=False, near=False) + new_context._nullspace_T = self(context._nullspace_T, self, coefficient_mapping=coefficient_mapping) + new_context.set_nullspace(new_context._nullspace_T, ises, transpose=True, near=False) + new_context._near_nullspace = self(context._near_nullspace, self, coefficient_mapping=coefficient_mapping) + new_context.set_nullspace(new_context._near_nullspace, ises, transpose=False, near=True) - return coarse + return new_context @coarsen.register(firedrake.slate.AssembledVector) diff --git a/firedrake/solving_utils.py b/firedrake/solving_utils.py index db18865712..dab0aa05c9 100644 --- a/firedrake/solving_utils.py +++ b/firedrake/solving_utils.py @@ -189,7 +189,8 @@ def __init__(self, problem, post_jacobian_callback=None, post_function_callback=None, options_prefix: str | None = None, transfer_manager=None, - pre_apply_bcs: bool = True): + pre_apply_bcs: bool = True, + marking_callback=None): from firedrake.assemble import get_assembler if pmat_type is None: @@ -276,6 +277,7 @@ def __init__(self, problem, self._near_nullspace = None self._coefficient_mapping = None self._transfer_manager = transfer_manager + self._marking_callback = marking_callback def reconstruct(self, problem=None, mat_type=None, pmat_type=None, **kwargs): """Reconstruct this _SNESContext instance with new arguments.""" @@ -290,6 +292,7 @@ def reconstruct(self, problem=None, mat_type=None, pmat_type=None, **kwargs): "options_prefix": self.options_prefix, "transfer_manager": self.transfer_manager, "pre_apply_bcs": self.pre_apply_bcs, + "marking_callback": self._marking_callback, } for k, v in default_options.items(): if kwargs.get(k) is None: diff --git a/firedrake/variational_solver.py b/firedrake/variational_solver.py index 4031bf3c5c..ebd62c2196 100644 --- a/firedrake/variational_solver.py +++ b/firedrake/variational_solver.py @@ -194,7 +194,8 @@ def __init__(self, problem, *, solver_parameters=None, post_jacobian_callback=None, pre_function_callback=None, post_function_callback=None, - pre_apply_bcs=True): + pre_apply_bcs=True, + marking_callback=None): r""" :arg problem: A :class:`NonlinearVariationalProblem` to solve. :kwarg nullspace: an optional :class:`.VectorSpaceBasis` (or @@ -223,9 +224,14 @@ def __init__(self, problem, *, solver_parameters=None, before residual assembly. :kwarg post_function_callback: As above, but called immediately after residual assembly. - :kwarg pre_apply_bcs: If `True`, the bcs are applied before the solve. + :kwarg pre_apply_bcs: If True, the bcs are applied before the solve. Otherwise, the problem is linearised around the initial guess before imposing bcs, and the bcs are appended to the nonlinear system. + :kwarg marking_callback: An optional callable of the form + ``callback(solver, u)`` for PETSc-driven adaptive refinement. + The callback receives this solver and the current Firedrake + solution, and must return a DG0 :class:`.Function` or + :class:`.Cofunction` with positive values on cells to refine. Example usage of the ``solver_parameters`` option: to set the nonlinear solver type to just use a linear solver, use @@ -293,6 +299,7 @@ def update_diffusivity(current_solution): pre_function_callback=pre_function_callback, post_jacobian_callback=post_jacobian_callback, post_function_callback=post_function_callback, + marking_callback=marking_callback, options_prefix=self.options_prefix, pre_apply_bcs=pre_apply_bcs) @@ -328,6 +335,23 @@ def update_diffusivity(current_solution): self._transfer_operators = () self._setup = False + def set_marking_callback(self, callback): + r"""Set the callback used by PETSc-driven adaptive refinement. + + The callback is called as ``callback(solver, u)`` when PETSc asks the + solution DM to refine. It must return a DG0 :class:`.Function` or + :class:`.Cofunction` on the current solution mesh, with positive values + on cells to refine. + """ + if not callable(callback): + raise TypeError(f"marking callback must be callable, not a {type(callback).__name__}") + self.parameters.setdefault("adaptor_criterion", "refine") + self._ctx._marking_callback = callback + + def get_adapted_solution(self): + r"""Return the current solution, including after PETSc adapts the DM.""" + return self._ctx._problem.u + def set_transfer_manager(self, manager): r"""Set the object that manages transfer between grid levels. Typically a :class:`~.TransferManager` object. @@ -357,7 +381,7 @@ def solve(self, bounds=None): self._ctx.set_jacobian(self.snes) # Make sure appcontext is attached to every DM from every coefficient and DirichletBC before we solve. - problem = self._problem + problem = self._ctx._problem forms = (problem.F, problem.J, problem.Jp) coefficients = utils.unique(chain.from_iterable(form.coefficients() for form in forms if form is not None)) solution_dm = self.snes.getDM() @@ -396,14 +420,19 @@ def solve(self, bounds=None): self._transfer_operators): stack.enter_context(ctx) self.snes.solve(None, work) - work.copy(u) + # The appctx might have been refined + self._ctx = dmhooks.get_appctx(self.snes.getDM()) + problem = self._ctx._problem + solution = self.snes.getSolution() + with problem.u_restrict.dat.vec as u: + solution.copy(u) self._setup = True if problem.restrict: problem.u.assign(problem.u_restrict) solving_utils.check_snes_convergence(self.snes) # Grab the comm associated with the `_problem` and call PETSc's garbage cleanup routine - comm = self._problem.u_restrict.function_space().mesh().comm + comm = problem.u_restrict.function_space().mesh().comm PETSc.garbage_cleanup(comm) diff --git a/tests/firedrake/multigrid/test_snes_adapt.py b/tests/firedrake/multigrid/test_snes_adapt.py new file mode 100644 index 0000000000..2972f93c05 --- /dev/null +++ b/tests/firedrake/multigrid/test_snes_adapt.py @@ -0,0 +1,148 @@ +import pytest +from firedrake import * +from firedrake import dmhooks +from firedrake.mg.utils import get_level + + +def test_marking_callback_configures_refine_adaptor(): + def mark_cells(solver, current_solution): + M = FunctionSpace(current_solution.mesh(), "DG", 0) + return Function(M).assign(1) + + mesh = UnitSquareMesh(1, 1) + V = FunctionSpace(mesh, "CG", 1) + u = Function(V) + v = TestFunction(V) + problem = NonlinearVariationalProblem((u - 1.0)*v*dx, u) + solver = NonlinearVariationalSolver(problem, marking_callback=mark_cells) + + assert solver.parameters["adaptor_criterion"] == "refine" + assert solver._ctx._adapt_marking_callback is mark_cells + + +@pytest.mark.skipnetgen +def test_marking_callback_refine_hook_reconstructs_problem(): + from netgen.geom2d import SplineGeometry + seen = [] + + def mark_cells(solver, current_solution): + current_mesh = current_solution.function_space().mesh() + seen.append(current_mesh) + M = FunctionSpace(current_mesh, "DG", 0) + markers = Function(M) + markers.assign(1) + return markers + + geo = SplineGeometry() + geo.AddRectangle((0, 0), (1, 1), bc="boundary") + mesh = Mesh(geo.GenerateMesh(maxh=0.5)) + V = FunctionSpace(mesh, "CG", 1) + old_dim = V.dim() + u = Function(V) + v = TestFunction(V) + problem = NonlinearVariationalProblem((u - 1.0)*v*dx, u) + solver = NonlinearVariationalSolver(problem, marking_callback=mark_cells) + + dm = solver.snes.getDM() + with dmhooks.add_hooks(dm, solver, appctx=solver._ctx): + newdm = dm.refine() + solver._ctx = dmhooks.get_appctx(newdm) + + adapted = solver.get_adapted_solution() + adapted_mesh = adapted.function_space().mesh() + hierarchy, level = get_level(adapted_mesh) + + assert seen[0] is mesh + assert newdm == solver._ctx._problem.dm + assert adapted_mesh is not mesh + assert level == 1 + assert hierarchy[1] is adapted_mesh + assert adapted.function_space().dim() > old_dim + + +@pytest.mark.skipnetgen +def test_snes_adapt_sequence_with_adaptive_multigrid(): + from netgen.occ import WorkPlane, Axes, OCCGeometry, X, Z + + rect1 = WorkPlane(Axes((0, 0, 0), n=Z, h=X)).Rectangle(1, 2).Face() + rect2 = WorkPlane(Axes((0, 1, 0), n=Z, h=X)).Rectangle(2, 1).Face() + mesh = Mesh(OCCGeometry(rect1 + rect2, dim=2).GenerateMesh(maxh=0.8)) + amh = AdaptiveMeshHierarchy(mesh) + atm = AdaptiveTransferManager() + + V = FunctionSpace(mesh, "CG", 1) + old_dim = V.dim() + u = TrialFunction(V) + v = TestFunction(V) + uh = Function(V, name="solution") + a = inner(grad(u), grad(v))*dx + L = inner(Constant(1), v)*dx + bcs = DirichletBC(V, 0, "on_boundary") + problem = LinearVariationalProblem(a, L, uh, bcs=bcs) + + def estimate_error(current_solution): + current_mesh = current_solution.function_space().mesh() + Q = FunctionSpace(current_mesh, "DG", 0) + eta_sq = Function(Q) + p = TrialFunction(Q) + q = TestFunction(Q) + residual = Constant(1) + div(grad(current_solution)) + h = CellDiameter(current_mesh) + n = FacetNormal(current_mesh) + vol = CellVolume(current_mesh) + + a = inner(p, q / vol) * dx + L = (inner(residual**2, q * h**2) * dx + + inner(jump(grad(current_solution), n)**2, avg(q * h)) * dS) + sp = {"mat_type": "matfree", "ksp_type": "preonly", "pc_type": "jacobi"} + solve(a == L, eta_sq, solver_parameters=sp) + return Function(Q).interpolate(sqrt(eta_sq)) + + seen = [] + + def mark_cells(solver, current_solution): + current_mesh = current_solution.function_space().mesh() + seen.append(current_mesh) + eta = estimate_error(current_solution) + with eta.dat.vec_ro as eta_vec: + _, eta_max = eta_vec.max() + markers = Function(eta.function_space()) + markers.interpolate(conditional(gt(eta, 0.5 * eta_max), 1, 0)) + return markers + + refinements = 5 + params = { + "mat_type": "aij", + "snes_adapt_sequence": refinements, + "ksp_type": "cg", + "ksp_max_it": 10, + "ksp_monitor": None, + "pc_type": "mg", + "mg_levels": { + "ksp_type": "chebyshev", + "ksp_max_it": 1, + "pc_type": "jacobi", + }, + "mg_levels_0": { + "mat_type": "aij", + "ksp_type": "preonly", + "pc_type": "lu", + }, + } + solver = LinearVariationalSolver(problem, + solver_parameters=params, + marking_callback=mark_cells) + solver.set_transfer_manager(atm) + solver.solve() + + u_adapted = solver.get_adapted_solution() + adapted_mesh = u_adapted.function_space().mesh() + hierarchy, level = get_level(adapted_mesh) + + assert seen[0] == mesh + assert hierarchy is amh + assert level == refinements + assert len(amh) == refinements + 1 + assert adapted_mesh is not mesh + assert u_adapted is not uh + assert u_adapted.function_space().dim() > old_dim