From 6882a882a4651e39b3881f3cd8719f000cfeb836 Mon Sep 17 00:00:00 2001
From: Timo Koch <timo.koch@iws.uni-stuttgart.de>
Date: Mon, 11 Jul 2022 17:56:02 +0200
Subject: [PATCH] [ff][md] Enable multihtreaded assembly for freeflow coupling
 manager

---
 .../staggeredfreeflow/couplingmanager.hh      | 130 +++++++++++++-----
 1 file changed, 97 insertions(+), 33 deletions(-)

diff --git a/dumux/multidomain/staggeredfreeflow/couplingmanager.hh b/dumux/multidomain/staggeredfreeflow/couplingmanager.hh
index c0cb5e1ce7..131c4e6b71 100644
--- a/dumux/multidomain/staggeredfreeflow/couplingmanager.hh
+++ b/dumux/multidomain/staggeredfreeflow/couplingmanager.hh
@@ -29,6 +29,7 @@
 #include <memory>
 #include <tuple>
 #include <vector>
+#include <deque>
 
 #include <dune/common/exceptions.hh>
 #include <dune/common/indices.hh>
@@ -39,6 +40,9 @@
 #include <dumux/multidomain/couplingmanager.hh>
 #include <dumux/discretization/facecentered/staggered/consistentlyorientedgrid.hh>
 
