From a34a1e9f7d23abf07a6be33c7313861f77c28bd6 Mon Sep 17 00:00:00 2001
From: Timo Koch <timokoch@math.uio.no>
Date: Sun, 19 Mar 2023 19:40:13 +0100
Subject: [PATCH] [istlsolvers] Improve parameter handling and allow setting a
 matrix operator for reuse

* The solver accepts now either a parameter group or a Dune::ParameterTree. In the case of the latter,
the parameter tree is forwarded without modificaiton to the dune-istl solvers.

* Introduce a new function to set the maximum number of itertions

* Introduce two new interfaces "setMatrix(A)" and "solve(x, b)". The first one constructs a solver
based on the matrix A. If the new solve interface without matrix is called, this pre-constructed
solver is used in the solver. solve(A, x, b) still exists and construct a solver based on A and
ignores any stored solver. This allows to contruct a solver and then reuse it for many right hand sides.
This can be practical, for example, in parallel for linear problems
where constructing the solver involves the modificaiton
of the matrix and communication and is therefore an expensive step.

* Return solver result that is convertible to bool but contains Dune::InverseOperatorResults

* Simplify parallel code branches

* Make solver copyable by using shared_ptr for parallel helper
---
 dumux/linear/istlsolvers.hh      | 511 +++++++++++++++++++++++--------
 test/linear/test_linearsolver.cc |   7 +-
 2 files changed, 384 insertions(+), 134 deletions(-)

diff --git a/dumux/linear/istlsolvers.hh b/dumux/linear/istlsolvers.hh
index 6d29313ed9..4dba72b3bb 100644
--- a/dumux/linear/istlsolvers.hh
+++ b/dumux/linear/istlsolvers.hh
@@ -25,8 +25,10 @@
 #define DUMUX_LINEAR_ISTL_SOLVERS_HH
 
 #include <memory>
+#include <variant>
 
 #include <dune/common/exceptions.hh>
+#include <dune/common/shared_ptr.hh>
 #include <dune/common/parallel/indexset.hh>
 #include <dune/common/parallel/mpicommunication.hh>
 #include <dune/grid/common/capabilities.hh>
@@ -93,10 +95,67 @@ class IstlDefaultPreconditionerFactory
 
 using IstlAmgPreconditionerFactory = Dune::AMGCreator;
 
