From 21842d590582cba792ce061fea261adf7c36c47b Mon Sep 17 00:00:00 2001
From: Timo Koch <timo.koch@iws.uni-stuttgart.de>
Date: Fri, 19 Jun 2020 21:51:25 +0200
Subject: [PATCH] [newton] Add some algebra backend helper functions

---
 dumux/nonlinear/newtonsolver.hh | 149 ++++++++++++++++++--------------
 1 file changed, 83 insertions(+), 66 deletions(-)

diff --git a/dumux/nonlinear/newtonsolver.hh b/dumux/nonlinear/newtonsolver.hh
index c0def89e05..077d59a6d7 100644
--- a/dumux/nonlinear/newtonsolver.hh
+++ b/dumux/nonlinear/newtonsolver.hh
@@ -28,12 +28,17 @@
 #include <memory>
 #include <iostream>
 #include <type_traits>
+#include <algorithm>
+#include <numeric>
 
 #include <dune/common/timer.hh>
 #include <dune/common/exceptions.hh>
 #include <dune/common/parallel/mpicommunication.hh>
 #include <dune/common/parallel/mpihelper.hh>
 #include <dune/common/std/type_traits.hh>
+#include <dune/common/indices.hh>
+#include <dune/common/hybridutilities.hh>
+
 #include <dune/istl/bvector.hh>
 #include <dune/istl/multitypeblockvector.hh>
 
@@ -69,6 +74,77 @@ using DetectPVSwitch = typename Assembler::GridVariables::VolumeVariables::Prima
 template<class Assembler>
 using GetPVSwitch = Dune::Std::detected_or<int, DetectPVSwitch, Assembler>;
 
