diff --git a/dumux/common/pdesolver.hh b/dumux/common/pdesolver.hh
index 9b02b124bc3fdb7509ac614e826056867afca522..99f61bcaa65deee6bba5fb40996c9f93f472744c 100644
--- a/dumux/common/pdesolver.hh
+++ b/dumux/common/pdesolver.hh
@@ -25,12 +25,18 @@
 #define DUMUX_COMMON_PDESOLVER_HH
 
 #include <memory>
+#include <utility>
 
 #include <dune/common/hybridutilities.hh>
 
-#include <dumux/common/typetraits/matrix.hh>
 #include <dumux/common/timeloop.hh>
 
+// forward declare
+namespace Dune {
+template <class FirstRow, class ... Args>
+class MultiTypeBlockMatrix;
+}
+
 namespace Dumux {
 
 /*!
@@ -107,16 +113,17 @@ protected:
     /*!
      * \brief Helper function to assure the MultiTypeBlockMatrix's sub-blocks have the correct sizes.
      */
-    template<class M>
-    bool checkSizesOfSubMatrices(const M& A) const
+    template <class FirstRow, class ... Args>
+    bool checkSizesOfSubMatrices(const Dune::MultiTypeBlockMatrix<FirstRow, Args...>& matrix) const
     {
-        static_assert(isMultiTypeBlockMatrix<M>::value, "This function can only be used with MultiTypeBlockMatrix");
-
         bool matrixHasCorrectSize = true;
         using namespace Dune::Hybrid;
-        forEach(A, [&](const auto& row){
+        forEach(std::make_index_sequence<Dune::MultiTypeBlockMatrix<FirstRow, Args...>::N()>(), [&](const auto i)
+        {
+            const auto& row = matrix[i];
             const auto numRowsLeftMostBlock = row[Dune::index_constant<0>{}].N();
-            forEach(row, [&](const auto& subBlock){
+            forEach(row, [&](const auto& subBlock)
+            {
                 if (subBlock.N() != numRowsLeftMostBlock)
                     matrixHasCorrectSize = false;
             });