+#include <dumux/parallel/parallel_for.hh>
+#include <dumux/assembly/coloring.hh>
+
 namespace Dumux {
 
 /*!
@@ -64,6 +68,7 @@ private:
     template<std::size_t id> using GridGeometry = GetPropType<SubDomainTypeTag<id>, Properties::GridGeometry>;
     template<std::size_t id> using GridView = typename GridGeometry<id>::GridView;
     template<std::size_t id> using Element = typename GridView<id>::template Codim<0>::Entity;
+    template<std::size_t id> using ElementSeed = typename GridView<id>::Grid::template Codim<0>::EntitySeed;
     template<std::size_t id> using FVElementGeometry = typename GridGeometry<id>::LocalView;
     template<std::size_t id> using SubControlVolume = typename FVElementGeometry<id>::SubControlVolume;
     template<std::size_t id> using SubControlVolumeFace = typename FVElementGeometry<id>::SubControlVolumeFace;
@@ -242,7 +247,7 @@ public:
         assert(!(considerPreviousTimeStep && !isTransient_));
         bindCouplingContext_(Dune::index_constant<freeFlowMomentumIndex>(), element, fvGeometry.elementIndex());
         const auto& insideMomentumScv = fvGeometry.scv(scvf.insideScvIdx());
-        const auto& insideMassScv = momentumCouplingContext_[0].fvGeometry.scv(insideMomentumScv.elementIndex());
+        const auto& insideMassScv = momentumCouplingContext_()[0].fvGeometry.scv(insideMomentumScv.elementIndex());
 
         const auto rho = [&](const auto& elemVolVars)
         {
@@ -251,14 +256,14 @@ public:
             else
             {
                 const auto& outsideMomentumScv = fvGeometry.scv(scvf.outsideScvIdx());
-                const auto& outsideMassScv = momentumCouplingContext_[0].fvGeometry.scv(outsideMomentumScv.elementIndex());
+                const auto& outsideMassScv = momentumCouplingContext_()[0].fvGeometry.scv(outsideMomentumScv.elementIndex());
                 // TODO distance weighting
                 return 0.5*(elemVolVars[insideMassScv].density() + elemVolVars[outsideMassScv].density());
             }
         };
 
-        return considerPreviousTimeStep ? rho(momentumCouplingContext_[0].prevElemVolVars)
-                                        : rho(momentumCouplingContext_[0].curElemVolVars);
+        return considerPreviousTimeStep ? rho(momentumCouplingContext_()[0].prevElemVolVars)
+                                        : rho(momentumCouplingContext_()[0].curElemVolVars);
     }
 
     auto insideAndOutsideDensity(const Element<freeFlowMomentumIndex>& element,
@@ -269,7 +274,7 @@ public:
         assert(!(considerPreviousTimeStep && !isTransient_));
         bindCouplingContext_(Dune::index_constant<freeFlowMomentumIndex>(), element, fvGeometry.elementIndex());
         const auto& insideMomentumScv = fvGeometry.scv(scvf.insideScvIdx());
-        const auto& insideMassScv = momentumCouplingContext_[0].fvGeometry.scv(insideMomentumScv.elementIndex());
+        const auto& insideMassScv = momentumCouplingContext_()[0].fvGeometry.scv(insideMomentumScv.elementIndex());
 
         const auto result = [&](const auto& elemVolVars)
         {
@@ -278,13 +283,13 @@ public:
             else
             {
                 const auto& outsideMomentumScv = fvGeometry.scv(scvf.outsideScvIdx());
-                const auto& outsideMassScv = momentumCouplingContext_[0].fvGeometry.scv(outsideMomentumScv.elementIndex());
+                const auto& outsideMassScv = momentumCouplingContext_()[0].fvGeometry.scv(outsideMomentumScv.elementIndex());
                 return std::make_pair(elemVolVars[insideMassScv].density(), elemVolVars[outsideMassScv].density());
             }
         };
 
-        return considerPreviousTimeStep ? result(momentumCouplingContext_[0].prevElemVolVars)
-                                        : result(momentumCouplingContext_[0].curElemVolVars);
+        return considerPreviousTimeStep ? result(momentumCouplingContext_()[0].prevElemVolVars)
+                                        : result(momentumCouplingContext_()[0].curElemVolVars);
     }
 
     /*!
@@ -296,10 +301,10 @@ public:
     {
         assert(!(considerPreviousTimeStep && !isTransient_));
         bindCouplingContext_(Dune::index_constant<freeFlowMomentumIndex>(), element, scv.elementIndex());
-        const auto& massScv = (*scvs(momentumCouplingContext_[0].fvGeometry).begin());
+        const auto& massScv = (*scvs(momentumCouplingContext_()[0].fvGeometry).begin());
 
-        return considerPreviousTimeStep ? momentumCouplingContext_[0].prevElemVolVars[massScv].density()
-                                        : momentumCouplingContext_[0].curElemVolVars[massScv].density();
+        return considerPreviousTimeStep ? momentumCouplingContext_()[0].prevElemVolVars[massScv].density()
+                                        : momentumCouplingContext_()[0].curElemVolVars[massScv].density();
     }
 
     /*!
@@ -312,13 +317,13 @@ public:
         bindCouplingContext_(Dune::index_constant<freeFlowMomentumIndex>(), element, fvGeometry.elementIndex());
 
         const auto& insideMomentumScv = fvGeometry.scv(scvf.insideScvIdx());
-        const auto& insideMassScv = momentumCouplingContext_[0].fvGeometry.scv(insideMomentumScv.elementIndex());
+        const auto& insideMassScv = momentumCouplingContext_()[0].fvGeometry.scv(insideMomentumScv.elementIndex());
 
         if (scvf.boundary())
-            return momentumCouplingContext_[0].curElemVolVars[insideMassScv].viscosity();
+            return momentumCouplingContext_()[0].curElemVolVars[insideMassScv].viscosity();
 
         const auto& outsideMomentumScv = fvGeometry.scv(scvf.outsideScvIdx());
-        const auto& outsideMassScv = momentumCouplingContext_[0].fvGeometry.scv(outsideMomentumScv.elementIndex());
+        const auto& outsideMassScv = momentumCouplingContext_()[0].fvGeometry.scv(outsideMomentumScv.elementIndex());
 
         const auto mu = [&](const auto& elemVolVars)
         {
@@ -326,7 +331,7 @@ public:
             return 0.5*(elemVolVars[insideMassScv].viscosity() + elemVolVars[outsideMassScv].viscosity());
         };
 
-        return mu(momentumCouplingContext_[0].curElemVolVars);
+        return mu(momentumCouplingContext_()[0].curElemVolVars);
     }
 
      /*!
@@ -339,8 +344,8 @@ public:
         bindCouplingContext_(Dune::index_constant<freeFlowMassIndex>(), element, scvf.insideScvIdx()/*eIdx*/);
 
         // the TPFA scvf index corresponds to the staggered scv index (might need mapping)