+// helpers to implement max relative shift
+template<class C> using dynamicIndexAccess = decltype(std::declval<C>()[0]);
+template<class C> using staticIndexAccess = decltype(std::declval<C>()[Dune::Indices::_0]);
+template<class C> static constexpr auto hasDynamicIndexAccess = Dune::Std::is_detected<dynamicIndexAccess, C>{};
+template<class C> static constexpr auto hasStaticIndexAccess = Dune::Std::is_detected<staticIndexAccess, C>{};
+
+template<class V, class Scalar, class Reduce, class Transform>
+auto hybridInnerProduct(const V& v1, const V& v2, Scalar init, Reduce&& r, Transform&& t)
+-> std::enable_if_t<hasDynamicIndexAccess<V>(), Scalar>
+{
+    return std::inner_product(v1.begin(), v1.end(), v2.begin(), init, std::forward<Reduce>(r), std::forward<Transform>(t));
+}
+
+template<class V, class Scalar, class Reduce, class Transform>
+auto hybridInnerProduct(const V& v1, const V& v2, Scalar init, Reduce&& r, Transform&& t)
+-> std::enable_if_t<hasStaticIndexAccess<V>() && !hasDynamicIndexAccess<V>(), Scalar>
+{
+    using namespace Dune::Hybrid;
+    forEach(std::make_index_sequence<V::N()>{}, [&](auto i){
+        init = r(init, hybridInnerProduct(v1[Dune::index_constant<i>{}], v2[Dune::index_constant<i>{}], init, std::forward<Reduce>(r), std::forward<Transform>(t)));
+    });
+    return init;
+}
+
+// TODO: Document why this is computed like this
+template<class Scalar, class V>
+auto maxRelativeShift(const V& v1, const V& v2)
+-> std::enable_if_t<Dune::IsNumber<V>::value, Scalar>
+{
+    using std::abs; using std::max;
+    return abs(v1 - v2)/max<Scalar>(1.0, abs(v1 + v2)*0.5);
+}
+
+template<class Scalar, class V>
+auto maxRelativeShift(const V& v1, const V& v2)
+-> std::enable_if_t<!Dune::IsNumber<V>::value, Scalar>
+{
+    return hybridInnerProduct(v1, v2, Scalar(0.0),
+        [](const auto& a, const auto& b){ using std::max; return max(a, b); },
+        [](const auto& a, const auto& b){ return maxRelativeShift<Scalar>(a, b); }
+    );
+}
+
+template<class To, class From>
+void assign(To& to, const From& from)
+{
+    if constexpr (std::is_assignable<To&, From>::value)
+        to = from;
+
+    else if constexpr (hasStaticIndexAccess<To>() && hasStaticIndexAccess<To>() && !hasDynamicIndexAccess<From>() && !hasDynamicIndexAccess<From>())
+    {
+        using namespace Dune::Hybrid;
+        forEach(std::make_index_sequence<To::N()>{}, [&](auto i){
+            assign(to[Dune::index_constant<i>{}], from[Dune::index_constant<i>{}]);
+        });
+    }
+
+    else if constexpr (hasDynamicIndexAccess<From>() && hasDynamicIndexAccess<From>())
+        for (decltype(to.size()) i = 0; i < to.size(); ++i)
+            assign(to[i], from[i]);
+
+    else
+        DUNE_THROW(Dune::Exception, "Values are not assignable to each other!");
+}
+
+template<class T, std::enable_if_t<Dune::IsNumber<std::decay_t<T>>::value, int> = 0>
+constexpr std::size_t blockSize() { return 1; }
+
+template<class T, std::enable_if_t<!Dune::IsNumber<std::decay_t<T>>::value, int> = 0>
+constexpr std::size_t blockSize() { return std::decay_t<T>::size(); }
+
 } // end namespace Detail
 
 /*!
@@ -949,44 +1025,14 @@ private:
     virtual void newtonUpdateShift_(const SolutionVector &uLastIter,
                                     const SolutionVector &deltaU)
     {
-        shift_ = 0;
-        newtonUpdateShiftImpl_(uLastIter, deltaU);
+        auto uNew = uLastIter;
+        uNew -= deltaU;
+        shift_ = Detail::maxRelativeShift<Scalar>(uLastIter, uNew);
 
         if (comm_.size() > 1)
             shift_ = comm_.max(shift_);
     }
 
-    template<class SolVec>
-    void newtonUpdateShiftImpl_(const SolVec &uLastIter,
-                                const SolVec &deltaU)
-    {
-        for (int i = 0; i < int(uLastIter.size()); ++i)
-        {
-            auto uNewI = uLastIter[i];
-            uNewI -= deltaU[i];
-
-            Scalar shiftAtDof = relativeShiftAtDof_(uLastIter[i], uNewI);
-            using std::max;
-            shift_ = max(shift_, shiftAtDof);
-        }
-    }
-
-    template<class ...Args>
-    void newtonUpdateShiftImpl_(const Dune::MultiTypeBlockVector<Args...> &uLastIter,
-                                const Dune::MultiTypeBlockVector<Args...> &deltaU)
-    {
-        // There seems to be a bug in g++5 which which prevents compilation when
-        // passing the call to the implementation directly to Dune::Hybrid::forEach.
-        // We therefore store this call in a lambda and pass it to the for loop afterwards.
-        auto doUpdate = [&](const auto subVectorIdx)
-        {
-            this->newtonUpdateShiftImpl_(uLastIter[subVectorIdx], deltaU[subVectorIdx]);
-        };
-
-        using namespace Dune::Hybrid;
-        forEach(integralRange(Dune::Hybrid::size(uLastIter)), doUpdate);
-    }
-
     virtual void lineSearchUpdate_(SolutionVector &uCurrentIter,
                                    const SolutionVector &uLastIter,
                                    const SolutionVector &deltaU)
@@ -1050,19 +1096,14 @@ private:
         //! to this field vector type in Dune ISTL
         //! Could be avoided for vectors that already have the right type using SFINAE
         //! but it shouldn't impact performance too much
-        constexpr auto blockSize = std::decay_t<decltype(b[0])>::dimension;
+        constexpr auto blockSize = Detail::blockSize<decltype(b[0])>();
         using BlockType = Dune::FieldVector<Scalar, blockSize>;
         Dune::BlockVector<BlockType> xTmp; xTmp.resize(b.size());
         Dune::BlockVector<BlockType> bTmp(xTmp);
-        for (unsigned int i = 0; i < b.size(); ++i)
-            for (unsigned int j = 0; j < blockSize; ++j)
-                bTmp[i][j] = b[i][j];
 
+        Detail::assign(bTmp, b);
         const int converged = ls.solve(A, xTmp, bTmp);
-
-        for (unsigned int i = 0; i < x.size(); ++i)
-            for (unsigned int j = 0; j < blockSize; ++j)
-                x[i][j] = xTmp[i][j];
+        Detail::assign(x, xTmp);
 
         return converged;
     }
@@ -1188,7 +1229,7 @@ private:
             nextPriVars -= uDelta[i];
 
             // add the current relative shift for this degree of freedom
-            auto shift = relativeShiftAtDof_(currentPriVars, nextPriVars);
+            auto shift = Detail::maxRelativeShift<Scalar>(currentPriVars, nextPriVars);
             distanceFromLastLinearization_[i] += shift;
         }
     }
@@ -1213,30 +1254,6 @@ private:
         DUNE_THROW(Dune::NotImplemented, "Reassembly for MultiTypeBlockVector");
     }
 
-    /*!
-     * \brief Returns the maximum relative shift between two vectors of
-     *        primary variables.
-     *
-     * \param priVars1 The first vector of primary variables
-     * \param priVars2 The second vector of primary variables
-     */
-    template<class PrimaryVariables>
-    Scalar relativeShiftAtDof_(const PrimaryVariables &priVars1,
-                               const PrimaryVariables &priVars2) const
-    {
-        Scalar result = 0.0;
-        using std::abs;
-        using std::max;
-        // iterate over all primary variables
-        for (int j = 0; j < PrimaryVariables::dimension; ++j) {
-            Scalar eqErr = abs(priVars1[j] - priVars2[j]);
-            eqErr /= max<Scalar>(1.0,abs(priVars1[j] + priVars2[j])/2);
-
-            result = max(result, eqErr);
-        }
-        return result;
-    }
-
     //! The communication object
     Communication comm_;
 
-- 
GitLab