From 5d5d7e6838c7b2d08c6b14a7959cec9f7687ffaa Mon Sep 17 00:00:00 2001
From: Timo Koch <timo.koch@iws.uni-stuttgart.de>
Date: Sat, 1 Jun 2019 17:16:04 +0200
Subject: [PATCH] [integrate] Add support for vector fields

---
 dumux/common/integrate.hh | 95 ++++++++++++++++++++++++++++++++-------
 1 file changed, 78 insertions(+), 17 deletions(-)

diff --git a/dumux/common/integrate.hh b/dumux/common/integrate.hh
index 3a68c6d83c..3d85da7db6 100644
--- a/dumux/common/integrate.hh
+++ b/dumux/common/integrate.hh
@@ -36,8 +36,12 @@
 
 #include <dumux/discretization/evalsolution.hh>
 #include <dumux/discretization/elementsolution.hh>
+#include <dumux/common/typetraits/typetraits.hh>
 
 namespace Dumux {
+
+// implementation details of the integrate functions
+#ifndef DOXYGEN
 namespace Impl {
 
 struct HasLocalFunction
@@ -53,7 +57,59 @@ template<class F>
 static constexpr bool hasLocalFunction()
 { return Dune::models<HasLocalFunction, F>(); }
 
+template<class Error,
+         typename std::enable_if_t<IsIndexable<Error>::value, int> = 0>
+auto localErrorSqrImpl(const Error& error)
+{
+    Error sqr = error; sqr = 0.0;
+    for (int i = 0; i < sqr.size(); ++i)
+        sqr[i] += error[i]*error[i];
+    return sqr;
+}
+
+template<class Error,
+         typename std::enable_if_t<!IsIndexable<Error>::value, int> = 0>
+auto localErrorSqrImpl(const Error& error)
+{
+    return error*error;
+}
+
+template<class Error,
+         typename std::enable_if_t<IsIndexable<Error>::value, int> = 0>
+Error sqrtNorm(const Error& error)
+{
+    using std::sqrt;
+    auto e = error;
+    for (int i = 0; i < error.size(); ++i)
+        e[i] = sqrt(error[i]);
+    return e;
+}
+
+template<class Error,
+         typename std::enable_if_t<!IsIndexable<Error>::value, int> = 0>
+Error sqrtNorm(const Error& error)
+{
+    using std::sqrt;
+    return sqrt(error);
+}
+
+template<class T, typename = int>
+struct FieldTypeImpl
+{
+    using type = T;
+};
+
+template<class T>
+struct FieldTypeImpl<T, typename std::enable_if<(sizeof(std::declval<T>()[0]) > 0), int>::type>
+{
+    using type = typename FieldTypeImpl<std::decay_t<decltype(std::declval<T>()[0])>>::type;
+};
+
+template<class T>
+using FieldType = typename FieldTypeImpl<T>::type;
+
 } // end namespace Impl
+#endif
 
 /*!
  * \brief Integrate a grid function over a grid view
@@ -67,10 +123,11 @@ auto integrateGridFunction(const GridGeometry& gg,
                            SolutionVector&& sol,
                            std::size_t order)
 {
-    using Scalar = std::decay_t<decltype(sol[0][0])>;
+    using PrimaryVariables = std::decay_t<decltype(sol[0])>;
     using GridView = typename GridGeometry::GridView;
+    using Scalar = typename Impl::FieldType<PrimaryVariables>;
 
-    Scalar integral = 0.0;
+    PrimaryVariables integral = 0.0;
     for (const auto& element : elements(gg.gridView()))
     {
         const auto elemSol = elementSolution(element, sol, gg);
@@ -101,10 +158,11 @@ auto integrateL2Error(const GridGeometry& gg,
                       const Sol2& sol2,
                       std::size_t order)
 {
-    using Scalar = std::decay_t<decltype(sol1[0][0])>;
+    using PrimaryVariables = std::decay_t<decltype(sol1[0])>;
     using GridView = typename GridGeometry::GridView;
+    using Scalar = typename Impl::FieldType<PrimaryVariables>;
 
-    Scalar l2norm = 0.0;
+    PrimaryVariables l2norm = 0.0;
     for (const auto& element : elements(gg.gridView()))
     {
         const auto elemSol1 = elementSolution(element, sol1, gg);
@@ -118,11 +176,13 @@ auto integrateL2Error(const GridGeometry& gg,
             const auto value1 = evalSolution(element, geometry, gg, elemSol1, globalPos);
             const auto value2 = evalSolution(element, geometry, gg, elemSol2, globalPos);
             const auto error = (value1 - value2);
-            l2norm += error*error*qp.weight()*geometry.integrationElement(qp.position());
+            auto value = Impl::localErrorSqrImpl(error);
+            value *= qp.weight()*geometry.integrationElement(qp.position());
+            l2norm += value;
         }
     }
-    using std::sqrt;
-    return sqrt(l2norm);
+
+    return Impl::sqrtNorm(l2norm);
 }
 
 #if HAVE_DUNE_FUNCTIONS
@@ -144,9 +204,10 @@ auto integrateGridFunction(const GridView& gv,
 
     using Element = typename GridView::template Codim<0>::Entity;
     using LocalPosition = typename Element::Geometry::LocalCoordinate;
-    using Scalar = std::decay_t<decltype(fLocal(std::declval<LocalPosition>()))>;
+    using PrimaryVariables = std::decay_t<decltype(fLocal(std::declval<LocalPosition>()))>;
+    using Scalar = typename Impl::FieldType<PrimaryVariables>;
 
-    Scalar integral = 0.0;
+    PrimaryVariables integral = 0.0;
     for (const auto& element : elements(gv))
     {
         fLocal.bind(element);
@@ -186,12 +247,10 @@ auto integrateL2Error(const GridView& gv,
 
     using Element = typename GridView::template Codim<0>::Entity;
     using LocalPosition = typename Element::Geometry::LocalCoordinate;
-    using FScalar = std::decay_t<decltype(fLocal(std::declval<LocalPosition>()))>;
-    using GScalar = std::decay_t<decltype(gLocal(std::declval<LocalPosition>()))>;
+    using PrimaryVariables = std::decay_t<decltype(fLocal(std::declval<LocalPosition>()))>;
+    using Scalar = typename Impl::FieldType<PrimaryVariables>;
 
-    using Scalar = std::decay_t<decltype((std::declval<FScalar>()-std::declval<GScalar>())
-                                        *(std::declval<FScalar>()-std::declval<GScalar>()))>;
-    Scalar l2norm = 0.0;
+    PrimaryVariables l2norm = 0.0;
     for (const auto& element : elements(gv))
     {
         fLocal.bind(element);
@@ -202,14 +261,16 @@ auto integrateL2Error(const GridView& gv,
         for (auto&& qp : quad)
         {
             const auto error = fLocal(qp.position()) - gLocal(qp.position());
-            l2norm += error*error*qp.weight()*geometry.integrationElement(qp.position());
+            auto value = Impl::localErrorSqrImpl(error);
+            value *= qp.weight()*geometry.integrationElement(qp.position());
+            l2norm += value;
         }
 
         gLocal.unbind();
         fLocal.unbind();
     }
-    using std::sqrt;
-    return sqrt(l2norm);
+
+    return Impl::sqrtNorm(l2norm);
 }
 #endif
 
-- 
GitLab