+template<class M, bool convert = false>
+struct MatrixForSolver { using type = M; };
+
+template<class M>
+struct MatrixForSolver<M, true>
+{ using type = std::decay_t<decltype(MatrixConverter<M>::multiTypeToBCRSMatrix(std::declval<M>()))>; };
+
+template<class V, bool convert = false>
+struct VectorForSolver { using type = V; };
+
+template<class V>
+struct VectorForSolver<V, true>
+{ using type = std::decay_t<decltype(VectorConverter<V>::multiTypeToBlockVector(std::declval<V>()))>; };
+
+template<class LSTraits, class LATraits, bool convert, bool parallel = LSTraits::canCommunicate>
+struct MatrixOperator;
+
+template<class LSTraits, class LATraits, bool convert>
+struct MatrixOperator<LSTraits, LATraits, convert, true>
+{
+    using M = typename MatrixForSolver<typename LATraits::Matrix, convert>::type;
+    using V = typename VectorForSolver<typename LATraits::Vector, convert>::type;
+#if HAVE_MPI
+    using type = std::variant<
+        std::shared_ptr<typename LSTraits::template Sequential<M, V>::LinearOperator>,
+        std::shared_ptr<typename LSTraits::template ParallelOverlapping<M, V>::LinearOperator>,
+        std::shared_ptr<typename LSTraits::template ParallelNonoverlapping<M, V>::LinearOperator>
+    >;
+#else
+    using type = std::variant<
+        std::shared_ptr<typename LSTraits::template Sequential<M, V>::LinearOperator>
+    >;
+#endif
+};
+
+template<class LSTraits, class LATraits, bool convert>
+struct MatrixOperator<LSTraits, LATraits, convert, false>
+{
+    using M = typename MatrixForSolver<typename LATraits::Matrix, convert>::type;
+    using V = typename VectorForSolver<typename LATraits::Vector, convert>::type;
+    using type = std::variant<
+        std::shared_ptr<typename LSTraits::template Sequential<M, V>::LinearOperator>
+    >;
+};
+
 } // end namespace Dumux::Detail::IstlSolvers
 
 namespace Dumux::Detail {
 
+struct IstlSolverResult : public Dune::InverseOperatorResult
+{
+    IstlSolverResult() = default;
+    IstlSolverResult(const IstlSolverResult&) = default;
+    IstlSolverResult(IstlSolverResult&&) = default;
+
+    IstlSolverResult(const Dune::InverseOperatorResult& o) : InverseOperatorResult(o) {}
+    IstlSolverResult(Dune::InverseOperatorResult&& o) : InverseOperatorResult(std::move(o)) {}
+
+    operator bool() const { return this->converged; }
+};
+
 /*!
  * \ingroup Linear
  * \brief Standard dune-istl iterative linear solvers
@@ -110,25 +169,36 @@ class IstlIterativeLinearSolver
     using XVector = typename LinearAlgebraTraits::Vector;
     using BVector = typename LinearAlgebraTraits::Vector;
     using Scalar = typename InverseOperator::real_type;
+
     using ScalarProduct = Dune::ScalarProduct<typename InverseOperator::domain_type>;
+
     static constexpr bool convertMultiTypeVectorAndMatrix
         = convertMultiTypeLATypes && isMultiTypeBlockVector<XVector>::value;
+    using MatrixForSolver = typename Detail::IstlSolvers::MatrixForSolver<Matrix, convertMultiTypeVectorAndMatrix>::type;
+    using BVectorForSolver = typename Detail::IstlSolvers::VectorForSolver<BVector, convertMultiTypeVectorAndMatrix>::type;
+    using XVectorForSolver = typename Detail::IstlSolvers::VectorForSolver<XVector, convertMultiTypeVectorAndMatrix>::type;
+    // a variant type that can hold sequential, overlapping, and non-overlapping operators
+    using MatrixOperatorHolder = typename Detail::IstlSolvers::MatrixOperator<
+        LinearSolverTraits, LinearAlgebraTraits, convertMultiTypeVectorAndMatrix
+    >::type;
+
 #if HAVE_MPI
     using Comm = Dune::OwnerOverlapCopyCommunication<Dune::bigunsignedint<96>, int>;
     using ParallelHelper = ParallelISTLHelper<LinearSolverTraits>;
 #endif
 
+    using ParameterInitializer = std::variant<std::string, Dune::ParameterTree>;
 public:
 
     /*!
      * \brief Constructor for sequential solvers
      */
-    IstlIterativeLinearSolver(const std::string& paramGroup = "")
+    IstlIterativeLinearSolver(const ParameterInitializer& params = "")
     {
         if (Dune::MPIHelper::getCommunication().size() > 1)
             DUNE_THROW(Dune::InvalidStateException, "Using sequential constructor for parallel run. Use signature with gridView and dofMapper!");
 
-        initializeParameters_(paramGroup);
+        initializeParameters_(params);
         solverCategory_ = Dune::SolverCategory::sequential;
         scalarProduct_ = std::make_shared<ScalarProduct>();
     }
@@ -139,9 +209,9 @@ public:
     template <class GridView, class DofMapper>
     IstlIterativeLinearSolver(const GridView& gridView,
                               const DofMapper& dofMapper,
-                              const std::string& paramGroup = "")
+                              const ParameterInitializer& params = "")
     {
-        initializeParameters_(paramGroup);
+        initializeParameters_(params);
 #if HAVE_MPI
         solverCategory_ = Detail::solverCategory<LinearSolverTraits>(gridView);
         if constexpr (LinearSolverTraits::canCommunicate)
@@ -149,7 +219,7 @@ public:
 
             if (solverCategory_ != Dune::SolverCategory::sequential)
             {
-                parallelHelper_ = std::make_unique<ParallelISTLHelper<LinearSolverTraits>>(gridView, dofMapper);
+                parallelHelper_ = std::make_shared<ParallelISTLHelper<LinearSolverTraits>>(gridView, dofMapper);
                 communication_ = std::make_shared<Comm>(gridView.comm(), solverCategory_);
                 scalarProduct_ = Dune::createScalarProduct<XVector>(*communication_, solverCategory_);
                 parallelHelper_->createParallelIndexSet(*communication_);
@@ -174,9 +244,9 @@ public:
                               std::shared_ptr<ScalarProduct> scalarProduct,
                               const GridView& gridView,
                               const DofMapper& dofMapper,
-                              const std::string& paramGroup = "")
+                              const ParameterInitializer& params = "")
     {
-        initializeParameters_(paramGroup);
+        initializeParameters_(params);
         solverCategory_ = Detail::solverCategory(gridView);
         scalarProduct_ = scalarProduct;
         communication_ = communication;
@@ -184,22 +254,49 @@ public:
         {
             if (solverCategory_ != Dune::SolverCategory::sequential)
             {
-                parallelHelper_ = std::make_unique<ParallelISTLHelper<LinearSolverTraits>>(gridView, dofMapper);
+                parallelHelper_ = std::make_shared<ParallelISTLHelper<LinearSolverTraits>>(gridView, dofMapper);
                 parallelHelper_->createParallelIndexSet(communication);
             }
         }
     }
 #endif
 
