From 12d0c7d376638c9949af678fb8dc7aac1ee157c5 Mon Sep 17 00:00:00 2001
From: Timo Koch <timo.koch@iws.uni-stuttgart.de>
Date: Thu, 27 Feb 2020 09:16:14 +0100
Subject: [PATCH] [istlfactory] Adjust to new structure after rebase

---
 dumux/linear/istlsolverfactorybackend.hh      | 130 ++++++++++++------
 test/freeflow/shallowwater/dambreak/main.cc   |   4 +-
 .../richards/implicit/lens/main.cc            |   5 +-
 3 files changed, 93 insertions(+), 46 deletions(-)

diff --git a/dumux/linear/istlsolverfactorybackend.hh b/dumux/linear/istlsolverfactorybackend.hh
index ca6ed3de9f..db03353bed 100644
--- a/dumux/linear/istlsolverfactorybackend.hh
+++ b/dumux/linear/istlsolverfactorybackend.hh
@@ -39,7 +39,6 @@
 #include <dune/istl/solverfactory.hh>
 
 #include <dumux/linear/solver.hh>
-#include <dumux/linear/linearsolvertraits.hh>
 #include <dumux/linear/parallelhelpers.hh>
 
 #if DUNE_VERSION_NEWER_REV(DUNE_ISTL,2,7,1)
