diff --git a/dumux/discretization/staggered/staggeredgeometryhelper.hh b/dumux/discretization/staggered/staggeredgeometryhelper.hh
index c74cc94f0fc8292ec5cc40be991ccb86c27b31bf..21f4623493c9b9b66a157ac62dee92f3e96f4d61 100644
--- a/dumux/discretization/staggered/staggeredgeometryhelper.hh
+++ b/dumux/discretization/staggered/staggeredgeometryhelper.hh
@@ -43,11 +43,13 @@ struct PairData
     Scalar normalDistance;
 };
 
-
 //! A class to create sub control volume and sub control volume face geometries per element
-template <class GridView>
-// class StaggeredGeometryHelper<GridView, dim>
+template <class GridView, int dim = GridView::dimension >
 class StaggeredGeometryHelper
+{};
+
+template<class GridView>
+class BaseStaggeredGeometryHelper
 {
     using Scalar = typename GridView::ctype;
     static constexpr int dim = GridView::dimension;
@@ -66,20 +68,22 @@ class StaggeredGeometryHelper
 
     using ReferenceElements = typename Dune::ReferenceElements<Scalar, dim>;
 
-
     //TODO include assert that checks for quad geometry
     static constexpr int codimCommonEntity = 2; //TODO: 3d?
+    static constexpr int numFacetSubEntities = 2; // TODO: 3d?
 
+    using Implementation = typename Dumux::StaggeredGeometryHelper<GridView, dim>;
 
 public:
-
-
-    StaggeredGeometryHelper(const Intersection& intersection, const GridView& gridView)
+    BaseStaggeredGeometryHelper(const Intersection& intersection, const GridView& gridView)
     : intersection_(intersection), element_(intersection.inside()), elementGeometry_(element_.geometry()), gridView_(gridView), offset_(gridView.size(0))
     {
         fillPairData();
     }
 
+     /*!
+     * \brief Returns the global dofIdx of the intersection itself
+     */
     int dofIdxSelf() const
     {
         //TODO: use proper intersection mapper!
@@ -87,20 +91,27 @@ public:
         return gridView_.indexSet().subIndex(intersection_.inside(), inIdx, dim-1) + offset_;
     }
 
+     /*!
+     * \brief Returns the global dofIdx of the opposing intersection
+     */
     int dofIdxOpposite() const
     {
         //TODO: use proper intersection mapper!
         const auto inIdx = intersection_.indexInInside();
-        return gridView_.indexSet().subIndex(intersection_.inside(), localOppositeIdx_(inIdx), dim-1) + offset_;
+        return gridView_.indexSet().subIndex(this->intersection_.inside(), localOppositeIdx_(inIdx), dim-1) + this->offset_;
     }
 
-
+     /*!
+     * \brief Returns a copy of the pair data
+     */
     auto pairData() const
     {
         return pairData_;
     }
 
-
+     /*!
+     * \brief Fills all entries of the pair data
+     */
     void fillPairData()
     {
         const auto& referenceElement = ReferenceElements::general(element_.geometry().type());
@@ -114,10 +125,9 @@ public:
             data.parallelDistance = 0;
         }
 
-
         // set the inner parts of the normal pairs
-        const auto localInnerNormalDofIndices = getLocalInnerNormalDofIndices_(indexInInside);
-        setInnerNormalPairs_(localInnerNormalDofIndices);
+        const auto localInnerNormalDofIndices = asImp_().getLocalInnerNormalDofIndices_(indexInInside);
+        asImp_().setInnerNormalPairs_(localInnerNormalDofIndices);
 
         // get the positions of the faces normal to the intersection within the element
         std::vector<GlobalPosition> innerNormalFacePos;
@@ -134,7 +144,6 @@ public:
             DUNE_THROW(Dune::NotImplemented, "3d not ready yet");
         }
 