-        const auto localMomentumScvIdx = massScvfToMomentumScvIdx_(scvf, massAndEnergyCouplingContext_[0].fvGeometry);
-        const auto& scvJ = massAndEnergyCouplingContext_[0].fvGeometry.scv(localMomentumScvIdx);
+        const auto localMomentumScvIdx = massScvfToMomentumScvIdx_(scvf, massAndEnergyCouplingContext_()[0].fvGeometry);
+        const auto& scvJ = massAndEnergyCouplingContext_()[0].fvGeometry.scv(localMomentumScvIdx);
 
         // create a unit normal vector oriented in positive coordinate direction
         typename SubControlVolumeFace<freeFlowMassIndex>::GlobalPosition velocity;
@@ -446,19 +451,58 @@ public:
             const auto& problem = this->problem(domainJ);
             const auto& deflectedElement = problem.gridGeometry().element(dofIdxGlobalJ);
             const auto elemSol = elementSolution(deflectedElement, this->curSol(domainJ), problem.gridGeometry());
-            const auto& fvGeometry = momentumCouplingContext_[0].fvGeometry;
+            const auto& fvGeometry = momentumCouplingContext_()[0].fvGeometry;
             const auto scvIdxJ = dofIdxGlobalJ;
             const auto& scv = fvGeometry.scv(scvIdxJ);
 
             if constexpr (ElementVolumeVariables<freeFlowMassIndex>::GridVolumeVariables::cachingEnabled)
                 gridVars_(freeFlowMassIndex).curGridVolVars().volVars(scv).update(std::move(elemSol), problem, deflectedElement, scv);
             else
-                momentumCouplingContext_[0].curElemVolVars[scv].update(std::move(elemSol), problem, deflectedElement, scv);
+                momentumCouplingContext_()[0].curElemVolVars[scv].update(std::move(elemSol), problem, deflectedElement, scv);
         }
     }
 
     // \}
 
+    /*!
+     * \brief Compute colors for multithreaded assembly
+     *
+     * \param domainI the domain index of domain i
+     * \param assembleElement kernel function to execute for one element
+     */
+    void computeColorsForAssembly()
+    {
+        // use coloring of the fc staggered discretization for both domains
+        elementSets_ = computeColoring(this->problem(freeFlowMomentumIndex).gridGeometry()).sets;
+    }
+
+    /*!
+     * \brief Execute assembly kernel in parallel
+     *
+     * \param domainI the domain index of domain i
+     * \param assembleElement kernel function to execute for one element
+     */
+    template<std::size_t i, class AssembleElementFunc>
+    void assembleMultithreaded(Dune::index_constant<i> domainId, AssembleElementFunc&& assembleElement) const
+    {
+        if (elementSets_.empty())
+            DUNE_THROW(Dune::InvalidStateException, "Call computeColorsForAssembly before assembling in parallel!");
+
+        // make this element loop run in parallel
+        // for this we have to color the elements so that we don't get
+        // race conditions when writing into the global matrix
+        // each color can be assembled using multiple threads
+        const auto& grid = this->problem(freeFlowMomentumIndex).gridGeometry().gridView().grid();
+        for (const auto& elements : elementSets_)
+        {
+            Dumux::parallelFor(elements.size(), [&](const std::size_t eIdx)
+            {
+                const auto element = grid.entity(elements[eIdx]);
+                assembleElement(element);
+            });
+        }
+    }
+
 private:
     void bindCouplingContext_(Dune::index_constant<freeFlowMomentumIndex> domainI,
                               const Element<freeFlowMomentumIndex>& elementI) const
