From 7929f9ceb09cbc030208f48f970e7bd917de4041 Mon Sep 17 00:00:00 2001 From: Timo Koch <timokoch@uio.no> Date: Sun, 14 Jul 2024 17:30:29 +0200 Subject: [PATCH] [test][newton] Extend parallel Newton test by testing two types of recoverable solver exceptions --- test/nonlinear/newton/test_newton_parallel.cc | 80 ++++++++++++++++++- 1 file changed, 76 insertions(+), 4 deletions(-) diff --git a/test/nonlinear/newton/test_newton_parallel.cc b/test/nonlinear/newton/test_newton_parallel.cc index a0a1a6067f..e9b6b367b9 100644 --- a/test/nonlinear/newton/test_newton_parallel.cc +++ b/test/nonlinear/newton/test_newton_parallel.cc @@ -8,13 +8,18 @@ #include <cmath> #include <cassert> #include <iomanip> +#include <chrono> +#include <thread> #include <dune/common/exceptions.hh> #include <dune/common/float_cmp.hh> +#include <dune/common/parallel/mpihelper.hh> + #include <dune/istl/bvector.hh> #include <dumux/common/timeloop.hh> #include <dumux/common/initialize.hh> +#include <dumux/common/parameters.hh> #include <dumux/nonlinear/newtonsolver.hh> /* @@ -80,10 +85,24 @@ public: template<class Vector> bool solve(const double& A, Vector& x, const Vector& b) const { - if (rank_ == 1 && timeLoop_->timeStepSize() > 0.4) - DUNE_THROW(Dune::Exception, "Assembly failed (for testing) on process " << rank_); + // solver construction + auto solver = constructSolver_(A, b); + + // error handling: make sure the solver was successfully constructed on all processes + // and throw on all processes if solver construction failed + bool success = static_cast<bool>(solver); + int successRemote = success; + if (Dune::MPIHelper::instance().size() > 1) + successRemote = Dune::MPIHelper::instance().getCommunication().min(success); + + if (!success) + DUNE_THROW(Dune::Exception, "Could not create solver"); + else if (!successRemote) + DUNE_THROW(Dune::Exception, "Could not create solver on remote process"); + + // solver solve (here we assume that either all processes are successful or all fail) + x = solver->solve(); - x = b/A; return true; } @@ -95,6 +114,54 @@ public: private: int rank_; std::shared_ptr<TimeLoop<double>> timeLoop_; + + template<class Vector> + struct Solver + { + Solver(const double& A, const Vector& b, int rank, const std::shared_ptr<TimeLoop<double>>& timeLoop) + : A_(A), b_(b), rank_(rank), timeLoop_(timeLoop) + { + // constructor might fail and failure might be recoverable + // this is what we are testing here + if (rank_ == 1 && timeLoop_->timeStepSize() > 0.4 && timeLoop_->timeStepSize() < 0.9) + { + using namespace std::chrono_literals; + std::this_thread::sleep_for(0.3s); + DUNE_THROW(Dune::Exception, "This is a recoverable test error during solver constructor"); + } + } + + auto solve() const + { + auto x = b_/A_; + + // collective communication emulating parallel solver + x = Dune::MPIHelper::instance().getCommunication().min(x); + + // solver might not converge, this is recoverable and what we are testing here + if (rank_ == 1 && timeLoop_->timeStepSize() > 0.9) + { + using namespace std::chrono_literals; + std::this_thread::sleep_for(0.3s); + DUNE_THROW(Dune::Exception, "This is a recoverable test error during solver solve"); + } + + return x; + } + const double& A_; const Vector& b_; + int rank_; std::shared_ptr<TimeLoop<double>> timeLoop_; + }; + + template<class Vector> + auto constructSolver_(const double& A, const Vector& b) const + { + try { + return std::make_shared<Solver<Vector>>(A, b, rank_, timeLoop_); + } catch (const Dune::Exception& e) { + std::cerr << "Caught exception on solver construction: " << e.what() << std::endl; + return std::decay_t<decltype(std::make_shared<Solver<Vector>>(A, b, rank_, timeLoop_))>(); + } + } }; } // end namespace Dumux @@ -106,6 +173,9 @@ int main(int argc, char* argv[]) // maybe initialize MPI and/or multithreading backend Dumux::initialize(argc, argv); + // initialize parameters + Dumux::Parameters::init(argc, argv); + // use the Newton solver to find a solution to a scalar equation using Assembler = MockScalarAssembler; using LinearSolver = MockScalarLinearSolver; @@ -121,7 +191,9 @@ int main(int argc, char* argv[]) double initialGuess = 0.1; double x = initialGuess; - std::cout << "Solving: x^2 - 5 = 0" << std::endl; + if (rank == 0) + std::cout << "Solving: x^2 - 5 = 0" << std::endl; + solver->solve(x, *timeLoop); if (rank == 0) -- GitLab