diff --git a/dumux/freeflow/rans/problem.hh b/dumux/freeflow/rans/problem.hh
index 50e3c7f82cb53c0cc6cff99753497adca0648b41..47164dcf96509257c989ef7d679776dce1431871 100644
--- a/dumux/freeflow/rans/problem.hh
+++ b/dumux/freeflow/rans/problem.hh
@@ -154,41 +154,29 @@ public:
         for (const auto& element : elements(gridView))
         {
             unsigned int elementID = this->fvGridGeometry().elementMapper().index(element);
-            std::array<std::array<Scalar, 2>, dim> distances;
             for (unsigned int dimIdx = 0; dimIdx < dim; ++dimIdx)
             {
                 neighborIDs_[elementID][dimIdx][0] = elementID;
                 neighborIDs_[elementID][dimIdx][1] = elementID;
-                distances[dimIdx][0] = std::numeric_limits<Scalar>::max();
-                distances[dimIdx][1] = std::numeric_limits<Scalar>::max();
             }
-            for (const auto& neighbor : elements(gridView))
+
+            for (const auto& intersection : intersections(gridView, element))
             {
-                unsigned int neighborID = this->fvGridGeometry().elementMapper().index(neighbor);
-                if (elementID == neighborID)
+                if (intersection.boundary())
                     continue;
 
+                unsigned int neighborID = this->fvGridGeometry().elementMapper().index(intersection.outside());
                 for (unsigned int dimIdx = 0; dimIdx < dim; ++dimIdx)
                 {
-                    GlobalPosition globalTemp = cellCenters_[elementID];
-                    globalTemp -= cellCenters_[neighborID];
-                    Scalar distanceReal = globalTemp.two_norm();
-                    Scalar distanceAxis = abs(cellCenters_[elementID][dimIdx] - cellCenters_[neighborID][dimIdx]);
-
-                    // only use element which are aligned to the one of interest
-                    if (abs(distanceReal - distanceAxis) < 1e-8)
+                    if (abs(cellCenters_[elementID][dimIdx] - cellCenters_[neighborID][dimIdx]) > 1e-8)
                     {
-                        if (cellCenters_[elementID][dimIdx] > cellCenters_[neighborID][dimIdx]
-                            && distanceAxis < distances[dimIdx][0])
+                        if (cellCenters_[elementID][dimIdx] > cellCenters_[neighborID][dimIdx])
                         {
                             neighborIDs_[elementID][dimIdx][0] = neighborID;
-                            distances[dimIdx][0] = distanceAxis;
                         }
-                        if (cellCenters_[elementID][dimIdx] < cellCenters_[neighborID][dimIdx]
-                            && distanceAxis < distances[dimIdx][1])
+                        if (cellCenters_[elementID][dimIdx] < cellCenters_[neighborID][dimIdx])
                         {
                             neighborIDs_[elementID][dimIdx][1] = neighborID;
-                            distances[dimIdx][1] = distanceAxis;
                         }
                     }
                 }