@@ -471,7 +515,7 @@ private:
                               const Element<freeFlowMomentumIndex>& elementI,
                               const std::size_t eIdx) const
     {
-        if (momentumCouplingContext_.empty())
+        if (momentumCouplingContext_().empty())
         {
             auto fvGeometry = localView(this->problem(freeFlowMassIndex).gridGeometry());
             fvGeometry.bind(elementI);
@@ -485,16 +529,16 @@ private:
             if (isTransient_)
                 prevElemVolVars.bindElement(elementI, fvGeometry, (*prevSol_)[freeFlowMassIndex]);
 
-            momentumCouplingContext_.emplace_back(MomentumCouplingContext{std::move(fvGeometry), std::move(curElemVolVars), std::move(prevElemVolVars), eIdx});
+            momentumCouplingContext_().emplace_back(MomentumCouplingContext{std::move(fvGeometry), std::move(curElemVolVars), std::move(prevElemVolVars), eIdx});
         }
-        else if (eIdx != momentumCouplingContext_[0].eIdx)
+        else if (eIdx != momentumCouplingContext_()[0].eIdx)
         {
-            momentumCouplingContext_[0].eIdx = eIdx;
-            momentumCouplingContext_[0].fvGeometry.bind(elementI);
-            momentumCouplingContext_[0].curElemVolVars.bind(elementI, momentumCouplingContext_[0].fvGeometry, this->curSol(freeFlowMassIndex));
+            momentumCouplingContext_()[0].eIdx = eIdx;
+            momentumCouplingContext_()[0].fvGeometry.bind(elementI);
+            momentumCouplingContext_()[0].curElemVolVars.bind(elementI, momentumCouplingContext_()[0].fvGeometry, this->curSol(freeFlowMassIndex));
 
             if (isTransient_)
-                momentumCouplingContext_[0].prevElemVolVars.bindElement(elementI, momentumCouplingContext_[0].fvGeometry, (*prevSol_)[freeFlowMassIndex]);
+                momentumCouplingContext_()[0].prevElemVolVars.bindElement(elementI, momentumCouplingContext_()[0].fvGeometry, (*prevSol_)[freeFlowMassIndex]);
         }
     }
 
@@ -509,17 +553,17 @@ private:
                               const Element<freeFlowMassIndex>& elementI,
                               const std::size_t eIdx) const
     {
-        if (massAndEnergyCouplingContext_.empty())
+        if (massAndEnergyCouplingContext_().empty())
         {
             const auto& gridGeometry = this->problem(freeFlowMomentumIndex).gridGeometry();
             auto fvGeometry = localView(gridGeometry);
             fvGeometry.bindElement(elementI);
-            massAndEnergyCouplingContext_.emplace_back(std::move(fvGeometry), eIdx);
+            massAndEnergyCouplingContext_().emplace_back(std::move(fvGeometry), eIdx);
         }
-        else if (eIdx != massAndEnergyCouplingContext_[0].eIdx)
+        else if (eIdx != massAndEnergyCouplingContext_()[0].eIdx)
         {
-            massAndEnergyCouplingContext_[0].eIdx = eIdx;
-            massAndEnergyCouplingContext_[0].fvGeometry.bindElement(elementI);
+            massAndEnergyCouplingContext_()[0].eIdx = eIdx;
+            massAndEnergyCouplingContext_()[0].fvGeometry.bindElement(elementI);
         }
     }
 
@@ -613,16 +657,36 @@ private:
     std::vector<CouplingStencilType> momentumToMassAndEnergyStencils_;
     std::vector<CouplingStencilType> massAndEnergyToMomentumStencils_;
 
-    mutable std::vector<MomentumCouplingContext> momentumCouplingContext_;
-    mutable std::vector<MassAndEnergyCouplingContext> massAndEnergyCouplingContext_;
+    // the coupling context exists for each thread
+    // TODO this is a bad pattern, just like mutable caches
+    // we should really construct and pass the context and not store it globally
+    std::vector<MomentumCouplingContext>& momentumCouplingContext_() const
+    {
+        thread_local static std::vector<MomentumCouplingContext> c;
+        return c;
+    }
+
+    // the coupling context exists for each thread
+    std::vector<MassAndEnergyCouplingContext>& massAndEnergyCouplingContext_() const
+    {
+        thread_local static std::vector<MassAndEnergyCouplingContext> c;
+        return c;
+    }
 
     //! A tuple of std::shared_ptrs to the grid variables of the sub problems
     GridVariablesTuple gridVariables_;
 
     const SolutionVector* prevSol_;
     bool isTransient_;
+
+    std::deque<std::vector<ElementSeed<freeFlowMomentumIndex>>> elementSets_;
 };
 
+//! we support multithreaded assembly
+template<class T>
+struct CouplingManagerSupportsMultithreadedAssembly<StaggeredFreeFlowCouplingManager<T>>
+: public std::true_type {};
+
 } // end namespace Dumux
 
 #endif
-- 
GitLab