-    bool solve(Matrix& A, XVector& x, BVector& b)
+    /*!
+     * \brief Solve the linear system Ax = b
+     */
+    IstlSolverResult solve(Matrix& A, XVector& x, BVector& b)
+    { return solveSequentialOrParallel_(A, x, b); }
+
+    /*!
+     * \brief Set the matrix A of the linear system Ax = b for reuse
+     */
+    void setMatrix(std::shared_ptr<Matrix> A)
     {
-#if HAVE_MPI
-        return solveSequentialOrParallel_(A, x, b);
-#else
-        return solveSequential_(A, x, b);
-#endif
+        linearOperator_ = makeParallelOrSequentialLinearOperator_(std::move(A));
+        solver_ = constructPreconditionedSolver_(linearOperator_);
     }
 
+    /*!
+     * \brief Set the matrix A of the linear system Ax = b for reuse
+     * \note The client has to take care of the lifetime management of A
+     */
+    void setMatrix(Matrix& A)
+    { setMatrix(Dune::stackobject_to_shared_ptr(A)); }
+
+    /*!
+     * \brief Solve the linear system Ax = b where A has been set with \ref setMatrix
+     */
+    IstlSolverResult solve(XVector& x, BVector& b) const
+    {
+        if (!solver_)
+            DUNE_THROW(Dune::InvalidStateException, "Called solve(x, b) but no linear operator has been set");
+
+        return solveSequentialOrParallel_(x, b, *solver_);
+    }
+
+    /*!
+     * \brief Compute the 2-norm of vector x
+     */
     Scalar norm(const XVector& x) const
     {
 #if HAVE_MPI
@@ -225,133 +322,244 @@ public:
             return scalarProduct_->norm(x);
     }
 
-    const Dune::InverseOperatorResult& result() const
-    {
-        return result_;
-    }
-
+    /*!
+     * \brief The name of the linear solver
+     */
     const std::string& name() const
     {
         return name_;
     }
 
+    /*!
+     * \brief Set the residual reduction tolerance
+     */
     void setResidualReduction(double residReduction)
-    { params_["reduction"] = std::to_string(residReduction); }
+    {
+        params_["reduction"] = std::to_string(residReduction);
+
+        // reconstruct the solver with new parameters
+        if (solver_)
+            solver_ = constructPreconditionedSolver_(linearOperator_);
+    }
+
+    /*!
+     * \brief Set the maximum number of linear solver iterations
+     */
+    void setMaxIter(std::size_t maxIter)
+    {
+        params_["maxit"] = std::to_string(maxIter);
+
+        // reconstruct the solver with new parameters
+        if (solver_)
+            solver_ = constructPreconditionedSolver_(linearOperator_);
+    }
+
+    /*!
+     * \brief Set the linear solver parameters
+     * \param params Either a std::string giving a parameter group (parameters are read from input file) or a Dune::ParameterTree
+     * \note In case of a Dune::ParameterTree, the parameters are passed trait to the linear solver and preconditioner
+     */
+    void setParams(const ParameterInitializer& params)
+    {
+        initializeParameters_(params);
+
+        // reconstruct the solver with new parameters
+        if (solver_)
+            solver_ = constructPreconditionedSolver_(linearOperator_);
+    }
 
 private:
 
