From bbe76d906f26e3e39048036c8f4cde3e6b0f44d6 Mon Sep 17 00:00:00 2001
From: Timo Koch <timo.koch@iws.uni-stuttgart.de>
Date: Mon, 11 Jul 2022 17:55:20 +0200
Subject: [PATCH] [multidomain] Implement hook to enable multithreaded assembly
 by a compatible coupling manager

---
 dumux/multidomain/fvassembler.hh | 116 ++++++++++++++++++++++---------
 1 file changed, 84 insertions(+), 32 deletions(-)

diff --git a/dumux/multidomain/fvassembler.hh b/dumux/multidomain/fvassembler.hh
index f1e6bd9718..18bac71019 100644
--- a/dumux/multidomain/fvassembler.hh
+++ b/dumux/multidomain/fvassembler.hh
@@ -39,6 +39,7 @@
 #include <dumux/assembly/diffmethod.hh>
 #include <dumux/assembly/jacobianpattern.hh>
 #include <dumux/linear/parallelhelpers.hh>
+#include <dumux/parallel/multithreading.hh>
 
 #include "couplingjacobianpattern.hh"
 #include "subdomaincclocalassembler.hh"
@@ -50,6 +51,14 @@
 
 namespace Dumux {
 
+/*!
+ * \ingroup MultiDomain
+ * \brief trait that is specialized for coupling manager supporting multithreaded assembly
+ */
+template<class CM>
+struct CouplingManagerSupportsMultithreadedAssembly : public std::false_type
+{};
+
 /*!
  * \ingroup MultiDomain
  * \ingroup Assembly
@@ -162,6 +171,12 @@ public:
     {
         static_assert(isImplicit(), "Explicit assembler for stationary problem doesn't make sense!");
         std::cout << "Instantiated assembler for a stationary problem." << std::endl;
+
+        enableMultithreading_ = CouplingManagerSupportsMultithreadedAssembly<CouplingManager>::value
+            && !Multithreading::isSerial()
+            && getParam<bool>("Assembly.Multithreading", true);
+
+        maybeComputeColors_();
     }
 
     /*!
@@ -185,6 +200,12 @@ public:
     , warningIssued_(false)
     {
         std::cout << "Instantiated assembler for an instationary problem." << std::endl;
+
+        enableMultithreading_ = CouplingManagerSupportsMultithreadedAssembly<CouplingManager>::value
+            && !Multithreading::isSerial()
+            && getParam<bool>("Assembly.Multithreading", true);
+
+        maybeComputeColors_();
     }
 
     /*!
@@ -238,7 +259,7 @@ public:
     Scalar residualNorm(const SolutionVector& curSol)
     {
         ResidualType residual;
-        setResidualSize(residual);
+        setResidualSize_(residual);
         assembleResidual(residual, curSol);
 
         // calculate the squared norm of the residual
@@ -300,8 +321,8 @@ public:
         residual_ = r;
 
         setJacobianBuildMode(*jacobian_);
-        setJacobianPattern(*jacobian_);
-        setResidualSize(*residual_);
+        setJacobianPattern_(*jacobian_);
+        setResidualSize_(*residual_);
     }
 
     /*!
@@ -314,8 +335,8 @@ public:
         residual_ = std::make_shared<SolutionVector>();
 
         setJacobianBuildMode(*jacobian_);
-        setJacobianPattern(*jacobian_);
-        setResidualSize(*residual_);
+        setJacobianPattern_(*jacobian_);
+        setResidualSize_(*residual_);
     }
 
     /*!
@@ -337,30 +358,14 @@ public:
         });
     }
 
-    /*!
-     * \brief Sets the jacobian sparsity pattern.
+     /*!
+     * \brief Resizes jacobian and residual and recomputes colors
      */
-    void setJacobianPattern(JacobianMatrix& jac) const
+    void updateAfterGridAdaption()
     {
-        using namespace Dune::Hybrid;
-        forEach(std::make_index_sequence<JacobianMatrix::N()>(), [&](const auto domainI)
-        {
-            forEach(integralRange(Dune::Hybrid::size(jac[domainI])), [&](const auto domainJ)
-            {
-                const auto pattern = this->getJacobianPattern_(domainI, domainJ);
-                pattern.exportIdx(jac[domainI][domainJ]);
-            });
-        });
-    }
-
-    /*!
-     * \brief Resizes the residual
-     */
-    void setResidualSize(SolutionVector& res) const
-    {
-        using namespace Dune::Hybrid;
-        forEach(integralRange(Dune::Hybrid::size(res)), [&](const auto domainId)
-        { res[domainId].resize(this->numDofs(domainId)); });
+        setResidualSize_();
+        setJacobianPattern_();
+        maybeComputeColors_();
     }
 
     /*!
@@ -461,13 +466,39 @@ protected:
     std::shared_ptr<CouplingManager> couplingManager_;
 
 private:
+    /*!
+     * \brief Sets the jacobian sparsity pattern.
+     */
+    void setJacobianPattern_(JacobianMatrix& jac) const
+    {
+        using namespace Dune::Hybrid;
+        forEach(std::make_index_sequence<JacobianMatrix::N()>(), [&](const auto domainI)
+        {
+            forEach(integralRange(Dune::Hybrid::size(jac[domainI])), [&](const auto domainJ)
+            {
+                const auto pattern = this->getJacobianPattern_(domainI, domainJ);
+                pattern.exportIdx(jac[domainI][domainJ]);
+            });
+        });
+    }
+
+    /*!
+     * \brief Resizes the residual
+     */
+    void setResidualSize_(SolutionVector& res) const
+    {
+        using namespace Dune::Hybrid;
+        forEach(integralRange(Dune::Hybrid::size(res)), [&](const auto domainId)
+        { res[domainId].resize(this->numDofs(domainId)); });
+    }
+
     // reset the residual vector to 0.0
     void resetResidual_()
     {
         if(!residual_)
         {
             residual_ = std::make_shared<SolutionVector>();
-            setResidualSize(*residual_);
+            setResidualSize_(*residual_);
         }
 
         (*residual_) = 0.0;
@@ -480,12 +511,20 @@ private:
         {
             jacobian_ = std::make_shared<JacobianMatrix>();
             setJacobianBuildMode(*jacobian_);
-            setJacobianPattern(*jacobian_);
+            setJacobianPattern_(*jacobian_);
         }
 
        (*jacobian_)  = 0.0;
     }
 
+    //! Computes the colors
+    void maybeComputeColors_()
+    {
+        if constexpr (CouplingManagerSupportsMultithreadedAssembly<CouplingManager>::value)
+            if (enableMultithreading_)
+                couplingManager_->computeColorsForAssembly();
+    }
+
     // check if the assembler is in a correct state for assembly
     void checkAssemblerState_() const
     {
@@ -528,9 +567,19 @@ private:
     template<std::size_t i, class AssembleElementFunc>
     void assemble_(Dune::index_constant<i> domainId, AssembleElementFunc&& assembleElement) const
     {
-        // let the local assembler add the element contributions
-        for (const auto& element : elements(gridView(domainId)))
-            assembleElement(element);
+        if constexpr (CouplingManagerSupportsMultithreadedAssembly<CouplingManager>::value)
+        {
+            if (enableMultithreading_)
+                couplingManager_->assembleMultithreaded(domainId, assembleElement);
+        }
+
+        // fallback for coupling managers that don't support multithreaded assembly (yet)
+        else
+        {
+            // let the local assembler add the element contributions
+            for (const auto& element : elements(gridView(domainId)))
+                assembleElement(element);
+        }
     }
 
     // get diagonal block pattern
@@ -620,6 +669,9 @@ private:
 
     //! Issue a warning if the calculation is used in parallel with overlap. This could be a static local variable if it wasn't for g++7 yielding a linker error.
     bool warningIssued_;
+
+    //! if multithreaded assembly is enabled
+    bool enableMultithreading_ = false;
 };
 
 } // end namespace Dumux
-- 
GitLab