From 7929f9ceb09cbc030208f48f970e7bd917de4041 Mon Sep 17 00:00:00 2001
From: Timo Koch <timokoch@uio.no>
Date: Sun, 14 Jul 2024 17:30:29 +0200
Subject: [PATCH] [test][newton] Extend parallel Newton test by testing two
 types of recoverable solver exceptions

---
 test/nonlinear/newton/test_newton_parallel.cc | 80 ++++++++++++++++++-
 1 file changed, 76 insertions(+), 4 deletions(-)

diff --git a/test/nonlinear/newton/test_newton_parallel.cc b/test/nonlinear/newton/test_newton_parallel.cc
index a0a1a6067f..e9b6b367b9 100644
--- a/test/nonlinear/newton/test_newton_parallel.cc
+++ b/test/nonlinear/newton/test_newton_parallel.cc
@@ -8,13 +8,18 @@
 #include <cmath>
 #include <cassert>
 #include <iomanip>
+#include <chrono>
+#include <thread>
 
 #include <dune/common/exceptions.hh>
 #include <dune/common/float_cmp.hh>
+#include <dune/common/parallel/mpihelper.hh>
+
 #include <dune/istl/bvector.hh>
 
 #include <dumux/common/timeloop.hh>
 #include <dumux/common/initialize.hh>
+#include <dumux/common/parameters.hh>
 #include <dumux/nonlinear/newtonsolver.hh>
 
 /*
@@ -80,10 +85,24 @@ public:
     template<class Vector>
     bool solve(const double& A, Vector& x, const Vector& b) const
     {
-        if (rank_ == 1 && timeLoop_->timeStepSize() > 0.4)
-            DUNE_THROW(Dune::Exception, "Assembly failed (for testing) on process " << rank_);
+        // solver construction
+        auto solver = constructSolver_(A, b);
+
+        // error handling: make sure the solver was successfully constructed on all processes
+        // and throw on all processes if solver construction failed
+        bool success = static_cast<bool>(solver);
+        int successRemote = success;
+        if (Dune::MPIHelper::instance().size() > 1)
+            successRemote = Dune::MPIHelper::instance().getCommunication().min(success);
+
+        if (!success)
+            DUNE_THROW(Dune::Exception, "Could not create solver");
+        else if (!successRemote)
+            DUNE_THROW(Dune::Exception, "Could not create solver on remote process");
+
+        // solver solve (here we assume that either all processes are successful or all fail)
+        x = solver->solve();
 
-        x = b/A;
         return true;
     }
 
@@ -95,6 +114,54 @@ public:
 private:
     int rank_;
     std::shared_ptr<TimeLoop<double>> timeLoop_;
+
+    template<class Vector>
+    struct Solver
+    {
+        Solver(const double& A, const Vector& b, int rank, const std::shared_ptr<TimeLoop<double>>& timeLoop)
+        : A_(A), b_(b), rank_(rank), timeLoop_(timeLoop)
+        {
+            // constructor might fail and failure might be recoverable
+            // this is what we are testing here
+            if (rank_ == 1 && timeLoop_->timeStepSize() > 0.4 && timeLoop_->timeStepSize() < 0.9)
+            {
+                using namespace std::chrono_literals;
+                std::this_thread::sleep_for(0.3s);
+                DUNE_THROW(Dune::Exception, "This is a recoverable test error during solver constructor");
+            }
+        }
+
+        auto solve() const
+        {
+            auto x = b_/A_;
+
+            // collective communication emulating parallel solver
+            x = Dune::MPIHelper::instance().getCommunication().min(x);
+
+            // solver might not converge, this is recoverable and what we are testing here
+            if (rank_ == 1 && timeLoop_->timeStepSize() > 0.9)
+            {
+                using namespace std::chrono_literals;
+                std::this_thread::sleep_for(0.3s);
+                DUNE_THROW(Dune::Exception, "This is a recoverable test error during solver solve");
+            }
+
+            return x;
+        }
+        const double& A_; const Vector& b_;
+        int rank_; std::shared_ptr<TimeLoop<double>> timeLoop_;
+    };
+
+    template<class Vector>
+    auto constructSolver_(const double& A, const Vector& b) const
+    {
+        try {
+            return std::make_shared<Solver<Vector>>(A, b, rank_, timeLoop_);
+        } catch (const Dune::Exception& e) {
+            std::cerr << "Caught exception on solver construction: " << e.what() << std::endl;
+            return std::decay_t<decltype(std::make_shared<Solver<Vector>>(A, b, rank_, timeLoop_))>();
+        }
+    }
 };
 
 } // end namespace Dumux
@@ -106,6 +173,9 @@ int main(int argc, char* argv[])
     // maybe initialize MPI and/or multithreading backend
     Dumux::initialize(argc, argv);
 
+    // initialize parameters
+    Dumux::Parameters::init(argc, argv);
+
     // use the Newton solver to find a solution to a scalar equation
     using Assembler = MockScalarAssembler;
     using LinearSolver = MockScalarLinearSolver;
@@ -121,7 +191,9 @@ int main(int argc, char* argv[])
     double initialGuess = 0.1;
     double x = initialGuess;
 
-    std::cout << "Solving: x^2 - 5 = 0" << std::endl;
+    if (rank == 0)
+        std::cout << "Solving: x^2 - 5 = 0" << std::endl;
+
     solver->solve(x, *timeLoop);
 
     if (rank == 0)
-- 
GitLab