-    void initializeParameters_(const std::string& paramGroup)
+    void initializeParameters_(const ParameterInitializer& params)
     {
-        params_ = Dumux::LinearSolverParameters<LinearSolverTraits>::createParameterTree(paramGroup);
+        if (std::holds_alternative<std::string>(params))
+            params_ = Dumux::LinearSolverParameters<LinearSolverTraits>::createParameterTree(std::get<std::string>(params));
+        else
+            params_ = std::get<Dune::ParameterTree>(params);
     }
 
-    bool solveSequential_(Matrix& A, XVector& x, BVector& b)
+    MatrixOperatorHolder makeSequentialLinearOperator_(std::shared_ptr<Matrix> A)
     {
+        using SequentialTraits = typename LinearSolverTraits::template Sequential<MatrixForSolver, XVectorForSolver>;
         if constexpr (convertMultiTypeVectorAndMatrix)
         {
             // create the BCRS matrix the IterativeSolver backend can handle
-            auto M = MatrixConverter<Matrix>::multiTypeToBCRSMatrix(A);
+            auto M = std::make_shared<MatrixForSolver>(MatrixConverter<Matrix>::multiTypeToBCRSMatrix(*A));
+            return std::make_shared<typename SequentialTraits::LinearOperator>(M);
+        }
+        else
+        {
+            return std::make_shared<typename SequentialTraits::LinearOperator>(A);
+        }
+    }
 
-            // get the new matrix sizes
-            const std::size_t numRows = M.N();
-            assert(numRows == M.M());
+    template<class ParallelTraits>
+    MatrixOperatorHolder makeParallelLinearOperator_(std::shared_ptr<Matrix> A, ParallelTraits = {})
+    {
+#if HAVE_MPI
+        // make matrix consistent
+        prepareMatrixParallel<LinearSolverTraits, ParallelTraits>(*A, *parallelHelper_);
+        return std::make_shared<typename ParallelTraits::LinearOperator>(std::move(A), *communication_);
+#else
+        DUNE_THROW(Dune::InvalidStateException, "Calling makeParallelLinearOperator for sequential run");
+#endif
+    }
 