@@ -53,22 +52,12 @@ namespace Dumux {
  * \note the solvers are configured via the input file
  * \note requires Dune version 2.7.1 or newer
  */
-template <class Matrix, class Vector, class GridGeometry>
+template <class LinearSolverTraits>
 class IstlSolverFactoryBackend : public LinearSolver
 {
-    using GridView = typename GridGeometry::GridView;
-    using LinearSolverTraits =  Dumux::LinearSolverTraits<Matrix, Vector, GridGeometry>;
-    using Grid = typename GridView::Grid;
-    using LinearOperator = typename LinearSolverTraits::LinearOperator;
-    using ScalarProduct = typename LinearSolverTraits::ScalarProduct;
-    using VType = typename LinearSolverTraits::VType;
-    using Comm = typename LinearSolverTraits::Comm;
-    using BCRSMat = typename LinearSolverTraits::LinearOperator::matrix_type;
-    using DofMapper = typename LinearSolverTraits::DofMapper;
-
 public:
     //! translation table for solver parameters
-    static std::vector<std::array<std::string,2> > dumuxToIstlSolverParams;
+    static std::vector<std::array<std::string, 2>> dumuxToIstlSolverParams;
 
     /*!
      * \brief Construct the backend for the sequential case only
@@ -77,9 +66,9 @@ public:
      */
     IstlSolverFactoryBackend(const std::string& paramGroup = "")
     : paramGroup_(paramGroup)
-    , firstCall_(true)
+    , isParallel_(Dune::MPIHelper::getCollectiveCommunication().size() > 1)
     {
-        if (Dune::MPIHelper::getCollectiveCommunication().size() > 1)
+        if (isParallel_)
             DUNE_THROW(Dune::InvalidStateException, "Using sequential constructor for parallel run. Use signature with gridView and dofMapper!");
 
         reset();
@@ -92,12 +81,12 @@ public:
      * \param dofMapper an index mapper for dof entities
      * \param paramGroup the parameter group for parameter lookup
      */
-    IstlSolverFactoryBackend(const GridView& gridView,
-                             const DofMapper& dofMapper,
+    IstlSolverFactoryBackend(const typename LinearSolverTraits::GridView& gridView,
+                             const typename LinearSolverTraits::DofMapper& dofMapper,
                              const std::string& paramGroup = "")
     : paramGroup_(paramGroup)
-    , parallelHelper_(std::make_unique<ParallelISTLHelper<GridView, LinearSolverTraits>>(gridView, dofMapper))
-    , firstCall_(true)
+    , parallelHelper_(std::make_unique<ParallelISTLHelper<LinearSolverTraits>>(gridView, dofMapper))
+    , isParallel_(Dune::MPIHelper::getCollectiveCommunication().size() > 1)
     {
         reset();
     }
@@ -109,27 +98,14 @@ public:
      * \param x the seeked solution vector, containing the initial solution upon entry
      * \param b the right hand side vector
      */
+    template<class Matrix, class Vector>
     bool solve(Matrix& A, Vector& x, Vector& b)
     {
-        std::shared_ptr<Comm> comm;
-        std::shared_ptr<LinearOperator> linearOperator;
-        std::shared_ptr<ScalarProduct> scalarProduct; // not used.
-
 #if HAVE_MPI
-        if constexpr (LinearSolverTraits::isParallel)
-            prepareLinearAlgebraParallel<LinearSolverTraits>(A, b, comm, linearOperator, scalarProduct, *parallelHelper_, firstCall_);
-        else
-            prepareLinearAlgebraSequential<LinearSolverTraits>(A, comm, linearOperator, scalarProduct);
+        solveSequentialOrParallel_(A, x, b);
 #else
-        prepareLinearAlgebraSequential<LinearSolverTraits>(A, comm, linearOperator, scalarProduct);
+        solveSequential_(A, x, b);
 #endif
-
-        // construct solver
-        auto solver = getSolverFromFactory_(linearOperator);
-
-        // solve linear system
-        solver->apply(x, b, result_);
-
         firstCall_ = false;
         return result_.converged;
     }
@@ -137,10 +113,10 @@ public:
     //! reset the linear solver factory
     void reset()
     {
+        firstCall_ = true;
         resetDefaultParameters();
         convertParameterTree_(paramGroup_);
         checkMandatoryParameters_();
-        Dune::initSolverFactories<typename LinearSolverTraits::LinearOperator>();
         name_ = params_.get<std::string>("preconditioner.type") + "-preconditioned " + params_.get<std::string>("type");
         if (params_.get<int>("verbose", 0) > 0)
             std::cout << "Initialized linear solver of type: " << name_ << std::endl;
@@ -168,7 +144,6 @@ public:
     }
 
 private:
-
     void convertParameterTree_(const std::string& paramGroup="")
     {
         auto linearSolverGroups = getParamSubGroups("LinearSolver", paramGroup);
@@ -199,6 +174,79 @@ private:
             DUNE_THROW(Dune::InvalidStateException, "Solver factory needs \"LinearSolver.Preconditioner.Type\" parameter to select the preconditioner");
     }
 
+#if HAVE_MPI
+    template<class Matrix, class Vector>
+    void solveSequentialOrParallel_(Matrix& A, Vector& x, Vector& b)
+    {
+        if constexpr (LinearSolverTraits::canCommunicate)
+        {
+            if (isParallel_)
+            {
+                if (LinearSolverTraits::isNonOverlapping(parallelHelper_->gridView()))
+                {
+                    using PTraits = typename LinearSolverTraits::template ParallelNonoverlapping<Matrix, Vector>;
+                    solveParallel_<PTraits>(A, x, b);
+                }
+                else
+                {
+                    using PTraits = typename LinearSolverTraits::template ParallelOverlapping<Matrix, Vector>;
+                    solveParallel_<PTraits>(A, x, b);
+                }
+            }
+            else
+                solveSequential_(A, x, b);
+        }
+        else
+        {
+            solveSequential_(A, x, b);
+        }
+    }
+
+    template<class ParallelTraits, class Matrix, class Vector>
+    void solveParallel_(Matrix& A, Vector& x, Vector& b)
+    {
+        using Comm = typename ParallelTraits::Comm;
+        using LinearOperator = typename ParallelTraits::LinearOperator;
+        using ScalarProduct = typename ParallelTraits::ScalarProduct;
+
+        if (firstCall_)
+        {
+            Dune::initSolverFactories<LinearOperator>();
+            parallelHelper_->initGhostsAndOwners();
+        }
+
+        std::shared_ptr<Comm> comm;
+        std::shared_ptr<LinearOperator> linearOperator;
+        std::shared_ptr<ScalarProduct> scalarProduct;
+        prepareLinearAlgebraParallel<LinearSolverTraits, ParallelTraits>(A, b, comm, linearOperator, scalarProduct, *parallelHelper_);
+
+        // construct solver
+        auto solver = getSolverFromFactory_(linearOperator);
+
+        // solve linear system
+        solver->apply(x, b, result_);
+    }
+#endif // HAVE_MPI
+
+    template<class Matrix, class Vector>
+    void solveSequential_(Matrix& A, Vector& x, Vector& b)
+    {
+        // construct linear operator
+        using Traits = typename LinearSolverTraits::template Sequential<Matrix, Vector>;
+        using LinearOperator = typename Traits::LinearOperator;
+        auto linearOperator = std::make_shared<LinearOperator>(A);
+
+        if (firstCall_)
+            Dune::initSolverFactories<LinearOperator>();
+
+        // construct solver
+        auto solver = getSolverFromFactory_(linearOperator);
+
+        // solve linear system
+        solver->apply(x, b, result_);
+    }
+
+    template<class LinearOperator>
     auto getSolverFromFactory_(std::shared_ptr<LinearOperator>& fop)
     {
         try { return Dune::getSolverFromFactory(fop, params_); }
@@ -211,17 +259,19 @@ private:
     }
 
     const std::string paramGroup_;
-    std::unique_ptr<ParallelISTLHelper<GridView, LinearSolverTraits>> parallelHelper_;
+    std::unique_ptr<ParallelISTLHelper<LinearSolverTraits>> parallelHelper_;
+    bool isParallel_;
     bool firstCall_;
+
     Dune::InverseOperatorResult result_;
     Dune::ParameterTree params_;
     std::string name_;
 };
 
 //! translation table for solver parameters
-template<class Matrix, class Vector, class Geometry>
+template<class LinearSolverTraits>
 std::vector<std::array<std::string, 2>>
-IstlSolverFactoryBackend<Matrix, Vector, Geometry>::dumuxToIstlSolverParams =
+IstlSolverFactoryBackend<LinearSolverTraits>::dumuxToIstlSolverParams =
 {
     // solver params
     {"Verbosity", "verbose"},
diff --git a/test/freeflow/shallowwater/dambreak/main.cc b/test/freeflow/shallowwater/dambreak/main.cc
index 7b3d94b4f5..18080ada72 100644
--- a/test/freeflow/shallowwater/dambreak/main.cc
+++ b/test/freeflow/shallowwater/dambreak/main.cc
@@ -132,9 +132,7 @@ int main(int argc, char** argv) try
 
     // the linear solver
 #if DUNE_VERSION_NEWER_REV(DUNE_ISTL,2,7,1)
-    using Matrix = GetPropType<TypeTag, Properties::JacobianMatrix>;
-    using Vector = Dune::BlockVector<Dune::FieldVector<typename SolutionVector::block_type::value_type, SolutionVector::block_type::dimension>>;
-    using LinearSolver = IstlSolverFactoryBackend<Matrix, Vector, GridGeometry>;
+    using LinearSolver = IstlSolverFactoryBackend<LinearSolverTraits<GridGeometry>>;
 #else
     using LinearSolver = AMGBiCGSTABBackend<LinearSolverTraits<GridGeometry>>;
 #endif
diff --git a/test/porousmediumflow/richards/implicit/lens/main.cc b/test/porousmediumflow/richards/implicit/lens/main.cc
index bb5251ec2a..40e3cbbb1d 100644
--- a/test/porousmediumflow/richards/implicit/lens/main.cc
+++ b/test/porousmediumflow/richards/implicit/lens/main.cc
@@ -42,6 +42,7 @@
 
 #if DUNE_VERSION_NEWER_REV(DUNE_ISTL,2,7,1)
 #include <dumux/linear/istlsolverfactorybackend.hh>
+#include <dumux/linear/linearsolvertraits.hh>
 #else
 #include <dumux/linear/amgbackend.hh>
 #endif
@@ -150,9 +151,7 @@ int main(int argc, char** argv) try
 
     // the linear solver
 #if DUNE_VERSION_NEWER_REV(DUNE_ISTL,2,7,1)
-    using Matrix = GetPropType<TypeTag, Properties::JacobianMatrix>;
-    using Vector = Dune::BlockVector<Dune::FieldVector<typename SolutionVector::block_type::value_type, SolutionVector::block_type::dimension>>;
-    using LinearSolver = IstlSolverFactoryBackend<Matrix, Vector, GridGeometry>;
+    using LinearSolver = IstlSolverFactoryBackend<LinearSolverTraits<GridGeometry>>;
 #else
     using LinearSolver = AMGBiCGSTABBackend<LinearSolverTraits<GridGeometry>>;
 #endif
-- 
GitLab