-
         // go into the direct neighbor element
         if(intersection_.neighbor())
         {
@@ -147,8 +156,8 @@ public:
                 // skip the directly neighboring face itself and its opposing one
                 if(neighborIntersectionNormalSide_(neighborIsIdx, intersection_.indexInOutside()))
                 {
-                    // iterate over facets sub-entities // TODO: get number correctly
-                    for(int i = 0; i < 2; ++i)
+                    // iterate over facets sub-entities
+                    for(int i = 0; i < numFacetSubEntities; ++i)
                     {
                         int localCommonEntIdx = referenceElement.subEntity(neighborIsIdx, 1, i, dim);
                         int globalCommonEntIdx = localToGlobalEntityIdx_(localCommonEntIdx, directNeighbor);
@@ -166,8 +175,6 @@ public:
                         }
                     }
 
-
-
                     // go into the adjacent neighbor element
                     if(neighborIntersection.neighbor())
                     {
@@ -176,7 +183,7 @@ public:
                         {
                             if(neighborIntersectionNormalSide_(dIs.indexInInside(), neighborIntersection.indexInOutside()))
                             {
-                                for(int i = 0; i < 2; ++i)
+                                for(int i = 0; i < numFacetSubEntities; ++i)
                                 {
                                     int localCommonEntIdx = referenceElement.subEntity(dIs.indexInInside(), 1, i, dim);
                                     int globalCommonEntIdx = localToGlobalEntityIdx_(localCommonEntIdx, diagonalNeighbor);
@@ -202,28 +209,89 @@ public:
     }
 
 private:
-
+     /*!
+     * \brief Returns the local opposing intersection index
+     *
+     * \param idx The local index of the intersection itself
+     */
     int localOppositeIdx_(const int idx) const
     {
         return (idx % 2) ? (idx - 1) : (idx + 1);
     }
 
-    bool neighborIntersectionNormalSide_(const int isIdx, const int neighborIsIdx) const
+     /*!
+     * \brief Returns true if the intersection lies normal to another given intersection
+     *
+     * \param selfIdx The local index of the intersection itself
+     * \param otherIdx The local index of the other intersection
+     */
+    bool neighborIntersectionNormalSide_(const int selfIdx, const int otherIdx) const
     {
-        return !(isIdx == neighborIsIdx || localOppositeIdx_(isIdx) == neighborIsIdx);
+        return !(selfIdx == otherIdx || localOppositeIdx_(selfIdx) == otherIdx);
     };
 
-    int localToGlobalEntityIdx_(const int localIdx, const Element& element)
+     /*!
+     * \brief Returns the global index of the common entity
+     *
+     * \param localIdx The local index of the common entity
+     * \param element The element
+     */
+    int localToGlobalEntityIdx_(const int localIdx, const Element& element) const
     {
-        return gridView_.indexSet().subIndex(element, localIdx, codimCommonEntity);
+        return this->gridView_.indexSet().subIndex(element, localIdx, codimCommonEntity);
     };
 
+protected:
+    const Intersection& intersection_; //! The intersection of interest
+    const Element& element_; //! The respective element
+    const typename Element::Geometry& elementGeometry_; //! Reference to the element geometry
+    const GridView gridView_;
+    const int offset_; //! Offset for intersection dof indexing
+    std::array<PairData<Scalar>, numPairs> pairData_; //! collection of pair information
+
+    //! Returns the implementation of the problem (i.e. static polymorphism)
+    Implementation &asImp_()
+    { return *static_cast<Implementation *>(this); }
 
+    //! \copydoc asImp_()
+    const Implementation &asImp_() const
+    { return *static_cast<const Implementation *>(this); }
+};
 
 
-    // specializations for 2D ***************************************************************************************************
-    template<class G = GridView, typename std::enable_if<G::dimension == 2, int>::type = 0>
-    auto getLocalInnerNormalDofIndices_(const int directNeighborIsIdx)
+template<class GridView>
+class StaggeredGeometryHelper<GridView, 2> : public BaseStaggeredGeometryHelper<GridView>
+{
+    friend class BaseStaggeredGeometryHelper<GridView>;
+    using Scalar = typename GridView::ctype;
+    static constexpr int dim = GridView::dimension;
+    static constexpr int dimWorld = GridView::dimensionworld;
+
+    static constexpr int numPairs = (dimWorld == 2) ? 2 : 4;
+
+    using ScvGeometry = Dune::CachedMultiLinearGeometry<Scalar, dim, dimWorld>;
+    using ScvfGeometry = Dune::CachedMultiLinearGeometry<Scalar, dim-1, dimWorld>;
+
+    using GlobalPosition = typename ScvGeometry::GlobalCoordinate;
+    using PointVector = std::vector<GlobalPosition>;
+
+    using Element = typename GridView::template Codim<0>::Entity;
+    using Intersection = typename GridView::Intersection;
+
+    using ReferenceElements = typename Dune::ReferenceElements<Scalar, dim>;
+
+    //TODO include assert that checks for quad geometry
+    static constexpr int codimCommonEntity = 2; //TODO: 3d?
+
+    using ParentType = BaseStaggeredGeometryHelper<GridView>;
+
+public:
+    StaggeredGeometryHelper(const Intersection& intersection, const GridView& gridView)
+    : ParentType(intersection, gridView)
+    {}
+
+private:
+    static auto getLocalInnerNormalDofIndices_(const int directNeighborIsIdx)
     {
         struct Indices
         {
@@ -267,60 +335,60 @@ private:
         return indices;
     }
 
-    template<class Indices, class G = GridView>
-    typename std::enable_if<G::dimension == 2, void>::type
-    setInnerNormalPairs_(const Indices& indices)
+    template<class Indices>
+    void setInnerNormalPairs_(const Indices& indices)
     {
-        pairData_[0].normalPair.first = gridView_.indexSet().subIndex(intersection_.inside(), indices.normalLocalDofIdx1, dim-1) + offset_;
-        pairData_[1].normalPair.first = gridView_.indexSet().subIndex(intersection_.inside(), indices.normalLocalDofIdx2, dim-1) + offset_;
-        pairData_[0].globalCommonEntIdx = gridView_.indexSet().subIndex(intersection_.inside(), indices.localCommonEntIdx1, codimCommonEntity);
-        pairData_[1].globalCommonEntIdx = gridView_.indexSet().subIndex(intersection_.inside(), indices.localCommonEntIdx2, codimCommonEntity);
+        this->pairData_[0].normalPair.first = this->gridView_.indexSet().subIndex(this->intersection_.inside(), indices.normalLocalDofIdx1, dim-1) + this->offset_;
+        this->pairData_[1].normalPair.first = this->gridView_.indexSet().subIndex(this->intersection_.inside(), indices.normalLocalDofIdx2, dim-1) + this->offset_;
+        this->pairData_[0].globalCommonEntIdx = this->gridView_.indexSet().subIndex(this->intersection_.inside(), indices.localCommonEntIdx1, codimCommonEntity);
+        this->pairData_[1].globalCommonEntIdx = this->gridView_.indexSet().subIndex(this->intersection_.inside(), indices.localCommonEntIdx2, codimCommonEntity);
     }
 
+};
 
-    template<class G = GridView, typename std::enable_if<G::dimension == 3, int>::type = 0>
-    auto getLocalInnerNormalDofIndices_(const int directNeighborIsIdx)
-    {
-        struct Indices
-        {
-            int normalLocalDofIdx1;
-            int normalLocalDofIdx2;
-            int normalLocalDofIdx3;
-            int normalLocalDofIdx4;
-            int localCommonEntIdx1;
-            int localCommonEntIdx2;
-            int localCommonEntIdx3;
-            int localCommonEntIdx4;
-        };
+template<class GridView>
+class StaggeredGeometryHelper<GridView, 3> : public BaseStaggeredGeometryHelper<GridView>
+{
+    friend class BaseStaggeredGeometryHelper<GridView>;
+    using Scalar = typename GridView::ctype;
+    static constexpr int dim = GridView::dimension;
+    static constexpr int dimWorld = GridView::dimensionworld;
 
-        Indices indices;
+    using ScvGeometry = Dune::CachedMultiLinearGeometry<Scalar, dim, dimWorld>;
+    using ScvfGeometry = Dune::CachedMultiLinearGeometry<Scalar, dim-1, dimWorld>;
 
-        switch(directNeighborIsIdx)
-        {
-            default:
-                DUNE_THROW(Dune::NotImplemented, "3d helper not ready yet");
-        }
-        return indices;
-    }
+    using GlobalPosition = typename ScvGeometry::GlobalCoordinate;
+    using PointVector = std::vector<GlobalPosition>;
 
-    template<class Indices, class G = GridView>
-    typename std::enable_if<G::dimension == 3, void>::type
-    setInnerNormalPairs_(const Indices& indices)
-    {
-        // TODO: 3D
-        DUNE_THROW(Dune::NotImplemented, "3d helper not ready yet");
-    }
+    using Element = typename GridView::template Codim<0>::Entity;
+    using Intersection = typename GridView::Intersection;
 
+    using ReferenceElements = typename Dune::ReferenceElements<Scalar, dim>;
 
-    const Intersection& intersection_;
-    const Element& element_;
-    const typename Element::Geometry& elementGeometry_; //! Reference to the element geometry
-    const GridView gridView_;
-    const int offset_;
 
-    std::array<PairData<Scalar>, numPairs> pairData_;
+    //TODO include assert that checks for quad geometry
+    static constexpr int codimCommonEntity = 2; //TODO: 3d?
+
+    using ParentType = BaseStaggeredGeometryHelper<GridView>;
 
+public:
+    StaggeredGeometryHelper(const Intersection& intersection, const GridView& gridView)
+    : ParentType(intersection, gridView)
+    {}
 
+private:
+    auto getLocalInnerNormalDofIndices_(const int directNeighborIsIdx)
+    {
+        // TODO: 3D
+        DUNE_THROW(Dune::NotImplemented, "3d helper not ready yet");
+    }
+
+    template<class Indices>
+    void setInnerNormalPairs_(const Indices& indices)
+    {
+        // TODO: 3D
+        DUNE_THROW(Dune::NotImplemented, "3d helper not ready yet");
+    }
 };