diff --git a/Src/Base/AMReX_FEIntegrator.H b/Src/Base/AMReX_FEIntegrator.H index becd795e742..f8a002ef534 100644 --- a/Src/Base/AMReX_FEIntegrator.H +++ b/Src/Base/AMReX_FEIntegrator.H @@ -15,50 +15,92 @@ private: amrex::Vector > F_nodes; - void initialize_stages (const T& S_data) + // Current (internal) state and time + amrex::Vector > S_current; + amrex::Real time_current; + + void initialize_stages (const T& S_data, const amrex::Real time) { + // Create data for stage RHS IntegratorOps::CreateLike(F_nodes, S_data); + + // Create and initialize data for current state + IntegratorOps::CreateLike(S_current, S_data, true); + IntegratorOps::Copy(*S_current[0], S_data); + + // Set the initial time + time_current = time; } public: FEIntegrator () {} - FEIntegrator (const T& S_data) + FEIntegrator (const T& S_data, const amrex::Real time = 0.0) { - initialize(S_data); + initialize(S_data, time); } virtual ~FEIntegrator () {} - void initialize (const T& S_data) override + void initialize (const T& S_data, const amrex::Real time = 0.0) { - initialize_stages(S_data); + initialize_stages(S_data, time); } - amrex::Real advance (T& S_old, T& S_new, amrex::Real time, const amrex::Real time_step) override + amrex::Real advance (T& S_old, T& S_new, amrex::Real time, const amrex::Real dt) override { - BaseT::timestep = time_step; - // Assume before advance() that S_old is valid data at the current time ("time" argument) + // Assume before step() that S_old is valid data at the current time ("time" argument) // So we initialize S_new by copying the old state. IntegratorOps::Copy(S_new, S_old); + // Call the pre RHS hook + BaseT::pre_rhs_action(S_new, time); + // F = RHS(S, t) T& F = *F_nodes[0]; - BaseT::rhs(F, S_new, time); + BaseT::Rhs(F, S_new, time); // S_new += timestep * dS/dt - IntegratorOps::Saxpy(S_new, BaseT::timestep, F); + IntegratorOps::Saxpy(S_new, dt, F); - // Call the post-update hook for S_new - BaseT::post_update(S_new, time + BaseT::timestep); + // Call the post step hook + BaseT::post_step_action(S_new, time + dt); // Return timestep - return BaseT::timestep; + return dt; + } + + void evolve (T& S_out, const amrex::Real time_out) override + { + amrex::Real dt = BaseT::time_step; + bool stop = false; + + for (int step_number = 0; step_number < BaseT::max_steps && !stop; ++step_number) + { + // Adjust step size to reach output time + if (time_out - time_current < dt) { + dt = time_out - time_current; + stop = true; + } + + // Call the time integrator step + advance(*S_current[0], S_out, time_current, dt); + + // Update current state S_current = S_out + IntegratorOps::Copy(*S_current[0], S_out); + + // Update time + time_current += dt; + + if (step_number == BaseT::max_steps - 1) { + Error("Did not reach output time in max steps."); + } + } } virtual void time_interpolate (const T& /* S_new */, const T& /* S_old */, amrex::Real /* timestep_fraction */, T& /* data */) override { - amrex::Error("Time interpolation not yet supported by forward euler integrator."); + amrex::Error("Time interpolation not yet supported by the forward euler integrator."); } virtual void map_data (std::function Map) override diff --git a/Src/Base/AMReX_IntegratorBase.H b/Src/Base/AMReX_IntegratorBase.H index 568e063bed5..d9af8053d7e 100644 --- a/Src/Base/AMReX_IntegratorBase.H +++ b/Src/Base/AMReX_IntegratorBase.H @@ -161,37 +161,120 @@ struct IntegratorOps > > template class IntegratorBase { -private: - /** - * \brief Fun is the right-hand-side function the integrator will use. - */ - std::function Fun; - - /** - * \brief FastFun is the fast timescale right-hand-side function for a multirate integration problem. - */ - std::function FastFun; - protected: - /** - * \brief Integrator timestep size (Real) - */ - amrex::Real timestep; - - /** - * \brief For multirate problems, the ratio of slow timestep size / fast timestep size (int) - */ - int slow_fast_timestep_ratio = 0; - - /** - * \brief For multirate problems, the fast timestep size (Real) - */ - Real fast_timestep = 0.0; - - /** - * \brief The post_update function is called by the integrator on state data before using it to evaluate a right-hand side. - */ - std::function post_update; + /** + * \brief Rhs is the right-hand-side function the integrator will use. + */ + std::function Rhs; + + /** + * \brief RhsIm is the implicit right-hand-side function an ImEx integrator + * will use. + */ + std::function RhsIm; + + /** + * \brief RhsEx is the explicit right-hand-side function an ImEx integrator + * will use. + */ + std::function RhsEx; + + /** + * \brief RhsFast is the fast timescale right-hand-side function a multirate + * integrator will use. + */ + std::function RhsFast; + + /** + * \brief The pre_rhs_action function is called by the integrator on state + * data before using it to evaluate a right-hand side. + */ + std::function pre_rhs_action; + + /** + * \brief The post_stage_action function is called by the integrator on + * the computed stage just after it is computed + */ + std::function post_stage_action; + + /** + * \brief The post_step_action function is called by the integrator on + * the computed state just after it is computed + */ + std::function post_step_action; + + /** + * \brief The post_stage_action function is called by the integrator on + * the computed stage just after it is computed + */ + std::function post_fast_stage_action; + + /** + * \brief The post_step_action function is called by the integrator on + * the computed state just after it is computed + */ + std::function post_fast_step_action; + + /** + * \brief Flag to enable/disable adaptive time stepping in single rate + * methods or at the slow time scale in multirate methods (bool) + */ + bool use_adaptive_time_step = false; + + /** + * \brief Current integrator time step size (Real) + */ + amrex::Real time_step; + + /** + * \brief Step size of the last completed step (Real) + */ + amrex::Real previous_time_step; + + /** + * \brief Flag to enable/disable adaptive time stepping at the fast time + * scale in multirate methods (bool) + */ + bool use_adaptive_fast_time_step = false; + + /** + * \brief Current integrator fast time scale time step size with multirate + * methods (Real) + */ + amrex::Real fast_time_step; + + /** + * \brief Number of integrator time steps (Long) + */ + amrex::Long num_steps = 0; + + /** + * \brief Max number of internal steps before an error is returned (Long) + */ + int max_steps = 500; + + /** + * \brief Relative tolerance for adaptive time stepping (Real) + */ + amrex::Real rel_tol = 1.0e-4; + + /** + * \brief Absolute tolerance for adaptive time stepping (Real) + */ + amrex::Real abs_tol = 1.0e-9; + + /** + * \brief Relative tolerance for adaptive time stepping at the fast time + * scale (Real) + */ + amrex::Real fast_rel_tol = 1.0e-4; + + /** + * \brief Absolute tolerance for adaptive time stepping at the fast time + * scale (Real) + */ + amrex::Real fast_abs_tol = 1.0e-9; + public: IntegratorBase () = default; @@ -200,71 +283,112 @@ public: virtual ~IntegratorBase () = default; - virtual void initialize (const T& S_data) = 0; - void set_rhs (std::function F) { - Fun = F; + Rhs = F; + } + + void set_imex_rhs (std::function Fi, + std::function Fe) + { + RhsIm = Fi; + RhsEx = Fe; + } + + void set_fast_rhs (std::function F) + { + RhsFast = F; } - void set_fast_rhs (std::function F) + void set_pre_rhs_action (std::function A) { - FastFun = F; + pre_rhs_action = A; } - void set_slow_fast_timestep_ratio (const int timestep_ratio = 1) + void set_post_stage_action (std::function A) { - slow_fast_timestep_ratio = timestep_ratio; + post_stage_action = A; } - void set_fast_timestep (const Real fast_dt = 1.0) + void set_post_step_action (std::function A) { - fast_timestep = fast_dt; + post_step_action = A; } - void set_post_update (std::function F) + void set_post_fast_stage_action (std::function A) { - post_update = F; + post_fast_stage_action = A; } - std::function get_post_update () + void set_post_fast_step_action (std::function A) { - return post_update; + post_fast_step_action = A; } - std::function get_rhs () + void set_post_update (std::function A) { - return Fun; + set_post_stage_action(A); + set_post_step_action(A); } - std::function get_fast_rhs () + amrex::Real get_time_step () { - return FastFun; + return time_step; } - int get_slow_fast_timestep_ratio () + void set_time_step (amrex::Real dt) { - return slow_fast_timestep_ratio; + time_step = dt; + use_adaptive_time_step = false; } - Real get_fast_timestep () + void set_adaptive_step () { - return fast_timestep; + use_adaptive_time_step = true; } - void rhs (T& S_rhs, const T& S_data, const amrex::Real time) + void set_fast_time_step (amrex::Real dt) { - Fun(S_rhs, S_data, time); + fast_time_step = dt; + use_adaptive_fast_time_step = false; } - void fast_rhs (T& S_rhs, T& S_extra, const T& S_data, const amrex::Real time) + void set_adaptive_fast_step () { - FastFun(S_rhs, S_extra, S_data, time); + use_adaptive_fast_time_step = true; } - virtual amrex::Real advance (T& S_old, T& S_new, amrex::Real time, amrex::Real dt) = 0; + void set_max_steps (int steps) + { + max_steps = steps; + } + + void set_tolerances (amrex::Real rtol, amrex::Real atol) + { + rel_tol = rtol; + abs_tol = atol; + } + + void set_fast_tolerances (amrex::Real rtol, amrex::Real atol) + { + fast_rel_tol = rtol; + fast_abs_tol = atol; + } + + /** + * \brief Take a single time step from (time, S_old) to (time + dt, S_new) + * with the given step size. + */ + virtual amrex::Real advance (T& S_old, T& S_new, amrex::Real time, + amrex::Real dt) = 0; + + /** + * \brief Evolve the current (internal) integrator state to time_out + */ + virtual void evolve (T& S_out, const amrex::Real time_out) = 0; - virtual void time_interpolate (const T& S_new, const T& S_old, amrex::Real timestep_fraction, T& data) = 0; + virtual void time_interpolate (const T& S_new, const T& S_old, + amrex::Real timestep_fraction, T& data) = 0; virtual void map_data (std::function Map) = 0; }; diff --git a/Src/Base/AMReX_RKIntegrator.H b/Src/Base/AMReX_RKIntegrator.H index f1bc5c58151..f72890c808e 100644 --- a/Src/Base/AMReX_RKIntegrator.H +++ b/Src/Base/AMReX_RKIntegrator.H @@ -23,15 +23,31 @@ class RKIntegrator : public IntegratorBase private: using BaseT = IntegratorBase; + // Butcher tableau identifiers ButcherTableauTypes tableau_type; + + // Number of RK stages int number_nodes; - bool use_adaptive_timestep; - amrex::Vector > F_nodes; + + // A matrix from Butcher tableau amrex::Vector > tableau; + + // b vector from Butcher tableau amrex::Vector weights; - amrex::Vector extended_weights; + + // c vector from Butcher tableau amrex::Vector nodes; + // RK embedded method b vector + amrex::Vector extended_weights; + + // RK stage right-hand sides + amrex::Vector > F_nodes; + + // Current (internal) state and time + amrex::Vector > S_current; + amrex::Real time_current; + void initialize_preset_tableau () { switch (tableau_type) @@ -85,14 +101,12 @@ private: pp.get("type", _tableau_type); tableau_type = static_cast(_tableau_type); - // By default, define no extended weights and no adaptive timestepping + // By default, define no extended weights extended_weights = {}; - use_adaptive_timestep = false; - pp.queryAdd("use_adaptive_timestep", use_adaptive_timestep); if (tableau_type == ButcherTableauTypes::User) { - // Read weights/nodes/butcher tableau" + // Read weights/nodes/butcher tableau pp.getarr("weights", weights); pp.queryarr("extended_weights", extended_weights); pp.getarr("nodes", nodes); @@ -143,35 +157,41 @@ private: } } - void initialize_stages (const T& S_data) + void initialize_stages (const T& S_data, const amrex::Real time) { // Create data for stage RHS for (int i = 0; i < number_nodes; ++i) { IntegratorOps::CreateLike(F_nodes, S_data); } + + // Create and initialize data for current state + IntegratorOps::CreateLike(S_current, S_data, true); + IntegratorOps::Copy(*S_current[0], S_data); + + // Set the initial time + time_current = time; } public: RKIntegrator () {} - RKIntegrator (const T& S_data) + RKIntegrator (const T& S_data, const amrex::Real time = 0.0) { - initialize(S_data); + initialize(S_data, time); } - void initialize (const T& S_data) override + void initialize (const T& S_data, const amrex::Real time = 0.0) { initialize_parameters(); - initialize_stages(S_data); + initialize_stages(S_data, time); } virtual ~RKIntegrator () {} - amrex::Real advance (T& S_old, T& S_new, amrex::Real time, const amrex::Real time_step) override + amrex::Real advance (T& S_old, T& S_new, amrex::Real time, const amrex::Real dt) override { - BaseT::timestep = time_step; - // Assume before advance() that S_old is valid data at the current time ("time" argument) + // Assume before step() that S_old is valid data at the current time ("time" argument) // And that if data is a MultiFab, both S_old and S_new contain ghost cells for evaluating a stencil based RHS // We need this from S_old. This is convenient for S_new to have so we can use it // as scratch space for stage values without creating a new scratch MultiFab with ghost cells. @@ -180,7 +200,7 @@ public: for (int i = 0; i < number_nodes; ++i) { // Get current stage time, t = t_old + h * Ci - amrex::Real stage_time = time + BaseT::timestep * nodes[i]; + amrex::Real stage_time = time + dt * nodes[i]; // Fill S_new with the solution value for evaluating F at the current stage // Copy S_new = S_old @@ -191,16 +211,18 @@ public: // We should fuse these kernels ... for (int j = 0; j < i; ++j) { - IntegratorOps::Saxpy(S_new, BaseT::timestep * tableau[i][j], *F_nodes[j]); + IntegratorOps::Saxpy(S_new, dt * tableau[i][j], *F_nodes[j]); } - // Call the post-update hook for the stage state value - BaseT::post_update(S_new, stage_time); + BaseT::post_stage_action(S_new, stage_time); } + // Call the update hook for the stage state value + BaseT::pre_rhs_action(S_new, stage_time); + // Fill F[i], the RHS at the current stage // F[i] = RHS(y, t) at y = stage_value, t = stage_time - BaseT::rhs(*F_nodes[i], S_new, stage_time); + BaseT::Rhs(*F_nodes[i], S_new, stage_time); } // Fill new State, starting with S_new = S_old. @@ -209,17 +231,50 @@ public: IntegratorOps::Copy(S_new, S_old); for (int i = 0; i < number_nodes; ++i) { - IntegratorOps::Saxpy(S_new, BaseT::timestep * weights[i], *F_nodes[i]); + IntegratorOps::Saxpy(S_new, dt * weights[i], *F_nodes[i]); } - // Call the post-update hook for S_new - BaseT::post_update(S_new, time + BaseT::timestep); + BaseT::post_step_action(S_new, time + dt); // If we are working with an extended Butcher tableau, we can estimate the error here, - // and then calculate an adaptive timestep. + // and then calculate an adaptive time step. + + // Save last completed step size for time_interpolate + BaseT::previous_time_step = dt; - // Return timestep - return BaseT::timestep; + // Return time step + return dt; + } + + void evolve (T& S_out, const amrex::Real time_out) override + { + amrex::Real dt = BaseT::time_step; + bool stop = false; + + for (int step_number = 0; step_number < BaseT::max_steps && !stop; ++step_number) + { + // Adjust step size to reach output time + if (time_out - time_current < dt) { + dt = time_out - time_current; + stop = true; + } + + // Call the time integrator step + advance(*S_current[0], S_out, time_current, dt); + + // Update current state S_current = S_out + IntegratorOps::Copy(*S_current[0], S_out); + + // Update time + time_current += dt; + + // Save last completed step size for time_interpolate + BaseT::previous_time_step = dt; + + if (step_number == BaseT::max_steps - 1) { + Error("Did not reach output time in max steps."); + } + } } void time_interpolate (const T& /* S_new */, const T& S_old, amrex::Real timestep_fraction, T& data) override @@ -232,7 +287,6 @@ public: IntegratorOps::Saxpy(data, timestep_fraction, S_new); */ - // currently we only do this for 4th order RK AMREX_ASSERT(number_nodes == 4); @@ -244,19 +298,19 @@ public: // data += (chi - 3/2 * chi^2 + 2/3 * chi^3) * k1 c = timestep_fraction - 1.5 * std::pow(timestep_fraction, 2) + 2./3. * std::pow(timestep_fraction, 3); - IntegratorOps::Saxpy(data, c*BaseT::timestep, *F_nodes[0]); + IntegratorOps::Saxpy(data, c*BaseT::previous_time_step, *F_nodes[0]); // data += (chi^2 - 2/3 * chi^3) * k2 c = std::pow(timestep_fraction, 2) - 2./3. * std::pow(timestep_fraction, 3); - IntegratorOps::Saxpy(data, c*BaseT::timestep, *F_nodes[1]); + IntegratorOps::Saxpy(data, c*BaseT::previous_time_step, *F_nodes[1]); // data += (chi^2 - 2/3 * chi^3) * k3 c = std::pow(timestep_fraction, 2) - 2./3. * std::pow(timestep_fraction, 3); - IntegratorOps::Saxpy(data, c*BaseT::timestep, *F_nodes[2]); + IntegratorOps::Saxpy(data, c*BaseT::previous_time_step, *F_nodes[2]); // data += (-1/2 * chi^2 + 2/3 * chi^3) * k4 c = -0.5 * std::pow(timestep_fraction, 2) + 2./3. * std::pow(timestep_fraction, 3); - IntegratorOps::Saxpy(data, c*BaseT::timestep, *F_nodes[3]); + IntegratorOps::Saxpy(data, c*BaseT::previous_time_step, *F_nodes[3]); } diff --git a/Src/Base/AMReX_TimeIntegrator.H b/Src/Base/AMReX_TimeIntegrator.H index 6ac11107f91..10443361533 100644 --- a/Src/Base/AMReX_TimeIntegrator.H +++ b/Src/Base/AMReX_TimeIntegrator.H @@ -25,10 +25,7 @@ template class TimeIntegrator { private: - amrex::Real m_time, m_timestep; - int m_step_number; std::unique_ptr > integrator_ptr; - std::function post_timestep; IntegratorTypes read_parameters () { @@ -67,21 +64,25 @@ private: void set_default_functions () { - // By default, do nothing post-timestep - set_post_timestep([](){}); + // By default, do nothing in the RHS + set_rhs([](T& /* S_rhs */, const T& /* S_data */, const amrex::Real /* time */){}); + set_imex_rhs([](T& /* S_rhs */, const T& /* S_data */, const amrex::Real /* time */){}, + [](T& /* S_rhs */, const T& /* S_data */, const amrex::Real /* time */){}); + set_fast_rhs([](T& /* S_rhs */, const T& /* S_data */, const amrex::Real /* time */){}); - // By default, do nothing after updating the state - // In general, this is where BCs should be filled - set_post_update([](T& /* S_data */, amrex::Real /* S_time */){}); + // In general, the following functions can be used to fill BCs. Which + // function to set will depend on the method type and intended use case - // By default, do nothing - set_rhs([](T& /* S_rhs */, const T& /* S_data */, const amrex::Real /* time */){}); - set_fast_rhs([](T& /* S_rhs */, T& /* S_extra */, const T& /* S_data */, const amrex::Real /* time */){}); + // By default, do nothing before calling the RHS + set_pre_rhs_action([](T& /* S_data */, amrex::Real /* time */){}); + + // By default, do nothing after a stage or step + set_post_stage_action([](T& /* S_data */, const amrex::Real /* time */){}); + set_post_step_action([](T& /* S_data */, const amrex::Real /* time */){}); - // By default, initialize time, timestep, step number to 0's - m_time = 0.0_rt; - m_timestep = 0.0_rt; - m_step_number = 0; + // By default, do nothing after a stage or step + set_post_fast_stage_action([](T& /* S_data */, const amrex::Real /* time */){}); + set_post_fast_step_action([](T& /* S_data */, const amrex::Real /* time */){}); } public: @@ -91,20 +92,20 @@ public: set_default_functions(); } - TimeIntegrator (IntegratorTypes integrator_type, const T& S_data) + TimeIntegrator (IntegratorTypes integrator_type, const T& S_data, const amrex::Real time = 0.0) { // initialize the integrator class corresponding to the desired type - initialize_integrator(integrator_type, S_data); + initialize_integrator(integrator_type, S_data, time); // initialize functions to do nothing set_default_functions(); } - TimeIntegrator (const T& S_data) + TimeIntegrator (const T& S_data, const amrex::Real time = 0.0) { // initialize the integrator class corresponding to the input parameter selection IntegratorTypes integrator_type = read_parameters(); - initialize_integrator(integrator_type, S_data); + initialize_integrator(integrator_type, S_data, time); // initialize functions to do nothing set_default_functions(); @@ -112,19 +113,19 @@ public: virtual ~TimeIntegrator () {} - void initialize_integrator (IntegratorTypes integrator_type, const T& S_data) + void initialize_integrator (IntegratorTypes integrator_type, const T& S_data, const amrex::Real time = 0.0) { switch (integrator_type) { case IntegratorTypes::ForwardEuler: - integrator_ptr = std::make_unique >(S_data); + integrator_ptr = std::make_unique >(S_data, time); break; case IntegratorTypes::ExplicitRungeKutta: - integrator_ptr = std::make_unique >(S_data); + integrator_ptr = std::make_unique >(S_data, time); break; #ifdef AMREX_USE_SUNDIALS case IntegratorTypes::Sundials: - integrator_ptr = std::make_unique >(S_data); + integrator_ptr = std::make_unique >(S_data, time); break; #endif default: @@ -133,97 +134,113 @@ public: } } - void set_post_timestep (std::function F) + void set_rhs (std::function F) { - post_timestep = F; + integrator_ptr->set_rhs(F); } - void set_post_update (std::function F) + void set_imex_rhs (std::function Fi, + std::function Fe) { - integrator_ptr->set_post_update(F); + integrator_ptr->set_imex_rhs(Fi, Fe); } - void set_rhs (std::function F) + void set_fast_rhs (std::function F) { - integrator_ptr->set_rhs(F); + integrator_ptr->set_fast_rhs(F); } - void set_fast_rhs (std::function F) + void set_pre_rhs_action (std::function A) { - integrator_ptr->set_fast_rhs(F); + integrator_ptr->set_pre_rhs_action(A); } - void set_slow_fast_timestep_ratio (const int timestep_ratio = 1) + void set_post_stage_action (std::function A) { - integrator_ptr->set_slow_fast_timestep_ratio(timestep_ratio); + integrator_ptr->set_post_stage_action(A); } - void set_fast_timestep (const Real fast_dt = 1.0) + void set_post_step_action (std::function A) { - integrator_ptr->set_fast_timestep(fast_dt); + integrator_ptr->set_post_step_action(A); } - Real get_fast_timestep () + void set_post_fast_stage_action (std::function A) { - return integrator_ptr->get_fast_timestep(); + integrator_ptr->set_post_fast_stage_action(A); } - int get_step_number () + void set_post_fast_step_action (std::function A) { - return m_step_number; + integrator_ptr->set_post_fast_step_action(A); } - amrex::Real get_time () + void set_post_update (std::function A) { - return m_time; + integrator_ptr->set_post_update(A); } - amrex::Real get_timestep () + amrex::Real get_time_step () { - return m_timestep; + return integrator_ptr->get_time_step(); } - void set_timestep (amrex::Real dt) + void set_time_step (amrex::Real dt) { - m_timestep = dt; + integrator_ptr->set_time_step(dt); } - std::function get_post_timestep () + void set_adaptive_step () { - return post_timestep; + integrator_ptr->set_adaptive_step(); } - std::function get_post_update () + void set_fast_time_step (amrex::Real dt) { - return integrator_ptr->get_post_update(); + integrator_ptr->set_fast_time_step(dt); } - std::function get_rhs () + void set_adaptive_fast_step () { - return integrator_ptr->get_rhs(); + integrator_ptr->set_adaptive_fast_step(); } - std::function get_fast_rhs () + void set_max_steps (int steps) { - return integrator_ptr->get_fast_rhs(); + integrator_ptr->set_max_steps(steps); } - void advance (T& S_old, T& S_new, amrex::Real time, const amrex::Real timestep) + void set_tolerances (amrex::Real rtol, amrex::Real atol) { - integrator_ptr->advance(S_old, S_new, time, timestep); + integrator_ptr->set_tolerances(rtol, atol); + } + + void set_fast_tolerances (amrex::Real rtol, amrex::Real atol) + { + integrator_ptr->set_fast_tolerances(rtol, atol); + } + + void advance (T& S_old, T& S_new, amrex::Real time, const amrex::Real dt) + { + integrator_ptr->advance(S_old, S_new, time, dt); + } + + void evolve (T& S_out, const amrex::Real time_out) + { + integrator_ptr->evolve(S_out, time_out); } void integrate (T& S_old, T& S_new, amrex::Real start_time, const amrex::Real start_timestep, const amrex::Real end_time, const int start_step, const int max_steps) { - m_time = start_time; - m_timestep = start_timestep; - bool stop_advance = false; - for (m_step_number = start_step; m_step_number < max_steps && !stop_advance; ++m_step_number) + amrex::Real m_time = start_time; + amrex::Real m_timestep = start_timestep; + bool stop = false; + for (int m_step_number = start_step; m_step_number < max_steps && !stop; ++m_step_number) { if (end_time - m_time < m_timestep) { m_timestep = end_time - m_time; - stop_advance = true; + stop = true; } if (m_step_number > 0) { @@ -235,9 +252,6 @@ public: // Update our time variable m_time += m_timestep; - - // Call the post-timestep hook - post_timestep(); } } diff --git a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H index 11d73c9920c..30ff30a499b 100644 --- a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H +++ b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H @@ -1,65 +1,84 @@ #ifndef AMREX_SUNDIALS_INTEGRATOR_H #define AMREX_SUNDIALS_INTEGRATOR_H + +#include + #include #include #include #include #include -#include -#include /* prototypes for ERKStep fcts., consts */ -#include /* prototypes for ARKStep fcts., consts */ -#include /* prototypes for MRIStep fcts., consts */ -#include /* access to CVODE solver */ -#include /* manyvector N_Vector types, fcts. etc */ -#include /* MultiFab N_Vector types, fcts., macros */ -#include /* MultiFab N_Vector types, fcts., macros */ -#include /* access to SPGMR SUNLinearSolver */ -#include /* access to SPGMR SUNLinearSolver */ -#include /* access to FixedPoint SUNNonlinearSolver */ -#include /* defs. of sunrealtype, sunindextype, etc */ +#include +#include + +#include +#include +#include +#include namespace amrex { struct SundialsUserData { - std::function f0; - std::function f_fast; - std::function f; - /* std::function StoreStage; */ - std::function ProcessStage; - std::function PostStoreStage; + // ERK or DIRK right-hand side function + // EX-MRI or IM-MRI slow right-hand side function + std::function f; + + // ImEx-RK right-hand side functions + // ImEx-MRI slow right-hand side functions + std::function fi; + std::function fe; + + // MRI fast time scale right-hand side function + std::function ff; + + // Post stage and step actions + std::function post_stage; + std::function post_step; + + // Post fast stage and step actions + std::function post_fast_stage; + std::function post_fast_step; }; namespace SundialsUserFun { - static int f0 (sunrealtype t, N_Vector y, N_Vector ydot, void *user_data) { + static int f (amrex::Real t, N_Vector y_data, N_Vector y_rhs, void *user_data) { SundialsUserData* udata = static_cast(user_data); - return udata->f0(t, y, ydot, user_data); + return udata->f(t, y_data, y_rhs, user_data); } - static int f_fast (sunrealtype t, N_Vector y_data, N_Vector y_rhs, void *user_data) { + static int fi (amrex::Real t, N_Vector y_data, N_Vector y_rhs, void *user_data) { SundialsUserData* udata = static_cast(user_data); - return udata->f_fast(t, y_data, y_rhs, user_data); + return udata->fi(t, y_data, y_rhs, user_data); } - static int f (sunrealtype t, N_Vector y_data, N_Vector y_rhs, void *user_data) { + static int fe (amrex::Real t, N_Vector y_data, N_Vector y_rhs, void *user_data) { SundialsUserData* udata = static_cast(user_data); - return udata->f(t, y_data, y_rhs, user_data); + return udata->fe(t, y_data, y_rhs, user_data); + } + + static int ff (amrex::Real t, N_Vector y_data, N_Vector y_rhs, void *user_data) { + SundialsUserData* udata = static_cast(user_data); + return udata->ff(t, y_data, y_rhs, user_data); } -/* - static int StoreStage (sunrealtype t, N_Vector* f_data, int nvecs, void *user_data) { + static int post_stage (amrex::Real t, N_Vector y_data, void *user_data) { SundialsUserData* udata = static_cast(user_data); - return udata->StoreStage(t, f_data, nvecs, user_data); + return udata->post_stage(t, y_data, user_data); } -*/ - static int ProcessStage (sunrealtype t, N_Vector y_data, void *user_data) { + static int post_step (amrex::Real t, N_Vector y_data, void *user_data) { SundialsUserData* udata = static_cast(user_data); - return udata->ProcessStage(t, y_data, user_data); + return udata->post_step(t, y_data, user_data); } - static int PostStoreStage(sunrealtype t, N_Vector y_data, void *user_data) { + static int post_fast_stage (amrex::Real t, N_Vector y_data, void *user_data) { SundialsUserData* udata = static_cast(user_data); - return udata->PostStoreStage(t, y_data, user_data); + return udata->post_fast_stage(t, y_data, user_data); + } + + static int post_fast_step (amrex::Real t, N_Vector y_data, void *user_data) { + SundialsUserData* udata = static_cast(user_data); + return udata->post_fast_step(t, y_data, user_data); } } @@ -67,736 +86,527 @@ template class SundialsIntegrator : public IntegratorBase { private: - amrex::Real timestep; using BaseT = IntegratorBase; - bool use_erk_strategy; - bool use_mri_strategy; - bool use_mri_strategy_test; - bool use_cvode_strategy; - bool use_implicit_inner; - - SUNNonlinearSolver NLS; /* empty nonlinear solver object */ - SUNLinearSolver LS; /* empty linear solver object */ - void *arkode_mem; /* empty ARKode memory structure */ - void *cvode_mem; /* empty CVODE memory structure */ - SUNNonlinearSolver NLSf; /* empty nonlinear solver object */ - SUNLinearSolver LSf; /* empty linear solver object */ - void *inner_mem; /* empty ARKode memory structure */ - void *mristep_mem; /* empty ARKode memory structure */ - MPI_Comm mpi_comm; /* the MPI communicator */ - SUNContext sunctx; /* SUNDIALS Context object */ - - std::string mri_outer_method, mri_inner_method, erk_method; - - Real reltol; - Real abstol; - Real t; - Real tout; - Real hfixed; - Real hfixed_mri; - - int NVar; // NOTE: expects S_data to be a Vector - N_Vector* nv_many_arr; /* vector array composed of cons, xmom, ymom, zmom component vectors */ - N_Vector nv_S; - N_Vector nv_stage_data; + // Method type: ERK, DIRK, IMEX-RK, EX-MRI, IM-MRI, IMEX-MRI + std::string type = "ERK"; + + // Use SUNDIALS default methods + std::string method = "DEFAULT"; + std::string method_e = "DEFAULT"; + std::string method_i = "DEFAULT"; + + // Fast method type (ERK or DIRK) and method + std::string fast_type = "ERK"; + std::string fast_method = "DEFAULT"; + + // SUNDIALS package flags, set based on type + bool use_ark = false; + bool use_mri = false; + + // structure for interfacing with user-supplied functions + SundialsUserData udata; + + // SUNDIALS context + // + // We should probably use context created by amrex:sundials::Initialize but + // that context is not MPI-aware + ::sundials::Context sunctx; + + // Single rate or slow time scale + void *arkode_mem = nullptr; + SUNLinearSolver LS = nullptr; + + // Fast time scale + void *arkode_fast_mem = nullptr; + MRIStepInnerStepper fast_stepper = nullptr; + SUNLinearSolver fast_LS = nullptr; void initialize_parameters () { - use_erk_strategy=false; - use_mri_strategy=false; - use_mri_strategy_test=false; - use_cvode_strategy=false; - amrex::ParmParse pp("integration.sundials"); - std::string theStrategy; + pp.query("type", type); + pp.query("method", method); + pp.query("method_e", method); + pp.query("method_i", method); - pp.get("strategy", theStrategy); + pp.query("fast_type", fast_type); + pp.query("fast_method", fast_method); - if (theStrategy == "ERK") - { - use_erk_strategy=true; - erk_method = "SSPRK3"; - amrex::ParmParse pp_erk("integration.sundials.erk"); - pp_erk.query("method", erk_method); + if (type == "ERK" || type == "DIRK" || type == "IMEX-RK") { + use_ark = true; } - else if (theStrategy == "MRI") - { - use_mri_strategy=true; + else if (type == "EX-MRI" || type == "IM-MRI" || type == "IMEX-MRI") { + use_mri = true; } - else if (theStrategy == "MRITEST") - { - use_mri_strategy=true; - use_mri_strategy_test=true; - } - else if (theStrategy == "CVODE") - { - use_cvode_strategy=true; - } - else - { - std::string msg("Unknown strategy: "); - msg += theStrategy; + else { + std::string msg("Unknown method type: "); + msg += type; amrex::Error(msg.c_str()); } - - if (theStrategy == "MRI" || theStrategy == "MRITEST") - { - use_implicit_inner = false; - mri_outer_method = "KnothWolke3"; - mri_inner_method = "ForwardEuler"; - amrex::ParmParse pp_mri("integration.sundials.mri"); - pp_mri.query("implicit_inner", use_implicit_inner); - pp_mri.query("outer_method", mri_outer_method); - pp_mri.query("inner_method", mri_inner_method); - } - - // SUNDIALS specific objects - NLS = nullptr; /* empty nonlinear solver object */ - LS = nullptr; /* empty linear solver object */ - arkode_mem = nullptr; /* empty ARKode memory structure */ - cvode_mem = nullptr; /* empty CVODE memory structure */ - NLSf = nullptr; /* empty nonlinear solver object */ - LSf = nullptr; /* empty linear solver object */ - inner_mem = nullptr; /* empty ARKode memory structure */ - mristep_mem = nullptr; /* empty ARKode memory structure */ - - // Arbitrary tolerances - reltol = 1e-4; - abstol = 1e-4; } -public: - SundialsIntegrator () {} - - SundialsIntegrator (const T& /* S_data */) + void SetupRK (amrex::Real time, N_Vector y_data) { - initialize(); - } + // Create integrator and select method + if (type == "ERK") { + amrex::Print() << "SUNDIALS ERK time integrator\n"; + arkode_mem = ARKStepCreate(SundialsUserFun::f, nullptr, time, y_data, sunctx); + + if (method != "DEFAULT") { + amrex::Print() << "SUNDIALS ERK method " << method << "\n"; + ARKStepSetTableName(arkode_mem, "ARKODE_DIRK_NONE", method.c_str()); + } + } + else if (type == "DIRK") { + amrex::Print() << "SUNDIALS DIRK time integrator\n"; + arkode_mem = ARKStepCreate(nullptr, SundialsUserFun::f, time, y_data, sunctx); - void initialize (const T& /* S_data */) override - { - initialize_parameters(); - mpi_comm = ParallelContext::CommunicatorSub(); -#if defined(SUNDIALS_VERSION_MAJOR) && (SUNDIALS_VERSION_MAJOR < 7) - SUNContext_Create(&mpi_comm, &sunctx); -#else -# ifdef AMREX_USE_MPI - SUNContext_Create(mpi_comm, &sunctx); -# else - SUNContext_Create(SUN_COMM_NULL, &sunctx); -# endif -#endif - } + if (method != "DEFAULT") { + amrex::Print() << "SUNDIALS DIRK method " << method << "\n"; + ARKStepSetTableName(arkode_mem, method.c_str(), "ARKODE_ERK_NONE"); + } + } + else if (type == "IMEX-RK") { + amrex::Print() << "SUNDIALS IMEX time integrator\n"; + arkode_mem = ARKStepCreate(SundialsUserFun::fe, SundialsUserFun::fi, time, y_data, sunctx); - void initialize () - { - initialize_parameters(); - mpi_comm = ParallelContext::CommunicatorSub(); -#if defined(SUNDIALS_VERSION_MAJOR) && (SUNDIALS_VERSION_MAJOR < 7) - SUNContext_Create(&mpi_comm, &sunctx); -#else -# ifdef AMREX_USE_MPI - SUNContext_Create(mpi_comm, &sunctx); -# else - SUNContext_Create(SUN_COMM_NULL, &sunctx); -# endif -#endif - } + if (method_e != "DEFAULT" && method_i != "DEFAULT") + { + amrex::Print() << "SUNDIALS IMEX method " << method_i << " and " + << method_e << "\n"; + ARKStepSetTableName(arkode_mem, method_i.c_str(), method_e.c_str()); + } + } - virtual ~SundialsIntegrator () { - SUNContext_Free(&sunctx); - } + // Attach structure with user-supplied function wrappers + ARKStepSetUserData(arkode_mem, &udata); - amrex::Real advance (T& S_old, T& S_new, amrex::Real time, const amrex::Real time_step) override - { - if (use_mri_strategy) { - return advance_mri(S_old, S_new, time, time_step); - } else if (use_erk_strategy) { - return advance_erk(S_old, S_new, time, time_step); - } else if (use_cvode_strategy) { - return advance_cvode(S_old, S_new, time, time_step); - }else { - Error("SUNDIALS integrator backend not specified (ERK, MRI, or CVODE)."); + // Set integrator tolerances + ARKStepSStolerances(arkode_mem, BaseT::rel_tol, BaseT::abs_tol); + + // Create and attach linear solver for implicit methods + if (type == "DIRK" || type == "IMEX-RK") { + LS = SUNLinSol_SPGMR(y_data, SUN_PREC_NONE, 0, sunctx); + ARKStepSetLinearSolver(arkode_mem, LS, nullptr); } - return 0; + // Set post stage and step function + ARKStepSetPostprocessStageFn(arkode_mem, SundialsUserFun::post_stage); + ARKStepSetPostprocessStepFn(arkode_mem, SundialsUserFun::post_step); } - amrex::Real advance_erk (T& S_old, T& S_new, amrex::Real time, const amrex::Real time_step) + void SetupMRI (amrex::Real time, N_Vector y_data) { - t = time; - tout = time+time_step; - hfixed = time_step; - timestep = time_step; - - // We use S_new as our working space, so first copy S_old to S_new - IntegratorOps::Copy(S_new, S_old); - - // Create an N_Vector wrapper for the solution MultiFab - auto get_length = [&](int index) -> sunindextype { - auto* p_mf = &S_new[index]; - return p_mf->nComp() * (p_mf->boxArray()).numPts(); - }; + // Create the fast integrator and select method + if (fast_type == "ERK") { + amrex::Print() << "SUNDIALS ERK time integrator\n"; + arkode_fast_mem = ARKStepCreate(SundialsUserFun::ff, nullptr, time, y_data, sunctx); + + if (method != "DEFAULT") { + amrex::Print() << "SUNDIALS ERK method " << method << "\n"; + ARKStepSetTableName(arkode_fast_mem, "ARKODE_DIRK_NONE", fast_method.c_str()); + } + } + else if (fast_type == "DIRK") { + amrex::Print() << "SUNDIALS DIRK time integrator\n"; + arkode_fast_mem = ARKStepCreate(nullptr, SundialsUserFun::ff, time, y_data, sunctx); - /* Create manyvector for solution using S_new */ - NVar = S_new.size(); // NOTE: expects S_new to be a Vector - nv_many_arr = new N_Vector[NVar]; // vector array composed of cons, xmom, ymom, zmom component vectors */ + if (method != "DEFAULT") { + amrex::Print() << "SUNDIALS DIRK method " << method << "\n"; + ARKStepSetTableName(arkode_fast_mem, fast_method.c_str(), "ARKODE_ERK_NONE"); + } - for (int i = 0; i < NVar; ++i) { - sunindextype length = get_length(i); - N_Vector nvi = amrex::sundials::N_VMake_MultiFab(length, &S_new[i]); - nv_many_arr[i] = nvi; + fast_LS = SUNLinSol_SPGMR(y_data, SUN_PREC_NONE, 0, sunctx); + ARKStepSetLinearSolver(arkode_fast_mem, fast_LS, nullptr); } - nv_S = N_VNew_ManyVector(NVar, nv_many_arr, sunctx); - nv_stage_data = N_VClone(nv_S); + // Attach structure with user-supplied function wrappers + ARKStepSetUserData(arkode_fast_mem, &udata); - /* Create a temporary storage space for MRI */ - Vector > temp_storage; - IntegratorOps::CreateLike(temp_storage, S_old); - T& state_store = *temp_storage.back(); + // Set integrator tolerances + ARKStepSStolerances(arkode_fast_mem, BaseT::fast_rel_tol, BaseT::fast_abs_tol); - SundialsUserData udata; + // Set post stage and step function + ARKStepSetPostprocessStageFn(arkode_fast_mem, SundialsUserFun::post_fast_stage); + ARKStepSetPostprocessStepFn(arkode_fast_mem, SundialsUserFun::post_fast_step); - /* Begin Section: SUNDIALS FUNCTION HOOKS */ - /* f routine to compute the ODE RHS function f(t,y). */ - udata.f = [&](sunrealtype rhs_time, N_Vector y_data, N_Vector y_rhs, void * /* user_data */) -> int { - amrex::Vector S_data; - amrex::Vector S_rhs; + // Wrap fast integrator as an inner stepper + ARKStepCreateMRIStepInnerStepper(arkode_fast_mem, &fast_stepper); - const int num_vecs = N_VGetNumSubvectors_ManyVector(y_data); - S_data.resize(num_vecs); - S_rhs.resize(num_vecs); + // Create slow integrator + if (type == "EX-MRI") { + amrex::Print() << "SUNDIALS ERK time integrator\n"; + arkode_mem = MRIStepCreate(SundialsUserFun::f, nullptr, time, y_data, + fast_stepper, sunctx); + } + else if (type == "IM-MRI") { + amrex::Print() << "SUNDIALS DIRK time integrator\n"; + arkode_mem = MRIStepCreate(nullptr, SundialsUserFun::f, time, y_data, + fast_stepper, sunctx); + } + else if (type == "IMEX-MRI") { + amrex::Print() << "SUNDIALS IMEX time integrator\n"; + arkode_mem = MRIStepCreate(SundialsUserFun::fe, SundialsUserFun::fi, + time, y_data, fast_stepper, sunctx); + } - for(int i=0; inComp()); - S_rhs.at(i)=amrex::MultiFab(*amrex::sundials::getMFptr(N_VGetSubvector_ManyVector(y_rhs, i)),amrex::make_alias,0,amrex::sundials::getMFptr(N_VGetSubvector_ManyVector(y_rhs, i))->nComp()); - } + // Set method + if (method != "DEFAULT") { + MRIStepCoupling MRIC = MRIStepCoupling_LoadTableByName(method.c_str()); + MRIStepSetCoupling(arkode_mem, MRIC); + MRIStepCoupling_Free(MRIC); + } - BaseT::post_update(S_data, rhs_time); - BaseT::rhs(S_rhs, S_data, rhs_time); + // Attach structure with user-supplied function wrappers + MRIStepSetUserData(arkode_mem, &udata); - return 0; - }; + // Set integrator tolerances + MRIStepSStolerances(arkode_mem, BaseT::rel_tol, BaseT::abs_tol); - udata.ProcessStage = [&](sunrealtype rhs_time, N_Vector y_data, void * /* user_data */) -> int { - amrex::Vector S_data; + // Create and attach linear solver + if (type == "IM-MRI" || type == "IMEX-MRI") { + LS = SUNLinSol_SPGMR(y_data, SUN_PREC_NONE, 0, sunctx); + MRIStepSetLinearSolver(arkode_mem, LS, nullptr); + } - const int num_vecs = N_VGetNumSubvectors_ManyVector(y_data); - S_data.resize(num_vecs); + // Set post stage and step function + MRIStepSetPostprocessStageFn(arkode_mem, SundialsUserFun::post_stage); + MRIStepSetPostprocessStepFn(arkode_mem, SundialsUserFun::post_step); + } - for (int i=0; inComp()); - } + // ------------------------------------- + // Vector / N_Vector Utilities + // ------------------------------------- - BaseT::post_update(S_data, rhs_time); + // Utility to unpack a SUNDIALS ManyVector into a vector of MultiFabs + void unpack_vector (N_Vector y_data, amrex::Vector& S_data) + { + const int num_vecs = N_VGetNumSubvectors_ManyVector(y_data); + S_data.resize(num_vecs); - return 0; - }; - /* End Section: SUNDIALS FUNCTION HOOKS */ - - /* Call ERKStepCreate to initialize the inner ARK timestepper module and - specify the right-hand side function in y'=f(t,y), the initial time - T0, and the initial dependent variable vector y. */ - arkode_mem = ERKStepCreate(SundialsUserFun::f, time, nv_S, sunctx); - ERKStepSetUserData(arkode_mem, &udata); /* Pass udata to user functions */ - ERKStepSetPostprocessStageFn(arkode_mem, SundialsUserFun::ProcessStage); - /* Specify tolerances */ - ERKStepSStolerances(arkode_mem, reltol, abstol); - ERKStepSetFixedStep(arkode_mem, hfixed); - - for(int i=0; inComp()); } + }; - auto make_butcher_table = [&](std::string method) -> ARKodeButcherTable { - ARKodeButcherTable B; - if (method == "SSPRK3") { - B = ARKodeButcherTable_Alloc(3, SUNFALSE); - - // 3rd order Strong Stability Preserving RK3 - B->A[1][0] = 1.0; - B->A[2][0] = 0.25; - B->A[2][1] = 0.25; - B->b[0] = 1./6.; - B->b[1] = 1./6.; - B->b[2] = 2./3.; - B->c[1] = 1.0; - B->c[2] = 0.5; - B->q=3; - } else if (method == "Trapezoid") { - B = ARKodeButcherTable_Alloc(2, SUNFALSE); - - // Trapezoidal rule - B->A[1][0] = 1.0; - B->b[0] = 0.5; - B->b[1] = 0.5; - B->c[1] = 1.0; - B->q=2; - B->p=0; - } else if (method == "ForwardEuler") { - B = ARKodeButcherTable_Alloc(1, SUNFALSE); - - // Forward Euler - B->b[0] = 1.0; - B->q=1; - B->p=0; - } else - amrex::Error("ERK method not implemented"); - return B; + // Utility to wrap vector of MultiFabs as a SUNDIALS ManyVector + N_Vector wrap_data (amrex::Vector& S_data) + { + auto get_length = [&](int index) -> sunindextype { + auto* p_mf = &S_data[index]; + return p_mf->nComp() * (p_mf->boxArray()).numPts(); }; - ARKodeButcherTable B = make_butcher_table(erk_method); - - //Set table - ERKStepSetTable(arkode_mem, B); - - // Free the Butcher table - ARKodeButcherTable_Free(B); + sunindextype NV_len = S_data.size(); + N_Vector* NV_array = new N_Vector[NV_len]; - // Use ERKStep to evolve state_old data (wrapped in nv_S) from t to tout=t+dt - auto flag = ERKStepEvolve(arkode_mem, tout, nv_S, &t, ARK_NORMAL); - AMREX_ALWAYS_ASSERT(flag >= 0); - - // Copy the result stored in nv_S to state_new - for(int i=0; i& S_data) { - int mri_time_step_ratio = BaseT::get_slow_fast_timestep_ratio(); - Real mri_fast_time_step = BaseT::get_fast_timestep(); - AMREX_ALWAYS_ASSERT(mri_time_step_ratio >= 1 || mri_fast_time_step >= 0.0); - t = time; - tout = time+time_step; - hfixed = time_step; - hfixed_mri = mri_fast_time_step >= 0.0 ? mri_fast_time_step : time_step / mri_time_step_ratio; - timestep = time_step; - - // NOTE: hardcoded for now ... - bool use_erk3 = !use_implicit_inner; - bool use_linear = false; - - // We use S_new as our working space, so first copy S_old to S_new - IntegratorOps::Copy(S_new, S_old); - - // Create an N_Vector wrapper for the solution MultiFab auto get_length = [&](int index) -> sunindextype { - auto* p_mf = &S_new[index]; + auto* p_mf = &S_data[index]; return p_mf->nComp() * (p_mf->boxArray()).numPts(); }; - /* Create manyvector for solution using S_new */ - NVar = S_new.size(); // NOTE: expects S_new to be a Vector - nv_many_arr = new N_Vector[NVar]; // vector array composed of cons, xmom, ymom, zmom component vectors */ - - for (int i = 0; i < NVar; ++i) { - sunindextype length = get_length(i); - N_Vector nvi = amrex::sundials::N_VMake_MultiFab(length, &S_new[i]); - nv_many_arr[i] = nvi; + sunindextype NV_len = S_data.size(); + N_Vector* NV_array = new N_Vector[NV_len]; + + for (int i = 0; i < NV_len; ++i) { + NV_array[i] = amrex::sundials::N_VNew_MultiFab(get_length(i), + S_data[i].boxArray(), + S_data[i].DistributionMap(), + S_data[i].nComp(), + S_data[i].nGrow(), + &sunctx); + + MultiFab::Copy(*amrex::sundials::getMFptr(NV_array[i]), + S_data[i], + 0, + 0, + S_data[i].nComp(), + S_data[i].nGrow()); } - nv_S = N_VNew_ManyVector(NVar, nv_many_arr, sunctx); - nv_stage_data = N_VClone(nv_S); + N_Vector y_data = N_VNew_ManyVector(NV_len, NV_array, sunctx); - // Copy the initial step data to nv_stage_data - for(int i=0; inComp(), mf_y->nGrow()); - } + delete[] NV_array; - /* Create a temporary storage space for MRI */ - Vector > temp_storage; - IntegratorOps::CreateLike(temp_storage, S_old); - T& state_store = *temp_storage.back(); + return y_data; + }; - SundialsUserData udata; + // ----------------------------- + // MultiFab / N_Vector Utilities + // ----------------------------- - /* Begin Section: SUNDIALS FUNCTION HOOKS */ - /* f0 routine to compute a zero-valued ODE RHS function f(t,y). */ - udata.f0 = [&](sunrealtype /* rhs_time */, N_Vector /* y */, N_Vector ydot, void * /* user_data */) -> int { - // Initialize ydot to zero and return - N_VConst(0.0, ydot); - return 0; - }; + // Utility to unpack a SUNDIALS Vector into a MultiFab + void unpack_vector (N_Vector y_data, amrex::MultiFab& S_data) + { + S_data = amrex::MultiFab(*amrex::sundials::getMFptr(y_data), + amrex::make_alias, + 0, + amrex::sundials::getMFptr(y_data)->nComp()); + }; + + // Utility to wrap a MultiFab as a SUNDIALS Vector + N_Vector wrap_data (amrex::MultiFab& S_data) + { + return amrex::sundials::N_VMake_MultiFab(S_data.nComp() * S_data.boxArray().numPts(), + &S_data, &sunctx); + }; - /* f routine to compute the ODE RHS function f(t,y). */ - udata.f_fast = [&](sunrealtype rhs_time, N_Vector y_data, N_Vector y_rhs, void * /* user_data */) -> int { - amrex::Vector S_data; - amrex::Vector S_rhs; - amrex::Vector S_stage_data; + // Utility to wrap a MultiFab as a SUNDIALS Vector + N_Vector copy_data (const amrex::MultiFab& S_data) + { + N_Vector y_data = amrex::sundials::N_VNew_MultiFab(S_data.nComp() * S_data.boxArray().numPts(), + S_data.boxArray(), + S_data.DistributionMap(), + S_data.nComp(), + S_data.nGrow(), + &sunctx); + + MultiFab::Copy(*amrex::sundials::getMFptr(y_data), + S_data, + 0, + 0, + S_data.nComp(), + S_data.nGrow()); + + return y_data; + }; - N_VConst(0.0, y_rhs); +public: + SundialsIntegrator () {} - const int num_vecs = N_VGetNumSubvectors_ManyVector(y_data); - S_data.resize(num_vecs); - S_rhs.resize(num_vecs); - S_stage_data.resize(num_vecs); + SundialsIntegrator (const T& S_data, const amrex::Real time = 0.0) + { + initialize(S_data, time); + } - for(int i=0; inComp()); - S_rhs.at(i)=amrex::MultiFab(*amrex::sundials::getMFptr(N_VGetSubvector_ManyVector(y_rhs, i)),amrex::make_alias,0,amrex::sundials::getMFptr(N_VGetSubvector_ManyVector(y_rhs, i))->nComp()); - S_stage_data.at(i)=amrex::MultiFab(*amrex::sundials::getMFptr(N_VGetSubvector_ManyVector(nv_stage_data, i)),amrex::make_alias,0,amrex::sundials::getMFptr(N_VGetSubvector_ManyVector(nv_stage_data, i))->nComp()); - } + void initialize (const T& S_data, const amrex::Real time = 0.0) + { + initialize_parameters(); + MPI_Comm mpi_comm = ParallelContext::CommunicatorSub(); +#if defined(SUNDIALS_VERSION_MAJOR) && (SUNDIALS_VERSION_MAJOR < 7) +# ifdef AMREX_USE_MPI + sunctx = ::sundials::Context(&mpi_comm); +# else + sunctx = ::sundials::Context(nullptr); +# endif +#else +# ifdef AMREX_USE_MPI + sunctx = ::sundials::Context(mpi_comm); +# else + sunctx = ::sundials::Context(SUN_COMM_NULL); +# endif +#endif - // NOTE: we can optimize by calling a post_update_fast and only updating the variables the fast integration modifies - BaseT::post_update(S_data, rhs_time); + // Right-hand side function wrappers + udata.f = [&](amrex::Real rhs_time, N_Vector y_data, N_Vector y_rhs, + void * /* user_data */) -> int { - BaseT::fast_rhs(S_rhs, S_stage_data, S_data, rhs_time); + T S_data; + unpack_vector(y_data, S_data); + + T S_rhs; + unpack_vector(y_rhs, S_rhs); + + BaseT::pre_rhs_action(S_data, rhs_time); + BaseT::Rhs(S_rhs, S_data, rhs_time); return 0; }; - /* f routine to compute the ODE RHS function f(t,y). */ - udata.f = [&](sunrealtype rhs_time, N_Vector y_data, N_Vector y_rhs, void * /* user_data */) -> int { - amrex::Vector S_data; - amrex::Vector S_rhs; + udata.fi = [&](amrex::Real rhs_time, N_Vector y_data, N_Vector y_rhs, + void * /* user_data */) -> int { - const int num_vecs = N_VGetNumSubvectors_ManyVector(y_data); - S_data.resize(num_vecs); - S_rhs.resize(num_vecs); + T S_data; + unpack_vector(y_data, S_data); - for(int i=0; inComp()); - S_rhs.at(i)=amrex::MultiFab(*amrex::sundials::getMFptr(N_VGetSubvector_ManyVector(y_rhs, i)),amrex::make_alias,0,amrex::sundials::getMFptr(N_VGetSubvector_ManyVector(y_rhs, i))->nComp()); - } + T S_rhs; + unpack_vector(y_rhs, S_rhs); - BaseT::post_update(S_data, rhs_time); - BaseT::rhs(S_rhs, S_data, rhs_time); + BaseT::pre_rhs_action(S_data, rhs_time); + BaseT::RhsIm(S_rhs, S_data, rhs_time); return 0; }; - udata.ProcessStage = [&](sunrealtype rhs_time, N_Vector y_data, void * /* user_data */) -> int { - amrex::Vector S_data; + udata.fe = [&](amrex::Real rhs_time, N_Vector y_data, N_Vector y_rhs, + void * /* user_data */) -> int { - const int num_vecs = N_VGetNumSubvectors_ManyVector(y_data); - S_data.resize(num_vecs); + T S_data; + unpack_vector(y_data, S_data); - for (int i=0; inComp()); - } + T S_rhs; + unpack_vector(y_rhs, S_rhs); - BaseT::post_update(S_data, rhs_time); + BaseT::pre_rhs_action(S_data, rhs_time); + BaseT::RhsEx(S_rhs, S_data, rhs_time); return 0; }; - udata.PostStoreStage = [&](sunrealtype rhs_time, N_Vector y_data, void *user_data) -> int { - udata.ProcessStage(rhs_time, y_data, user_data); + udata.ff = [&](amrex::Real rhs_time, N_Vector y_data, N_Vector y_rhs, + void * /* user_data */) -> int { - for(int i=0; inComp(), mf_y->nGrow()); - } + T S_data; + unpack_vector(y_data, S_data); + + T S_rhs; + unpack_vector(y_rhs, S_rhs); + + BaseT::pre_rhs_action(S_data, rhs_time); + BaseT::RhsFast(S_rhs, S_data, rhs_time); return 0; }; - /* End Section: SUNDIALS FUNCTION HOOKS */ - if(use_mri_strategy_test) - { - if(use_erk3) { - inner_mem = ARKStepCreate(SundialsUserFun::f0, nullptr, time, nv_S, sunctx); // explicit bc (explicit f, implicit f, time, data) - } else { - inner_mem = ARKStepCreate(nullptr, SundialsUserFun::f0, time, nv_S, sunctx); // implicit - } - } - else - { - if(use_erk3) { - inner_mem = ARKStepCreate(SundialsUserFun::f_fast, nullptr, time, nv_S, sunctx); - } else { - inner_mem = ARKStepCreate(nullptr, SundialsUserFun::f_fast, time, nv_S, sunctx); - } - } + udata.post_stage = [&](amrex::Real time, N_Vector y_data, + void * /* user_data */) -> int { - ARKStepSetFixedStep(inner_mem, hfixed_mri); // Specify fixed time step size + T S_data; + unpack_vector(y_data, S_data); - ARKStepSetUserData(inner_mem, &udata); /* Pass udata to user functions */ - - for(int i=0; i ARKodeButcherTable { - ARKodeButcherTable B; - if (method == "KnothWolke3" || method == "Knoth-Wolke-3-3") { - B = ARKodeButcherTable_Alloc(3, SUNFALSE); - - // 3rd order Knoth-Wolke method - B->A[1][0] = 1.0/3.0; - B->A[2][0] = -3.0/16.0; - B->A[2][1] = 15.0/16.0; - B->b[0] = 1./6.; - B->b[1] = 3./10.; - B->b[2] = 8./15.; - B->c[1] = 1.0/3.0; - B->c[2] = 3.0/4.0; - B->q=3; - B->p=0; - } else if (method == "Trapezoid") { - B = ARKodeButcherTable_Alloc(2, SUNFALSE); - - // Trapezoidal rule - B->A[1][0] = 1.0; - B->b[0] = 0.5; - B->b[1] = 0.5; - B->c[1] = 1.0; - B->q=2; - B->p=0; - } else if (method == "ForwardEuler") { - B = ARKodeButcherTable_Alloc(1, SUNFALSE); - - // Forward Euler - B->b[0] = 1.0; - B->q=1; - B->p=0; - } else { - amrex::Error("MRI method not implemented"); - } - return B; + return 0; }; - ARKodeButcherTable B_outer = make_butcher_table(mri_outer_method); - ARKodeButcherTable B_inner = make_butcher_table(mri_inner_method); + udata.post_step = [&](amrex::Real time, N_Vector y_data, + void * /* user_data */) -> int { - if(use_erk3) - { - ARKStepSetTables(inner_mem, B_inner->q, B_inner->p, nullptr, B_inner); // Specify Butcher table - } else { - ARKodeButcherTable_Free(B_inner); - B_inner = ARKodeButcherTable_Alloc(2, SUNFALSE); - - B_inner->A[1][0] = 1.0; - B_inner->A[2][0] = 1.0; - B_inner->A[2][2] = 0.0; - B_inner->b[0] = 0.5; - B_inner->b[2] = 0.5; - B_inner->c[1] = 1.0; - B_inner->c[2] = 1.0; - B_inner->q=2; - ARKStepSetTables(inner_mem, B_inner->q, B_inner->p, B_inner, nullptr); // Specify Butcher table - } + T S_data; + unpack_vector(y_data, S_data); - //Set table - // Create fast time scale integrator from an ARKStep instance - MRIStepInnerStepper inner_stepper = nullptr; - ARKStepCreateMRIStepInnerStepper(inner_mem, &inner_stepper); + BaseT::post_step_action(S_data, time); - // args: fast RHS, nullptr, initial time, initial state, fast time scale integrator, sundials context - mristep_mem = MRIStepCreate(SundialsUserFun::f, nullptr, time, nv_S, inner_stepper, sunctx); + return 0; + }; - MRIStepSetFixedStep(mristep_mem, hfixed); + udata.post_fast_stage = [&](amrex::Real time, N_Vector y_data, + void * /* user_data */) -> int { - /* Specify tolerances */ - MRIStepSStolerances(mristep_mem, reltol, abstol); + T S_data; + unpack_vector(y_data, S_data); - /* Initialize spgmr solver */ -#if defined(SUNDIALS_VERSION_MAJOR) && (SUNDIALS_VERSION_MAJOR < 7) - LS = SUNLinSol_SPGMR(nv_S, PREC_NONE, 10, sunctx); -#else - LS = SUNLinSol_SPGMR(nv_S, SUN_PREC_NONE, 10, sunctx); -#endif - NLS = SUNNonlinSol_FixedPoint(nv_S, 50, sunctx); + BaseT::post_fast_stage_action(S_data, time); - if (use_implicit_inner) { ARKStepSetNonlinearSolver(inner_mem, NLS); } - if(use_linear) { - MRIStepSetLinearSolver(mristep_mem, LS, nullptr); - } else { - MRIStepSetNonlinearSolver(mristep_mem, NLS); - } + return 0; + }; - MRIStepSetUserData(mristep_mem, &udata); /* Pass udata to user functions */ - MRIStepSetPostprocessStageFn(mristep_mem, SundialsUserFun::ProcessStage); + udata.post_fast_step = [&](amrex::Real time, N_Vector y_data, + void * /* user_data */) -> int { - MRIStepCoupling mri_coupling = MRIStepCoupling_MIStoMRI(B_outer, B_outer->q, B_outer->p); - MRIStepSetCoupling(mristep_mem, mri_coupling); + T S_data; + unpack_vector(y_data, S_data); - // Free the Butcher tables - ARKodeButcherTable_Free(B_outer); - ARKodeButcherTable_Free(B_inner); + BaseT::post_fast_step_action(S_data, time); - // Use MRIStep to evolve state_old data (wrapped in nv_S) from t to tout=t+dt - auto flag = MRIStepEvolve(mristep_mem, tout, nv_S, &t, ARK_NORMAL); - AMREX_ALWAYS_ASSERT(flag >= 0); + return 0; + }; - // Copy the result stored in nv_S to state_new - for(int i=0; i::Copy(S_new, S_old); + amrex::Real tout = time + dt; + amrex::Real tret; - // Create an N_Vector wrapper for the solution MultiFab - auto get_length = [&](int index) -> sunindextype { - auto* p_mf = &S_new[index]; - return p_mf->nComp() * (p_mf->boxArray()).numPts(); - }; - - /* Create manyvector for solution using S_new */ - NVar = S_new.size(); // NOTE: expects S_new to be a Vector - nv_many_arr = new N_Vector[NVar]; // vector array composed of cons, xmom, ymom, zmom component vectors */ + N_Vector y_old = wrap_data(S_old); + N_Vector y_new = wrap_data(S_new); - for (int i = 0; i < NVar; ++i) { - sunindextype length = get_length(i); - N_Vector nvi = amrex::sundials::N_VMake_MultiFab(length, &S_new[i]); - nv_many_arr[i] = nvi; + if (use_ark) { + ARKStepReset(arkode_mem, time, y_old); // should probably resize + ARKStepSetFixedStep(arkode_mem, dt); + int flag = ARKStepEvolve(arkode_mem, tout, y_new, &tret, ARK_ONE_STEP); + AMREX_ALWAYS_ASSERT(flag >= 0); + } + else if (use_mri) { + MRIStepReset(arkode_mem, time, y_old); // should probably resize -- need to resize inner stepper + MRIStepSetFixedStep(arkode_mem, dt); + int flag = MRIStepEvolve(arkode_mem, tout, y_new, &tret, ARK_ONE_STEP); + AMREX_ALWAYS_ASSERT(flag >= 0); + } else { + Error("SUNDIALS integrator type not specified."); } - nv_S = N_VNew_ManyVector(NVar, nv_many_arr, sunctx); - nv_stage_data = N_VClone(nv_S); - - /* Create a temporary storage space for MRI */ - Vector > temp_storage; - IntegratorOps::CreateLike(temp_storage, S_old); - T& state_store = *temp_storage.back(); - - SundialsUserData udata; - - /* Begin Section: SUNDIALS FUNCTION HOOKS */ - /* f routine to compute the ODE RHS function f(t,y). */ - udata.f = [&](sunrealtype rhs_time, N_Vector y_data, N_Vector y_rhs, void * /* user_data */) -> int { - amrex::Vector S_data; - amrex::Vector S_rhs; - - const int num_vecs = N_VGetNumSubvectors_ManyVector(y_data); - S_data.resize(num_vecs); - S_rhs.resize(num_vecs); - - for(int i=0; inComp()); - S_rhs.at(i)=amrex::MultiFab(*amrex::sundials::getMFptr(N_VGetSubvector_ManyVector(y_rhs, i)),amrex::make_alias,0,amrex::sundials::getMFptr(N_VGetSubvector_ManyVector(y_rhs, i))->nComp()); - } - - BaseT::post_update(S_data, rhs_time); - BaseT::rhs(S_rhs, S_data, rhs_time); + N_VDestroy(y_old); + N_VDestroy(y_new); - return 0; - }; + return dt; + } - udata.ProcessStage = [&](sunrealtype rhs_time, N_Vector y_data, void * /* user_data */) -> int { - amrex::Vector S_data; + void evolve (T& S_out, const amrex::Real time_out) override + { + int flag = 0; // SUNDIALS return status + amrex::Real time_ret; // SUNDIALS return time - const int num_vecs = N_VGetNumSubvectors_ManyVector(y_data); - S_data.resize(num_vecs); + N_Vector y_out = wrap_data(S_out); - for (int i=0; inComp()); + if (use_ark) { + if (!BaseT::use_adaptive_time_step) { + ARKStepSetFixedStep(arkode_mem, BaseT::time_step); } - - BaseT::post_update(S_data, rhs_time); - - return 0; - }; - /* End Section: SUNDIALS FUNCTION HOOKS */ - - /* Set up CVODE BDF solver */ - cvode_mem = CVodeCreate(CV_BDF, sunctx); - CVodeSetUserData(cvode_mem, &udata); - CVodeInit(cvode_mem, SundialsUserFun::f, time, nv_S); - CVodeSStolerances(cvode_mem, reltol, abstol); - CVodeSetMaxNumSteps(cvode_mem, 100000); - - for(int i=0; i= 0); } - - // Set up and assign the linear solver (GMRES) - LS = SUNLinSol_SPGMR(nv_S, SUN_PREC_NONE, 0, sunctx); - CVodeSetLinearSolver(cvode_mem, LS, nullptr); - - // Use CVode to evolve state_old data (wrapped in nv_S) from t to tout=t+dt - auto flag = CVode(cvode_mem, tout, nv_S, &t, CV_NORMAL); - AMREX_ALWAYS_ASSERT(flag >= 0); - - // Copy the result stored in nv_S to state_new - for(int i=0; i= 0); + } else { + Error("SUNDIALS integrator type not specified."); } - delete[] nv_many_arr; - N_VDestroy(nv_S); - N_VDestroy(nv_stage_data); - CVodeFree(&cvode_mem); - SUNLinSolFree(LS); - - // Return timestep - return timestep; + N_VDestroy(y_out); } void time_interpolate (const T& /* S_new */, const T& /* S_old */, amrex::Real /* timestep_fraction */, T& /* data */) override {} void map_data (std::function /* Map */) override {} - }; }