+    MatrixOperatorHolder makeParallelOrSequentialLinearOperator_(std::shared_ptr<Matrix> A)
+    {
+        return executeSequentialOrParallel_(
+            [&]{ return makeSequentialLinearOperator_(std::move(A)); },
+            [&](auto traits){ return makeParallelLinearOperator_(std::move(A), traits); }
+        );
+    }
+
+    MatrixOperatorHolder makeSequentialLinearOperator_(Matrix& A)
+    { return makeSequentialLinearOperator_(Dune::stackobject_to_shared_ptr<Matrix>(A)); }
+
+    MatrixOperatorHolder makeParallelOrSequentialLinearOperator_(Matrix& A)
+    { return makeParallelOrSequentialLinearOperator_(Dune::stackobject_to_shared_ptr<Matrix>(A)); }
+
+    template<class ParallelTraits>
+    MatrixOperatorHolder makeParallelLinearOperator_(Matrix& A, ParallelTraits = {})
+    { return makeParallelLinearOperator_<ParallelTraits>(Dune::stackobject_to_shared_ptr<Matrix>(A)); }
+
+    IstlSolverResult solveSequential_(Matrix& A, XVector& x, BVector& b)
+    {
+        // construct solver from linear operator
+        auto linearOperatorHolder = makeSequentialLinearOperator_(A);
+        auto solver = constructPreconditionedSolver_(linearOperatorHolder);
+
+        return solveSequential_(x, b, *solver);
+    }
+
+    IstlSolverResult solveSequential_(XVector& x, BVector& b, InverseOperator& solver) const
+    {
+        Dune::InverseOperatorResult result;
+        if constexpr (convertMultiTypeVectorAndMatrix)
+        {
             // create the vector the IterativeSolver backend can handle
-            auto bTmp = VectorConverter<BVector>::multiTypeToBlockVector(b);
-            assert(bTmp.size() == numRows);
+            BVectorForSolver bTmp = VectorConverter<BVector>::multiTypeToBlockVector(b);
 
             // create a block vector to which the linear solver writes the solution
-            using VectorBlock = typename Dune::FieldVector<Scalar, 1>;
-            using BlockVector = typename Dune::BlockVector<VectorBlock>;
-            BlockVector y(numRows);
-
-            auto linearOperator = std::make_shared<Dune::MatrixAdapter<decltype(M), decltype(y), decltype(bTmp)>>(M);
-            auto solver = constructPreconditionedSolver_(linearOperator);
+            XVectorForSolver y(bTmp.size());
 
             // solve linear system
-            solver.apply(y, bTmp, result_);
+            solver.apply(y, bTmp, result);
 
             // copy back the result y into x
-            if(result_.converged)
+            if (result.converged)
                 VectorConverter<XVector>::retrieveValues(x, y);
         }
         else
         {
-            // construct solver from linear operator
-            using SequentialTraits = typename LinearSolverTraits::template Sequential<Matrix, XVector>;
-            auto linearOperator = std::make_shared<typename SequentialTraits::LinearOperator>(A);
-            auto solver = constructPreconditionedSolver_(linearOperator);
-
             // solve linear system
-            solver.apply(x, b, result_);
+            solver.apply(x, b, result);
         }
 
-        return result_.converged;
+        return result;
+    }
+
+    IstlSolverResult solveSequentialOrParallel_(Matrix& A, XVector& x, BVector& b)
+    {
+        return executeSequentialOrParallel_(
+            [&]{ return solveSequential_(A, x, b); },
+            [&](auto traits){ return solveParallel_(A, x, b, traits); }
+        );
+    }
+
+    IstlSolverResult solveSequentialOrParallel_(XVector& x, BVector& b, InverseOperator& solver) const
+    {
+        return executeSequentialOrParallel_(
+            [&]{ return solveSequential_(x, b, solver); },
+            [&](auto traits){ return solveParallel_(x, b, solver, traits); }
+        );
+    }
+
+    template<class ParallelTraits>
+    IstlSolverResult solveParallel_(Matrix& A, XVector& x, BVector& b, ParallelTraits = {})
+    {
+        // construct solver from linear operator
+        auto linearOperatorHolder = makeParallelLinearOperator_<ParallelTraits>(A);
+        auto solver = constructPreconditionedSolver_(linearOperatorHolder);
+        return solveParallel_<ParallelTraits>(x, b, *solver);
+    }
+
+    template<class ParallelTraits>
+    IstlSolverResult solveParallel_(XVector& x, BVector& b, InverseOperator& solver, ParallelTraits = {}) const
+    {
+#if HAVE_MPI
+        // make right hand side consistent
+        prepareVectorParallel<LinearSolverTraits, ParallelTraits>(b, *parallelHelper_);
+
+        // solve linear system
+        Dune::InverseOperatorResult result;
+        solver.apply(x, b, result);
+        return result;
+#else
+        DUNE_THROW(Dune::InvalidStateException, "Calling makeParallelLinearOperator for sequential run");
+#endif
     }
 
