diff --git a/dumux/common/math.hh b/dumux/common/math.hh index 330a0a9059755c11fc4c02235eabdb3e7b03fac1..a2ebd0f91ae7569a1d9c102829a26ae7231575c1 100644 --- a/dumux/common/math.hh +++ b/dumux/common/math.hh @@ -26,6 +26,7 @@ #include <algorithm> #include <cmath> +#include <utility> #include <dune/common/typetraits.hh> #include <dune/common/fvector.hh> @@ -489,6 +490,78 @@ bool isBetween(const Dune::FieldVector<Scalar, dim> &pos, return false; } +//! forward declaration of the linear interpolation policy (default) +namespace InterpolationPolicy { struct Linear; } + +/*! + * \ingroup Common + * \brief a generic function to interpolate given a set of parameters and an interpolation point + * \param params the parameters used for interpolation (depends on the policy used) + * \param ip the interpolation point + * \param policy the interpolation policy + */ +template <class Policy = InterpolationPolicy::Linear, class Scalar, class Parameter> +Scalar interpolate(Scalar ip, Parameter&& params) +{ return Policy::interpolate(ip, std::forward<Parameter>(params)); } + +/*! + * \ingroup Common + * \brief Interpolation policies + */ +namespace InterpolationPolicy { + +/*! + * \ingroup Common + * \brief interpolate linearly between two given values + */ +struct Linear +{ + /*! + * \brief interpolate linearly between two given values + * \param ip the interpolation point in [0,1] + * \param array with the lower and upper bound + */ + template<class Scalar> + static constexpr Scalar interpolate(Scalar ip, const std::array<Scalar, 2>& params) + { + return params[0]*(1.0 - ip) + params[1]*ip; + } +}; + +/*! + * \ingroup Common + * \brief interpolate linearly in a piecewise linear function (tabularized function) + */ +struct LinearTable +{ + /*! + * \brief interpolate linearly in a piecewise linear function (tabularized function) + * \param ip the interpolation point + * \param table the table as a pair of sorted vectors (have to be same size) + * \note if the interpolation point is out of bounds this will return the bounds + */ + template<class Scalar, class RandomAccessContainer> + static constexpr Scalar interpolate(Scalar ip, const std::pair<RandomAccessContainer, RandomAccessContainer>& table) + { + const auto& range = table.first; + const auto& values = table.second; + + // check bounds + if (ip > range.back()) return values.back(); + if (ip < range[0]) return values[0]; + + // if we are within bounds find the index of the lower bound + const auto lookUpIndex = std::distance(range.begin(), std::lower_bound(range.begin(), range.end(), ip)); + if (lookUpIndex == 0) + return values[0]; + + const auto ipLinear = (ip - range[lookUpIndex-1])/(range[lookUpIndex] - range[lookUpIndex-1]); + return Dumux::interpolate<Linear>(ipLinear, std::array<Scalar, 2>{{values[lookUpIndex-1], values[lookUpIndex]}}); + } +}; + +} // end namespace InterpolationPolicy + /*! * \ingroup Common diff --git a/test/common/math/test_math.cc b/test/common/math/test_math.cc index 7c62d3e0c3a94dcd2545a5ad96065a5ae5db9e15..d723a8f14319cad3c9a86fea87aca1eea3eeacd3 100644 --- a/test/common/math/test_math.cc +++ b/test/common/math/test_math.cc @@ -31,6 +31,7 @@ #include <config.h> #include <iostream> +#include <utility> #include <dune/common/float_cmp.hh> #include <dune/common/fmatrix.hh> @@ -40,6 +41,18 @@ #include <dumux/common/math.hh> +namespace Test { + +template<class Scalar, class Table> +void checkTableInterpolation(Scalar ip, Scalar expected, const Table& table, Scalar eps = 1e-15) +{ + using namespace Dumux; + auto interp = interpolate<InterpolationPolicy::LinearTable>(ip, table); + if (!Dune::FloatCmp::eq(interp, expected, eps)) + DUNE_THROW(Dune::Exception, "Wrong interpolation, expected " << expected << " got " << interp); +} + +} // end namespace Test int main() try { @@ -129,6 +142,20 @@ int main() try static_assert(Dumux::sign(2.0) == 1, "Wrong sign!"); static_assert(Dumux::sign(-2) == -1, "Wrong sign!"); static_assert(Dumux::sign(-3.5) == -1, "Wrong sign!"); + + ////////////////////////////////////////////////////////////////// + ///// Dumux::interpolate ///////////////////////////////////////// + ////////////////////////////////////////////////////////////////// + std::vector<double> a{0.0, 1.0, 2.0}; + std::vector<double> b{-1.0, 1.0, 3.0}; + const auto table = std::make_pair(a, b); + + Test::checkTableInterpolation(-1.0, -1.0, table); + Test::checkTableInterpolation(+0.0, -1.0, table); + Test::checkTableInterpolation(0.001, -0.998, table); + Test::checkTableInterpolation(1.5, 2.0, table); + Test::checkTableInterpolation(2.0, 3.0, table); + Test::checkTableInterpolation(3.0, 3.0, table); } catch (Dune::Exception& e) { std::cerr << e << std::endl;