+
+    std::shared_ptr<InverseOperator> constructPreconditionedSolver_(MatrixOperatorHolder& ops)
+    {
+        return std::visit([&](auto&& op)
+        {
+            using LinearOperator = typename std::decay_t<decltype(op)>::element_type;
+            const auto& params = params_.sub("preconditioner");
+            using Prec = Dune::Preconditioner<typename LinearOperator::domain_type, typename LinearOperator::range_type>;
+            using TL = Dune::TypeList<typename LinearOperator::matrix_type, typename LinearOperator::domain_type, typename LinearOperator::range_type>;
+            std::shared_ptr<Prec> prec = PreconditionerFactory{}(TL{}, op, params);
+
 #if HAVE_MPI
-    bool solveSequentialOrParallel_(Matrix& A, XVector& x, BVector& b)
+            if (prec->category() != op->category() && prec->category() == Dune::SolverCategory::sequential)
+                prec = Dune::wrapPreconditioner4Parallel(prec, op);
+#endif
+            return std::make_shared<InverseOperator>(op, scalarProduct_, prec, params_);
+        }, ops);
+    }
+
+    template<class Seq, class Par>
+    decltype(auto) executeSequentialOrParallel_(Seq&& sequentialAction, Par&& parallelAction) const
     {
+#if HAVE_MPI
         // For Dune::MultiTypeBlockMatrix there is currently no generic way
         // of handling parallelism, we therefore can only solve these types of systems sequentially
         if constexpr (isMultiTypeBlockMatrix<Matrix>::value || !LinearSolverTraits::canCommunicate)
-            return solveSequential_(A, x, b);
+            return sequentialAction();
         else
         {
             switch (solverCategory_)
             {
                 case Dune::SolverCategory::sequential:
-                    return solveSequential_(A, x, b);
+                    return sequentialAction();
                 case Dune::SolverCategory::nonoverlapping:
                     using NOTraits = typename LinearSolverTraits::template ParallelNonoverlapping<Matrix, XVector>;
-                    return solveParallel_<NOTraits>(A, x, b);
+                    return parallelAction(NOTraits{});
                 case Dune::SolverCategory::overlapping:
                     using OTraits = typename LinearSolverTraits::template ParallelOverlapping<Matrix, XVector>;
-                    return solveParallel_<OTraits>(A, x, b);
+                    return parallelAction(OTraits{});
                 default: DUNE_THROW(Dune::InvalidStateException, "Unknown solver category");
             }
         }
-    }
-
-    template<class ParallelTraits>
-    bool solveParallel_(Matrix& A, XVector& x, BVector& b)
-    {
-        // make linear algebra consistent
-        prepareLinearAlgebraParallel<LinearSolverTraits, ParallelTraits>(A, b, *parallelHelper_);
-
-        // construct solver from linear operator
-        auto linearOperator = std::make_shared<typename ParallelTraits::LinearOperator>(A, *communication_);
-        auto solver = constructPreconditionedSolver_(linearOperator);
-
-        // solve linear system
-        solver.apply(x, b, result_);
-        return result_.converged;
-    }
-#endif // HAVE_MPI
-
-    template<class LinearOperator>
-    InverseOperator constructPreconditionedSolver_(std::shared_ptr<LinearOperator>& op)
-    {
-        const auto& params = params_.sub("preconditioner");
-        using Prec = Dune::Preconditioner<typename LinearOperator::domain_type, typename LinearOperator::range_type>;
-        using TL = Dune::TypeList<typename LinearOperator::matrix_type, typename LinearOperator::domain_type, typename LinearOperator::range_type>;
-        std::shared_ptr<Prec> prec = PreconditionerFactory{}(TL{}, op, params);
-
-#if HAVE_MPI
-        if (prec->category() != op->category() && prec->category() == Dune::SolverCategory::sequential)
-            prec = Dune::wrapPreconditioner4Parallel(prec, op);
+#else
+        return sequentialAction();
 #endif
-        return {op, scalarProduct_, prec, params_};
     }
 
 #if HAVE_MPI
-    std::unique_ptr<ParallelHelper> parallelHelper_;
+    std::shared_ptr<const ParallelHelper> parallelHelper_;
     std::shared_ptr<Comm> communication_;
 #endif
+
     Dune::SolverCategory::Category solverCategory_;
     std::shared_ptr<ScalarProduct> scalarProduct_;
 
-    Dune::InverseOperatorResult result_;
+    // for stored solvers (reuse matrix)
+    MatrixOperatorHolder linearOperator_;
+    // for stored solvers (reuse matrix)
+    std::shared_ptr<InverseOperator> solver_;
+
     Dune::ParameterTree params_;
     std::string name_;
 };
@@ -507,90 +715,129 @@ namespace Dumux::Detail {
 template<class LSTraits, class LATraits, template<class M> class Solver>
 class DirectIstlSolver : public LinearSolver
 {
+    using Matrix = typename LATraits::Matrix;
+    using XVector = typename LATraits::Vector;
+    using BVector = typename LATraits::Vector;
+
+    static constexpr bool convertMultiTypeVectorAndMatrix = isMultiTypeBlockVector<XVector>::value;
+    using MatrixForSolver = typename Detail::IstlSolvers::MatrixForSolver<Matrix, convertMultiTypeVectorAndMatrix>::type;
+    using BVectorForSolver = typename Detail::IstlSolvers::VectorForSolver<BVector, convertMultiTypeVectorAndMatrix>::type;
+    using XVectorForSolver = typename Detail::IstlSolvers::VectorForSolver<XVector, convertMultiTypeVectorAndMatrix>::type;
+    using InverseOperator = Dune::InverseOperator<XVectorForSolver, BVectorForSolver>;
 public:
     using LinearSolver::LinearSolver;
 
-    template<class Matrix, class Vector>
-    bool solve(const Matrix& A, Vector& x, const Vector& b)
+    /*!
+     * \brief Solve the linear system Ax = b
+     */
+    IstlSolverResult solve(const Matrix& A, XVector& x, const BVector& b)
     {
-        // support dune-istl multi-type block vector/matrix
-        if constexpr (isMultiTypeBlockVector<Vector>())
-        {
-            auto [AA, xx, bb] = convertIstlMultiTypeToBCRSSystem_(A, x, b);
-            bool converged = solve_(AA, xx, bb);
-            if (converged)
-                VectorConverter<Vector>::retrieveValues(x, xx);
-            return converged;
-        }
+        return solve_(A, x, b);
+    }
+
+    /*!
+     * \brief Solve the linear system Ax = b using the matrix set with \ref setMatrix
+     */
+    IstlSolverResult solve(XVector& x, const BVector& b)
+    {
+        if (!solver_)
+            DUNE_THROW(Dune::InvalidStateException, "Called solve(x, b) but no linear operator has been set");
 
+        return solve_(x, b, *solver_);
+    }
+
+    /*!
+     * \brief Set the matrix A of the linear system Ax = b for reuse
+     */
+    void setMatrix(std::shared_ptr<Matrix> A)
+    {
+        if constexpr (convertMultiTypeVectorAndMatrix)
+            matrix_ = std::make_shared<MatrixForSolver>(MatrixConverter<Matrix>::multiTypeToBCRSMatrix(A));
         else
-            return solve_(A, x, b);
+            matrix_ = A;
+
+        solver_ = std::make_shared<Solver<MatrixForSolver>>(*matrix_);
     }
 
+    /*!
+     * \brief Set the matrix A of the linear system Ax = b for reuse
+     * \note The client has to take care of the lifetime management of A
+     */
+    void setMatrix(Matrix& A)
+    { setMatrix(Dune::stackobject_to_shared_ptr(A)); }
+
+    /*!
+     * \brief name of the linear solver
+     */
     std::string name() const
     {
         return "Direct solver";
     }
 
-    const Dune::InverseOperatorResult& result() const
+private:
+    IstlSolverResult solve_(const Matrix& A, XVector& x, const BVector& b)
     {
-        return result_;
+        // support dune-istl multi-type block vector/matrix by copying
+        if constexpr (isMultiTypeBlockVector<BVector>())
+        {
+            const auto AA = MatrixConverter<Matrix>::multiTypeToBCRSMatrix(A);
+            Solver<MatrixForSolver> solver(AA, this->verbosity() > 0);
+            return solve_(x, b, solver);
+        }
+        else
+        {
+            Solver<MatrixForSolver> solver(A, this->verbosity() > 0);
+            return solve_(x, b, solver);
+        }
     }
 
-private:
-    Dune::InverseOperatorResult result_;
-
-    template<class Matrix, class Vector>
-    bool solve_(const Matrix& A, Vector& x, const Vector& b)
+    IstlSolverResult solve_(XVector& x, const BVector& b, InverseOperator& solver) const
     {
-        static_assert(isBCRSMatrix<Matrix>::value, "Direct solver only works with BCRS matrices!");
-        using BlockType = typename Matrix::block_type;
-        static_assert(BlockType::rows == BlockType::cols, "Matrix block must be quadratic!");
-        constexpr auto blockSize = BlockType::rows;
-
-        Solver<Matrix> solver(A, this->verbosity() > 0);
+        static_assert(isBCRSMatrix<MatrixForSolver>::value, "Direct solver only works with BCRS matrices!");
+        static_assert(MatrixForSolver::block_type::rows == MatrixForSolver::block_type::cols, "Matrix block must be quadratic!");
 
-        Vector bTmp(b);
-        solver.apply(x, bTmp, result_);
+        Dune::InverseOperatorResult result;
+        if constexpr (isMultiTypeBlockVector<BVector>())
+        {
+            auto bb = VectorConverter<BVector>::multiTypeToBlockVector(b);
+            XVectorForSolver xx(bb.size());
+            solver.apply(xx, bb, result);
+            checkResult_(xx, result);
+            if (result.converged)
+                VectorConverter<XVector>::retrieveValues(x, xx);
+            return result;
+        }
+        else
+        {
+            BVectorForSolver bTmp(b);
+            solver.apply(x, bTmp, result);
+            checkResult_(x, result);
+            return result;
+        }
+    }
 
+    void checkResult_(const XVectorForSolver& x, Dune::InverseOperatorResult& result) const
+    {
         int size = x.size();
         for (int i = 0; i < size; i++)
         {
-            for (int j = 0; j < blockSize; j++)
+            for (int j = 0; j < x[i].size(); j++)
             {
                 using std::isnan;
                 using std::isinf;
                 if (isnan(x[i][j]) || isinf(x[i][j]))
                 {
-                    result_.converged = false;
+                    result.converged = false;
                     break;
                 }
             }
         }
-
-        return result_.converged;
     }
 
-    template<class Matrix, class Vector>
-    auto convertIstlMultiTypeToBCRSSystem_(const Matrix& A, Vector& x, const Vector& b)
-    {
-        const auto AA = MatrixConverter<Matrix>::multiTypeToBCRSMatrix(A);
-
-        // get the new matrix sizes
-        const std::size_t numRows = AA.N();
-        assert(numRows == AA.M());
-
-        // create the vector the IterativeSolver backend can handle
-        const auto bb = VectorConverter<Vector>::multiTypeToBlockVector(b);
-        assert(bb.size() == numRows);
-
-        // create a blockvector to which the linear solver writes the solution
-        using VectorBlock = typename Dune::FieldVector<Scalar, 1>;
-        using BlockVector = typename Dune::BlockVector<VectorBlock>;
-        BlockVector xx(numRows);
-
-        return std::make_tuple(std::move(AA), std::move(xx), std::move(bb));
-    }
+    //! matrix when using the setMatrix interface for matrix reuse
+    std::shared_ptr<MatrixForSolver> matrix_;
+    //! solver when using the setMatrix interface for matrix reuse
+    std::shared_ptr<InverseOperator> solver_;
 };
 
 } // end namespace Dumux::Detail
diff --git a/test/linear/test_linearsolver.cc b/test/linear/test_linearsolver.cc
index 4128648940..c268d399d7 100644
--- a/test/linear/test_linearsolver.cc
+++ b/test/linear/test_linearsolver.cc
@@ -83,9 +83,12 @@ int main(int argc, char* argv[])
         LinearSolver solver(testSolverName);
 
         std::cout << "Solving Laplace problem with " << solver.name() << "\n";
-        solver.solve(A, x, b);
-        if (!solver.result().converged)
+        auto result = solver.solve(A, x, b);
+        if (!result.converged)
             DUNE_THROW(Dune::Exception, testSolverName << " did not converge!");
+
+        if (!result)
+            DUNE_THROW(Dune::Exception, "Solver result cannot be implicitly converted to bool");
     }
 
     // IstlSolverFactoryBackend
-- 
GitLab