From e7a6e68aeb70ad1457d7eb30291f92d67c4da130 Mon Sep 17 00:00:00 2001 From: "David J. Gardner" Date: Thu, 30 May 2024 00:04:39 -0700 Subject: [PATCH 01/36] simplify sundials interface --- Src/Base/AMReX_FEIntegrator.H | 6 +- Src/Base/AMReX_IntegratorBase.H | 20 +- Src/Base/AMReX_RKIntegrator.H | 6 +- Src/Base/AMReX_TimeIntegrator.H | 28 +- .../SUNDIALS/AMReX_SundialsIntegrator.H | 786 +++--------------- 5 files changed, 151 insertions(+), 695 deletions(-) diff --git a/Src/Base/AMReX_FEIntegrator.H b/Src/Base/AMReX_FEIntegrator.H index becd795e742..5cb61fe8b76 100644 --- a/Src/Base/AMReX_FEIntegrator.H +++ b/Src/Base/AMReX_FEIntegrator.H @@ -23,14 +23,14 @@ private: public: FEIntegrator () {} - FEIntegrator (const T& S_data) + FEIntegrator (const T& S_data, const amrex::Real time = 0.0) { - initialize(S_data); + initialize(S_data, time); } virtual ~FEIntegrator () {} - void initialize (const T& S_data) override + void initialize (const T& S_data, const amrex::Real /* time */) override { initialize_stages(S_data); } diff --git a/Src/Base/AMReX_IntegratorBase.H b/Src/Base/AMReX_IntegratorBase.H index 568e063bed5..4bfe97d8112 100644 --- a/Src/Base/AMReX_IntegratorBase.H +++ b/Src/Base/AMReX_IntegratorBase.H @@ -165,12 +165,12 @@ private: /** * \brief Fun is the right-hand-side function the integrator will use. */ - std::function Fun; + std::function Fun; /** * \brief FastFun is the fast timescale right-hand-side function for a multirate integration problem. */ - std::function FastFun; + std::function FastFun; protected: /** @@ -200,14 +200,14 @@ public: virtual ~IntegratorBase () = default; - virtual void initialize (const T& S_data) = 0; + virtual void initialize (const T& S_data, const amrex::Real time = 0.0) = 0; - void set_rhs (std::function F) + void set_rhs (std::function F) { Fun = F; } - void set_fast_rhs (std::function F) + void set_fast_rhs (std::function F) { FastFun = F; } @@ -232,12 +232,12 @@ public: return post_update; } - std::function get_rhs () + std::function get_rhs () { return Fun; } - std::function get_fast_rhs () + std::function get_fast_rhs () { return FastFun; } @@ -252,14 +252,14 @@ public: return fast_timestep; } - void rhs (T& S_rhs, const T& S_data, const amrex::Real time) + void rhs (T& S_rhs, T& S_data, const amrex::Real time) { Fun(S_rhs, S_data, time); } - void fast_rhs (T& S_rhs, T& S_extra, const T& S_data, const amrex::Real time) + void fast_rhs (T& S_rhs, T& S_data, const amrex::Real time) { - FastFun(S_rhs, S_extra, S_data, time); + FastFun(S_rhs, S_data, time); } virtual amrex::Real advance (T& S_old, T& S_new, amrex::Real time, amrex::Real dt) = 0; diff --git a/Src/Base/AMReX_RKIntegrator.H b/Src/Base/AMReX_RKIntegrator.H index f1bc5c58151..80f3d96df2f 100644 --- a/Src/Base/AMReX_RKIntegrator.H +++ b/Src/Base/AMReX_RKIntegrator.H @@ -155,12 +155,12 @@ private: public: RKIntegrator () {} - RKIntegrator (const T& S_data) + RKIntegrator (const T& S_data, const amrex::Real time = 0.0) { - initialize(S_data); + initialize(S_data, time); } - void initialize (const T& S_data) override + void initialize (const T& S_data, const amrex::Real /* time */) override { initialize_parameters(); initialize_stages(S_data); diff --git a/Src/Base/AMReX_TimeIntegrator.H b/Src/Base/AMReX_TimeIntegrator.H index 6ac11107f91..cf7a12ff6c9 100644 --- a/Src/Base/AMReX_TimeIntegrator.H +++ b/Src/Base/AMReX_TimeIntegrator.H @@ -75,8 +75,8 @@ private: set_post_update([](T& /* S_data */, amrex::Real /* S_time */){}); // By default, do nothing - set_rhs([](T& /* S_rhs */, const T& /* S_data */, const amrex::Real /* time */){}); - set_fast_rhs([](T& /* S_rhs */, T& /* S_extra */, const T& /* S_data */, const amrex::Real /* time */){}); + set_rhs([](T& /* S_rhs */, T& /* S_data */, const amrex::Real /* time */){}); + set_fast_rhs([](T& /* S_rhs */, T& /* S_data */, const amrex::Real /* time */){}); // By default, initialize time, timestep, step number to 0's m_time = 0.0_rt; @@ -91,20 +91,20 @@ public: set_default_functions(); } - TimeIntegrator (IntegratorTypes integrator_type, const T& S_data) + TimeIntegrator (IntegratorTypes integrator_type, const T& S_data, const amrex::Real time = 0.0) { // initialize the integrator class corresponding to the desired type - initialize_integrator(integrator_type, S_data); + initialize_integrator(integrator_type, S_data, time); // initialize functions to do nothing set_default_functions(); } - TimeIntegrator (const T& S_data) + TimeIntegrator (const T& S_data, const amrex::Real time = 0.0) { // initialize the integrator class corresponding to the input parameter selection IntegratorTypes integrator_type = read_parameters(); - initialize_integrator(integrator_type, S_data); + initialize_integrator(integrator_type, S_data, time); // initialize functions to do nothing set_default_functions(); @@ -112,19 +112,19 @@ public: virtual ~TimeIntegrator () {} - void initialize_integrator (IntegratorTypes integrator_type, const T& S_data) + void initialize_integrator (IntegratorTypes integrator_type, const T& S_data, const amrex::Real time = 0.0) { switch (integrator_type) { case IntegratorTypes::ForwardEuler: - integrator_ptr = std::make_unique >(S_data); + integrator_ptr = std::make_unique >(S_data, time); break; case IntegratorTypes::ExplicitRungeKutta: - integrator_ptr = std::make_unique >(S_data); + integrator_ptr = std::make_unique >(S_data, time); break; #ifdef AMREX_USE_SUNDIALS case IntegratorTypes::Sundials: - integrator_ptr = std::make_unique >(S_data); + integrator_ptr = std::make_unique >(S_data, time); break; #endif default: @@ -143,12 +143,12 @@ public: integrator_ptr->set_post_update(F); } - void set_rhs (std::function F) + void set_rhs (std::function F) { integrator_ptr->set_rhs(F); } - void set_fast_rhs (std::function F) + void set_fast_rhs (std::function F) { integrator_ptr->set_fast_rhs(F); } @@ -198,12 +198,12 @@ public: return integrator_ptr->get_post_update(); } - std::function get_rhs () + std::function get_rhs () { return integrator_ptr->get_rhs(); } - std::function get_fast_rhs () + std::function get_fast_rhs () { return integrator_ptr->get_fast_rhs(); } diff --git a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H index 11d73c9920c..4c8a990452c 100644 --- a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H +++ b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H @@ -1,65 +1,42 @@ #ifndef AMREX_SUNDIALS_INTEGRATOR_H #define AMREX_SUNDIALS_INTEGRATOR_H + +#include + #include #include #include #include #include -#include -#include /* prototypes for ERKStep fcts., consts */ -#include /* prototypes for ARKStep fcts., consts */ -#include /* prototypes for MRIStep fcts., consts */ -#include /* access to CVODE solver */ -#include /* manyvector N_Vector types, fcts. etc */ -#include /* MultiFab N_Vector types, fcts., macros */ -#include /* MultiFab N_Vector types, fcts., macros */ -#include /* access to SPGMR SUNLinearSolver */ -#include /* access to SPGMR SUNLinearSolver */ -#include /* access to FixedPoint SUNNonlinearSolver */ -#include /* defs. of sunrealtype, sunindextype, etc */ +#include +#include + +#include +#include +#include namespace amrex { struct SundialsUserData { - std::function f0; - std::function f_fast; std::function f; - /* std::function StoreStage; */ - std::function ProcessStage; - std::function PostStoreStage; + std::function fe; + std::function fi; }; namespace SundialsUserFun { - static int f0 (sunrealtype t, N_Vector y, N_Vector ydot, void *user_data) { - SundialsUserData* udata = static_cast(user_data); - return udata->f0(t, y, ydot, user_data); - } - - static int f_fast (sunrealtype t, N_Vector y_data, N_Vector y_rhs, void *user_data) { - SundialsUserData* udata = static_cast(user_data); - return udata->f_fast(t, y_data, y_rhs, user_data); - } - static int f (sunrealtype t, N_Vector y_data, N_Vector y_rhs, void *user_data) { SundialsUserData* udata = static_cast(user_data); return udata->f(t, y_data, y_rhs, user_data); } -/* - static int StoreStage (sunrealtype t, N_Vector* f_data, int nvecs, void *user_data) { - SundialsUserData* udata = static_cast(user_data); - return udata->StoreStage(t, f_data, nvecs, user_data); - } -*/ - - static int ProcessStage (sunrealtype t, N_Vector y_data, void *user_data) { + static int fe (sunrealtype t, N_Vector y_data, N_Vector y_rhs, void *user_data) { SundialsUserData* udata = static_cast(user_data); - return udata->ProcessStage(t, y_data, user_data); + return udata->fe(t, y_data, y_rhs, user_data); } - static int PostStoreStage(sunrealtype t, N_Vector y_data, void *user_data) { + static int fi (sunrealtype t, N_Vector y_data, N_Vector y_rhs, void *user_data) { SundialsUserData* udata = static_cast(user_data); - return udata->PostStoreStage(t, y_data, user_data); + return udata->fi(t, y_data, y_rhs, user_data); } } @@ -67,133 +44,71 @@ template class SundialsIntegrator : public IntegratorBase { private: - amrex::Real timestep; using BaseT = IntegratorBase; - bool use_erk_strategy; - bool use_mri_strategy; - bool use_mri_strategy_test; - bool use_cvode_strategy; - bool use_implicit_inner; - - SUNNonlinearSolver NLS; /* empty nonlinear solver object */ - SUNLinearSolver LS; /* empty linear solver object */ - void *arkode_mem; /* empty ARKode memory structure */ - void *cvode_mem; /* empty CVODE memory structure */ - SUNNonlinearSolver NLSf; /* empty nonlinear solver object */ - SUNLinearSolver LSf; /* empty linear solver object */ - void *inner_mem; /* empty ARKode memory structure */ - void *mristep_mem; /* empty ARKode memory structure */ - MPI_Comm mpi_comm; /* the MPI communicator */ - SUNContext sunctx; /* SUNDIALS Context object */ - - std::string mri_outer_method, mri_inner_method, erk_method; - - Real reltol; - Real abstol; - Real t; - Real tout; - Real hfixed; - Real hfixed_mri; - - int NVar; // NOTE: expects S_data to be a Vector - N_Vector* nv_many_arr; /* vector array composed of cons, xmom, ymom, zmom component vectors */ + // method type and name + std::string type = "ERK"; + std::string method = "DEFAULT"; + std::string method_e = "DEFAULT"; + std::string method_i = "DEFAULT"; + + // method type flags + bool use_erk = false; + bool use_dirk = false; + bool use_imex = false; + bool use_mri = false; + + // structure for interfacing with user-supplied functions + SundialsUserData udata; + + // SUNDIALS objects + SUNContext sunctx = nullptr; // should get from singleton + void *arkode_mem = nullptr; + SUNLinearSolver LS = nullptr; + + // relative and absolute tolerances + Real reltol = 1.0e-4; + Real abstol = 1.0e-9; + + int NVar; // NOTE: expects S_data to be a Vector + N_Vector* nv_many_arr; /* vector array composed of cons, xmom, ymom, zmom component vectors */ N_Vector nv_S; - N_Vector nv_stage_data; void initialize_parameters () { - use_erk_strategy=false; - use_mri_strategy=false; - use_mri_strategy_test=false; - use_cvode_strategy=false; - amrex::ParmParse pp("integration.sundials"); - std::string theStrategy; - - pp.get("strategy", theStrategy); + pp.query("type", type); + pp.query("method", method); - if (theStrategy == "ERK") - { - use_erk_strategy=true; - erk_method = "SSPRK3"; - amrex::ParmParse pp_erk("integration.sundials.erk"); - pp_erk.query("method", erk_method); + if (type == "ERK") { + use_erk = true; } - else if (theStrategy == "MRI") - { - use_mri_strategy=true; + else if (type == "DIRK") { + use_dirk = true; } - else if (theStrategy == "MRITEST") - { - use_mri_strategy=true; - use_mri_strategy_test=true; - } - else if (theStrategy == "CVODE") - { - use_cvode_strategy=true; + else if (type == "IMEX") { + use_imex = true; } - else - { + else { std::string msg("Unknown strategy: "); - msg += theStrategy; + msg += type; amrex::Error(msg.c_str()); } - - if (theStrategy == "MRI" || theStrategy == "MRITEST") - { - use_implicit_inner = false; - mri_outer_method = "KnothWolke3"; - mri_inner_method = "ForwardEuler"; - amrex::ParmParse pp_mri("integration.sundials.mri"); - pp_mri.query("implicit_inner", use_implicit_inner); - pp_mri.query("outer_method", mri_outer_method); - pp_mri.query("inner_method", mri_inner_method); - } - - // SUNDIALS specific objects - NLS = nullptr; /* empty nonlinear solver object */ - LS = nullptr; /* empty linear solver object */ - arkode_mem = nullptr; /* empty ARKode memory structure */ - cvode_mem = nullptr; /* empty CVODE memory structure */ - NLSf = nullptr; /* empty nonlinear solver object */ - LSf = nullptr; /* empty linear solver object */ - inner_mem = nullptr; /* empty ARKode memory structure */ - mristep_mem = nullptr; /* empty ARKode memory structure */ - - // Arbitrary tolerances - reltol = 1e-4; - abstol = 1e-4; } public: SundialsIntegrator () {} - SundialsIntegrator (const T& /* S_data */) + SundialsIntegrator (const T& S_data, const amrex::Real time = 0.0) { - initialize(); - } - - void initialize (const T& /* S_data */) override - { - initialize_parameters(); - mpi_comm = ParallelContext::CommunicatorSub(); -#if defined(SUNDIALS_VERSION_MAJOR) && (SUNDIALS_VERSION_MAJOR < 7) - SUNContext_Create(&mpi_comm, &sunctx); -#else -# ifdef AMREX_USE_MPI - SUNContext_Create(mpi_comm, &sunctx); -# else - SUNContext_Create(SUN_COMM_NULL, &sunctx); -# endif -#endif + initialize(S_data, time); } - void initialize () + void initialize (const T& S_data, const amrex::Real time = 0.0) override { initialize_parameters(); - mpi_comm = ParallelContext::CommunicatorSub(); + MPI_Comm mpi_comm = ParallelContext::CommunicatorSub(); #if defined(SUNDIALS_VERSION_MAJOR) && (SUNDIALS_VERSION_MAJOR < 7) SUNContext_Create(&mpi_comm, &sunctx); #else @@ -203,64 +118,7 @@ public: SUNContext_Create(SUN_COMM_NULL, &sunctx); # endif #endif - } - - virtual ~SundialsIntegrator () { - SUNContext_Free(&sunctx); - } - - amrex::Real advance (T& S_old, T& S_new, amrex::Real time, const amrex::Real time_step) override - { - if (use_mri_strategy) { - return advance_mri(S_old, S_new, time, time_step); - } else if (use_erk_strategy) { - return advance_erk(S_old, S_new, time, time_step); - } else if (use_cvode_strategy) { - return advance_cvode(S_old, S_new, time, time_step); - }else { - Error("SUNDIALS integrator backend not specified (ERK, MRI, or CVODE)."); - } - - return 0; - } - - amrex::Real advance_erk (T& S_old, T& S_new, amrex::Real time, const amrex::Real time_step) - { - t = time; - tout = time+time_step; - hfixed = time_step; - timestep = time_step; - - // We use S_new as our working space, so first copy S_old to S_new - IntegratorOps::Copy(S_new, S_old); - - // Create an N_Vector wrapper for the solution MultiFab - auto get_length = [&](int index) -> sunindextype { - auto* p_mf = &S_new[index]; - return p_mf->nComp() * (p_mf->boxArray()).numPts(); - }; - - /* Create manyvector for solution using S_new */ - NVar = S_new.size(); // NOTE: expects S_new to be a Vector - nv_many_arr = new N_Vector[NVar]; // vector array composed of cons, xmom, ymom, zmom component vectors */ - - for (int i = 0; i < NVar; ++i) { - sunindextype length = get_length(i); - N_Vector nvi = amrex::sundials::N_VMake_MultiFab(length, &S_new[i]); - nv_many_arr[i] = nvi; - } - - nv_S = N_VNew_ManyVector(NVar, nv_many_arr, sunctx); - nv_stage_data = N_VClone(nv_S); - /* Create a temporary storage space for MRI */ - Vector > temp_storage; - IntegratorOps::CreateLike(temp_storage, S_old); - T& state_store = *temp_storage.back(); - - SundialsUserData udata; - - /* Begin Section: SUNDIALS FUNCTION HOOKS */ /* f routine to compute the ODE RHS function f(t,y). */ udata.f = [&](sunrealtype rhs_time, N_Vector y_data, N_Vector y_rhs, void * /* user_data */) -> int { amrex::Vector S_data; @@ -272,8 +130,15 @@ public: for(int i=0; inComp()); - S_rhs.at(i)=amrex::MultiFab(*amrex::sundials::getMFptr(N_VGetSubvector_ManyVector(y_rhs, i)),amrex::make_alias,0,amrex::sundials::getMFptr(N_VGetSubvector_ManyVector(y_rhs, i))->nComp()); + S_data.at(i)=amrex::MultiFab(*amrex::sundials::getMFptr(N_VGetSubvector_ManyVector(y_data, i)), + amrex::make_alias, + 0, + amrex::sundials::getMFptr(N_VGetSubvector_ManyVector(y_data, i))->nComp()); + + S_rhs.at(i)=amrex::MultiFab(*amrex::sundials::getMFptr(N_VGetSubvector_ManyVector(y_rhs, i)), + amrex::make_alias, + 0, + amrex::sundials::getMFptr(N_VGetSubvector_ManyVector(y_rhs, i))->nComp()); } BaseT::post_update(S_data, rhs_time); @@ -282,495 +147,99 @@ public: return 0; }; - udata.ProcessStage = [&](sunrealtype rhs_time, N_Vector y_data, void * /* user_data */) -> int { - amrex::Vector S_data; - - const int num_vecs = N_VGetNumSubvectors_ManyVector(y_data); - S_data.resize(num_vecs); - - for (int i=0; inComp()); - } - - BaseT::post_update(S_data, rhs_time); - - return 0; - }; - /* End Section: SUNDIALS FUNCTION HOOKS */ - - /* Call ERKStepCreate to initialize the inner ARK timestepper module and - specify the right-hand side function in y'=f(t,y), the initial time - T0, and the initial dependent variable vector y. */ - arkode_mem = ERKStepCreate(SundialsUserFun::f, time, nv_S, sunctx); - ERKStepSetUserData(arkode_mem, &udata); /* Pass udata to user functions */ - ERKStepSetPostprocessStageFn(arkode_mem, SundialsUserFun::ProcessStage); - /* Specify tolerances */ - ERKStepSStolerances(arkode_mem, reltol, abstol); - ERKStepSetFixedStep(arkode_mem, hfixed); - - for(int i=0; i ARKodeButcherTable { - ARKodeButcherTable B; - if (method == "SSPRK3") { - B = ARKodeButcherTable_Alloc(3, SUNFALSE); - - // 3rd order Strong Stability Preserving RK3 - B->A[1][0] = 1.0; - B->A[2][0] = 0.25; - B->A[2][1] = 0.25; - B->b[0] = 1./6.; - B->b[1] = 1./6.; - B->b[2] = 2./3.; - B->c[1] = 1.0; - B->c[2] = 0.5; - B->q=3; - } else if (method == "Trapezoid") { - B = ARKodeButcherTable_Alloc(2, SUNFALSE); - - // Trapezoidal rule - B->A[1][0] = 1.0; - B->b[0] = 0.5; - B->b[1] = 0.5; - B->c[1] = 1.0; - B->q=2; - B->p=0; - } else if (method == "ForwardEuler") { - B = ARKodeButcherTable_Alloc(1, SUNFALSE); - - // Forward Euler - B->b[0] = 1.0; - B->q=1; - B->p=0; - } else - amrex::Error("ERK method not implemented"); - return B; - }; - - ARKodeButcherTable B = make_butcher_table(erk_method); - - //Set table - ERKStepSetTable(arkode_mem, B); - - // Free the Butcher table - ARKodeButcherTable_Free(B); - - // Use ERKStep to evolve state_old data (wrapped in nv_S) from t to tout=t+dt - auto flag = ERKStepEvolve(arkode_mem, tout, nv_S, &t, ARK_NORMAL); - AMREX_ALWAYS_ASSERT(flag >= 0); - - // Copy the result stored in nv_S to state_new - for(int i=0; i= 1 || mri_fast_time_step >= 0.0); - t = time; - tout = time+time_step; - hfixed = time_step; - hfixed_mri = mri_fast_time_step >= 0.0 ? mri_fast_time_step : time_step / mri_time_step_ratio; - timestep = time_step; - - // NOTE: hardcoded for now ... - bool use_erk3 = !use_implicit_inner; - bool use_linear = false; - - // We use S_new as our working space, so first copy S_old to S_new - IntegratorOps::Copy(S_new, S_old); // Create an N_Vector wrapper for the solution MultiFab auto get_length = [&](int index) -> sunindextype { - auto* p_mf = &S_new[index]; + auto* p_mf = &S_data[index]; return p_mf->nComp() * (p_mf->boxArray()).numPts(); }; - /* Create manyvector for solution using S_new */ - NVar = S_new.size(); // NOTE: expects S_new to be a Vector + NVar = S_data.size(); // NOTE: expects S_data to be a Vector nv_many_arr = new N_Vector[NVar]; // vector array composed of cons, xmom, ymom, zmom component vectors */ for (int i = 0; i < NVar; ++i) { - sunindextype length = get_length(i); - N_Vector nvi = amrex::sundials::N_VMake_MultiFab(length, &S_new[i]); - nv_many_arr[i] = nvi; + nv_many_arr[i] = amrex::sundials::N_VNew_MultiFab(get_length(i), + S_data[i].boxArray(), + S_data[i].DistributionMap(), + S_data[i].nComp(), + S_data[i].nGrow()); + MultiFab::Copy(*amrex::sundials::getMFptr(nv_many_arr[i]), S_data[i], 0, 0, S_data[i].nComp(), S_data[i].nGrow()); } - nv_S = N_VNew_ManyVector(NVar, nv_many_arr, sunctx); - nv_stage_data = N_VClone(nv_S); - // Copy the initial step data to nv_stage_data - for(int i=0; inComp(), mf_y->nGrow()); + if (use_erk) { + amrex::Print() << "SUNDIALS ERK time integrator\n"; + arkode_mem = ARKStepCreate(SundialsUserFun::f, nullptr, time, nv_S, sunctx); } - - /* Create a temporary storage space for MRI */ - Vector > temp_storage; - IntegratorOps::CreateLike(temp_storage, S_old); - T& state_store = *temp_storage.back(); - - SundialsUserData udata; - - /* Begin Section: SUNDIALS FUNCTION HOOKS */ - /* f0 routine to compute a zero-valued ODE RHS function f(t,y). */ - udata.f0 = [&](sunrealtype /* rhs_time */, N_Vector /* y */, N_Vector ydot, void * /* user_data */) -> int { - // Initialize ydot to zero and return - N_VConst(0.0, ydot); - return 0; - }; - - /* f routine to compute the ODE RHS function f(t,y). */ - udata.f_fast = [&](sunrealtype rhs_time, N_Vector y_data, N_Vector y_rhs, void * /* user_data */) -> int { - amrex::Vector S_data; - amrex::Vector S_rhs; - amrex::Vector S_stage_data; - - N_VConst(0.0, y_rhs); - - const int num_vecs = N_VGetNumSubvectors_ManyVector(y_data); - S_data.resize(num_vecs); - S_rhs.resize(num_vecs); - S_stage_data.resize(num_vecs); - - for(int i=0; inComp()); - S_rhs.at(i)=amrex::MultiFab(*amrex::sundials::getMFptr(N_VGetSubvector_ManyVector(y_rhs, i)),amrex::make_alias,0,amrex::sundials::getMFptr(N_VGetSubvector_ManyVector(y_rhs, i))->nComp()); - S_stage_data.at(i)=amrex::MultiFab(*amrex::sundials::getMFptr(N_VGetSubvector_ManyVector(nv_stage_data, i)),amrex::make_alias,0,amrex::sundials::getMFptr(N_VGetSubvector_ManyVector(nv_stage_data, i))->nComp()); - } - - // NOTE: we can optimize by calling a post_update_fast and only updating the variables the fast integration modifies - BaseT::post_update(S_data, rhs_time); - - BaseT::fast_rhs(S_rhs, S_stage_data, S_data, rhs_time); - - return 0; - }; - - /* f routine to compute the ODE RHS function f(t,y). */ - udata.f = [&](sunrealtype rhs_time, N_Vector y_data, N_Vector y_rhs, void * /* user_data */) -> int { - amrex::Vector S_data; - amrex::Vector S_rhs; - - const int num_vecs = N_VGetNumSubvectors_ManyVector(y_data); - S_data.resize(num_vecs); - S_rhs.resize(num_vecs); - - for(int i=0; inComp()); - S_rhs.at(i)=amrex::MultiFab(*amrex::sundials::getMFptr(N_VGetSubvector_ManyVector(y_rhs, i)),amrex::make_alias,0,amrex::sundials::getMFptr(N_VGetSubvector_ManyVector(y_rhs, i))->nComp()); - } - - BaseT::post_update(S_data, rhs_time); - BaseT::rhs(S_rhs, S_data, rhs_time); - - return 0; - }; - - udata.ProcessStage = [&](sunrealtype rhs_time, N_Vector y_data, void * /* user_data */) -> int { - amrex::Vector S_data; - - const int num_vecs = N_VGetNumSubvectors_ManyVector(y_data); - S_data.resize(num_vecs); - - for (int i=0; inComp()); - } - - BaseT::post_update(S_data, rhs_time); - - return 0; - }; - - udata.PostStoreStage = [&](sunrealtype rhs_time, N_Vector y_data, void *user_data) -> int { - udata.ProcessStage(rhs_time, y_data, user_data); - - for(int i=0; inComp(), mf_y->nGrow()); - } - - return 0; - }; - /* End Section: SUNDIALS FUNCTION HOOKS */ - - if(use_mri_strategy_test) - { - if(use_erk3) { - inner_mem = ARKStepCreate(SundialsUserFun::f0, nullptr, time, nv_S, sunctx); // explicit bc (explicit f, implicit f, time, data) - } else { - inner_mem = ARKStepCreate(nullptr, SundialsUserFun::f0, time, nv_S, sunctx); // implicit - } + else if (use_dirk) { + amrex::Print() << "SUNDIALS DIRK time integrator\n"; + arkode_mem = ARKStepCreate(nullptr, SundialsUserFun::f, time, nv_S, sunctx); } - else - { - if(use_erk3) { - inner_mem = ARKStepCreate(SundialsUserFun::f_fast, nullptr, time, nv_S, sunctx); - } else { - inner_mem = ARKStepCreate(nullptr, SundialsUserFun::f_fast, time, nv_S, sunctx); - } + else if (use_imex) { + amrex::Print() << "SUNDIALS IMEX time integrator\n"; + arkode_mem = ARKStepCreate(SundialsUserFun::fe, SundialsUserFun::fi, time, nv_S, sunctx); } + ARKStepSetUserData(arkode_mem, &udata); /* Pass udata to user functions */ + ARKStepSStolerances(arkode_mem, reltol, abstol); - ARKStepSetFixedStep(inner_mem, hfixed_mri); // Specify fixed time step size - - ARKStepSetUserData(inner_mem, &udata); /* Pass udata to user functions */ - - for(int i=0; i ARKodeButcherTable { - ARKodeButcherTable B; - if (method == "KnothWolke3" || method == "Knoth-Wolke-3-3") { - B = ARKodeButcherTable_Alloc(3, SUNFALSE); - - // 3rd order Knoth-Wolke method - B->A[1][0] = 1.0/3.0; - B->A[2][0] = -3.0/16.0; - B->A[2][1] = 15.0/16.0; - B->b[0] = 1./6.; - B->b[1] = 3./10.; - B->b[2] = 8./15.; - B->c[1] = 1.0/3.0; - B->c[2] = 3.0/4.0; - B->q=3; - B->p=0; - } else if (method == "Trapezoid") { - B = ARKodeButcherTable_Alloc(2, SUNFALSE); - - // Trapezoidal rule - B->A[1][0] = 1.0; - B->b[0] = 0.5; - B->b[1] = 0.5; - B->c[1] = 1.0; - B->q=2; - B->p=0; - } else if (method == "ForwardEuler") { - B = ARKodeButcherTable_Alloc(1, SUNFALSE); - - // Forward Euler - B->b[0] = 1.0; - B->q=1; - B->p=0; - } else { - amrex::Error("MRI method not implemented"); + if (use_erk) { + amrex::Print() << "SUNDIALS ERK method " << method << "\n"; + ARKStepSetTableName(arkode_mem, "ARKODE_DIRK_NONE", method.c_str()); + } + else if (use_dirk) { + amrex::Print() << "SUNDIALS DIRK method " << method << "\n"; + ARKStepSetTableName(arkode_mem, method.c_str(), "ARKODE_ERK_NONE"); + } + else { + amrex::Print() << "SUNDIALS IMEX method " << method_i << " and " + << method_e << "\n"; + ARKStepSetTableName(arkode_mem, method_i.c_str(), method_e.c_str()); } - return B; - }; - - ARKodeButcherTable B_outer = make_butcher_table(mri_outer_method); - ARKodeButcherTable B_inner = make_butcher_table(mri_inner_method); - - if(use_erk3) - { - ARKStepSetTables(inner_mem, B_inner->q, B_inner->p, nullptr, B_inner); // Specify Butcher table - } else { - ARKodeButcherTable_Free(B_inner); - B_inner = ARKodeButcherTable_Alloc(2, SUNFALSE); - - B_inner->A[1][0] = 1.0; - B_inner->A[2][0] = 1.0; - B_inner->A[2][2] = 0.0; - B_inner->b[0] = 0.5; - B_inner->b[2] = 0.5; - B_inner->c[1] = 1.0; - B_inner->c[2] = 1.0; - B_inner->q=2; - ARKStepSetTables(inner_mem, B_inner->q, B_inner->p, B_inner, nullptr); // Specify Butcher table } - //Set table - // Create fast time scale integrator from an ARKStep instance - MRIStepInnerStepper inner_stepper = nullptr; - ARKStepCreateMRIStepInnerStepper(inner_mem, &inner_stepper); - - // args: fast RHS, nullptr, initial time, initial state, fast time scale integrator, sundials context - mristep_mem = MRIStepCreate(SundialsUserFun::f, nullptr, time, nv_S, inner_stepper, sunctx); - - MRIStepSetFixedStep(mristep_mem, hfixed); - - /* Specify tolerances */ - MRIStepSStolerances(mristep_mem, reltol, abstol); - - /* Initialize spgmr solver */ -#if defined(SUNDIALS_VERSION_MAJOR) && (SUNDIALS_VERSION_MAJOR < 7) - LS = SUNLinSol_SPGMR(nv_S, PREC_NONE, 10, sunctx); -#else - LS = SUNLinSol_SPGMR(nv_S, SUN_PREC_NONE, 10, sunctx); -#endif - NLS = SUNNonlinSol_FixedPoint(nv_S, 50, sunctx); - - if (use_implicit_inner) { ARKStepSetNonlinearSolver(inner_mem, NLS); } - if(use_linear) { - MRIStepSetLinearSolver(mristep_mem, LS, nullptr); - } else { - MRIStepSetNonlinearSolver(mristep_mem, NLS); - } - - MRIStepSetUserData(mristep_mem, &udata); /* Pass udata to user functions */ - MRIStepSetPostprocessStageFn(mristep_mem, SundialsUserFun::ProcessStage); - - MRIStepCoupling mri_coupling = MRIStepCoupling_MIStoMRI(B_outer, B_outer->q, B_outer->p); - MRIStepSetCoupling(mristep_mem, mri_coupling); - - // Free the Butcher tables - ARKodeButcherTable_Free(B_outer); - ARKodeButcherTable_Free(B_inner); - - // Use MRIStep to evolve state_old data (wrapped in nv_S) from t to tout=t+dt - auto flag = MRIStepEvolve(mristep_mem, tout, nv_S, &t, ARK_NORMAL); - AMREX_ALWAYS_ASSERT(flag >= 0); - - // Copy the result stored in nv_S to state_new - for(int i=0; i::Copy(S_new, S_old); - - // Create an N_Vector wrapper for the solution MultiFab - auto get_length = [&](int index) -> sunindextype { - auto* p_mf = &S_new[index]; - return p_mf->nComp() * (p_mf->boxArray()).numPts(); - }; - - /* Create manyvector for solution using S_new */ - NVar = S_new.size(); // NOTE: expects S_new to be a Vector - nv_many_arr = new N_Vector[NVar]; // vector array composed of cons, xmom, ymom, zmom component vectors */ - - for (int i = 0; i < NVar; ++i) { - sunindextype length = get_length(i); - N_Vector nvi = amrex::sundials::N_VMake_MultiFab(length, &S_new[i]); - nv_many_arr[i] = nvi; - } - - nv_S = N_VNew_ManyVector(NVar, nv_many_arr, sunctx); - nv_stage_data = N_VClone(nv_S); - - /* Create a temporary storage space for MRI */ - Vector > temp_storage; - IntegratorOps::CreateLike(temp_storage, S_old); - T& state_store = *temp_storage.back(); - - SundialsUserData udata; - - /* Begin Section: SUNDIALS FUNCTION HOOKS */ - /* f routine to compute the ODE RHS function f(t,y). */ - udata.f = [&](sunrealtype rhs_time, N_Vector y_data, N_Vector y_rhs, void * /* user_data */) -> int { - amrex::Vector S_data; - amrex::Vector S_rhs; - - const int num_vecs = N_VGetNumSubvectors_ManyVector(y_data); - S_data.resize(num_vecs); - S_rhs.resize(num_vecs); - - for(int i=0; inComp()); - S_rhs.at(i)=amrex::MultiFab(*amrex::sundials::getMFptr(N_VGetSubvector_ManyVector(y_rhs, i)),amrex::make_alias,0,amrex::sundials::getMFptr(N_VGetSubvector_ManyVector(y_rhs, i))->nComp()); - } - - BaseT::post_update(S_data, rhs_time); - BaseT::rhs(S_rhs, S_data, rhs_time); - - return 0; - }; - - udata.ProcessStage = [&](sunrealtype rhs_time, N_Vector y_data, void * /* user_data */) -> int { - amrex::Vector S_data; - - const int num_vecs = N_VGetNumSubvectors_ManyVector(y_data); - S_data.resize(num_vecs); - - for (int i=0; inComp()); - } - - BaseT::post_update(S_data, rhs_time); - - return 0; - }; - /* End Section: SUNDIALS FUNCTION HOOKS */ - - /* Set up CVODE BDF solver */ - cvode_mem = CVodeCreate(CV_BDF, sunctx); - CVodeSetUserData(cvode_mem, &udata); - CVodeInit(cvode_mem, SundialsUserFun::f, time, nv_S); - CVodeSStolerances(cvode_mem, reltol, abstol); - CVodeSetMaxNumSteps(cvode_mem, 100000); + amrex::Real tout = time + time_step; + amrex::Real tret; + // Copy the S_old to nv_S for(int i=0; i= 0); + if (use_erk || use_dirk || use_imex) { + ARKStepReset(arkode_mem, time, nv_S); + ARKStepSetFixedStep(arkode_mem, time_step); + int flag = ARKStepEvolve(arkode_mem, tout, nv_S, &tret, ARK_ONE_STEP); + AMREX_ALWAYS_ASSERT(flag >= 0); + } + else if (use_mri) { + Error("SUNDIALS integrator type not implemented, yet."); + } else { + Error("SUNDIALS integrator type not specified."); + } // Copy the result stored in nv_S to state_new for(int i=0; i /* Map */) override {} - }; } From bb788922326e9e99415c4f5ffff1478cf94a2a76 Mon Sep 17 00:00:00 2001 From: "David J. Gardner" Date: Fri, 31 May 2024 00:49:23 -0700 Subject: [PATCH 02/36] add evolve function for adaptive stepping --- Src/Base/AMReX_FEIntegrator.H | 5 ++++ Src/Base/AMReX_IntegratorBase.H | 2 ++ Src/Base/AMReX_RKIntegrator.H | 5 ++++ Src/Base/AMReX_TimeIntegrator.H | 5 ++++ .../SUNDIALS/AMReX_SundialsIntegrator.H | 28 +++++++++++++++++++ 5 files changed, 45 insertions(+) diff --git a/Src/Base/AMReX_FEIntegrator.H b/Src/Base/AMReX_FEIntegrator.H index 5cb61fe8b76..8e504397490 100644 --- a/Src/Base/AMReX_FEIntegrator.H +++ b/Src/Base/AMReX_FEIntegrator.H @@ -56,6 +56,11 @@ public: return BaseT::timestep; } + void evolve (T& S_out, const amrex::Real t_out) override + { + amrex::Error("Not implemented yet"); + } + virtual void time_interpolate (const T& /* S_new */, const T& /* S_old */, amrex::Real /* timestep_fraction */, T& /* data */) override { amrex::Error("Time interpolation not yet supported by forward euler integrator."); diff --git a/Src/Base/AMReX_IntegratorBase.H b/Src/Base/AMReX_IntegratorBase.H index 4bfe97d8112..5804fa40638 100644 --- a/Src/Base/AMReX_IntegratorBase.H +++ b/Src/Base/AMReX_IntegratorBase.H @@ -264,6 +264,8 @@ public: virtual amrex::Real advance (T& S_old, T& S_new, amrex::Real time, amrex::Real dt) = 0; + virtual void evolve (T& S_out, const amrex::Real t_out) = 0; + virtual void time_interpolate (const T& S_new, const T& S_old, amrex::Real timestep_fraction, T& data) = 0; virtual void map_data (std::function Map) = 0; diff --git a/Src/Base/AMReX_RKIntegrator.H b/Src/Base/AMReX_RKIntegrator.H index 80f3d96df2f..951cf7fc9ee 100644 --- a/Src/Base/AMReX_RKIntegrator.H +++ b/Src/Base/AMReX_RKIntegrator.H @@ -222,6 +222,11 @@ public: return BaseT::timestep; } + void evolve (T& S_out, const amrex::Real t_out) override + { + amrex::Error("Not implemented yet"); + } + void time_interpolate (const T& /* S_new */, const T& S_old, amrex::Real timestep_fraction, T& data) override { // data = S_old*(1-time_step_fraction) + S_new*(time_step_fraction) diff --git a/Src/Base/AMReX_TimeIntegrator.H b/Src/Base/AMReX_TimeIntegrator.H index cf7a12ff6c9..4fe1c1329bd 100644 --- a/Src/Base/AMReX_TimeIntegrator.H +++ b/Src/Base/AMReX_TimeIntegrator.H @@ -213,6 +213,11 @@ public: integrator_ptr->advance(S_old, S_new, time, timestep); } + void evolve (T& S_out, const amrex::Real t_out) + { + integrator_ptr->evolve(S_out, t_out); + } + void integrate (T& S_old, T& S_new, amrex::Real start_time, const amrex::Real start_timestep, const amrex::Real end_time, const int start_step, const int max_steps) { diff --git a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H index 4c8a990452c..6a875766272 100644 --- a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H +++ b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H @@ -250,6 +250,34 @@ public: return time_step; } + void evolve (T& S_out, const amrex::Real t_out) override + { + amrex::Real tret; + + if (use_erk || use_dirk || use_imex) { + int flag = ARKStepEvolve(arkode_mem, t_out, nv_S, &tret, ARK_NORMAL); + AMREX_ALWAYS_ASSERT(flag >= 0); + } + else if (use_mri) { + Error("SUNDIALS integrator type not implemented, yet."); + } else { + Error("SUNDIALS integrator type not specified."); + } + + // Should be able to make an alias to S_out and avoid the copy + for(int i=0; i /* Map */) override {} From 1ae97024716eae28f2c7acbd2c9b10b8d8cfe1ae Mon Sep 17 00:00:00 2001 From: "David J. Gardner" Date: Fri, 31 May 2024 07:53:09 -0700 Subject: [PATCH 03/36] add back const, add more ImEx RHS funcs --- Src/Base/AMReX_IntegratorBase.H | 102 ++++++++---------- Src/Base/AMReX_TimeIntegrator.H | 47 ++------ .../SUNDIALS/AMReX_SundialsIntegrator.H | 20 +++- 3 files changed, 70 insertions(+), 99 deletions(-) diff --git a/Src/Base/AMReX_IntegratorBase.H b/Src/Base/AMReX_IntegratorBase.H index 5804fa40638..9affbea24a5 100644 --- a/Src/Base/AMReX_IntegratorBase.H +++ b/Src/Base/AMReX_IntegratorBase.H @@ -162,35 +162,39 @@ template class IntegratorBase { private: - /** - * \brief Fun is the right-hand-side function the integrator will use. - */ - std::function Fun; - - /** - * \brief FastFun is the fast timescale right-hand-side function for a multirate integration problem. - */ - std::function FastFun; + /** + * \brief Fun is the right-hand-side function the integrator will use. + */ + std::function Fun; + + /** + * \brief FunIm is the implicit right-hand-side function an ImEx integrator + * will use. + */ + std::function FunIm; + + /** + * \brief FunEx is the explicit right-hand-side function an ImEx integrator + * will use. + */ + std::function FunEx; + + /** + * \brief FastFun is the fast timescale right-hand-side function a multirate + * integrator will use. + */ + std::function FunFast; protected: - /** - * \brief Integrator timestep size (Real) - */ + /** + * \brief Integrator timestep size (Real) + */ amrex::Real timestep; - /** - * \brief For multirate problems, the ratio of slow timestep size / fast timestep size (int) - */ - int slow_fast_timestep_ratio = 0; - - /** - * \brief For multirate problems, the fast timestep size (Real) - */ - Real fast_timestep = 0.0; - - /** - * \brief The post_update function is called by the integrator on state data before using it to evaluate a right-hand side. - */ + /** + * \brief The post_update function is called by the integrator on state data + * before using it to evaluate a right-hand side. + */ std::function post_update; public: @@ -202,24 +206,21 @@ public: virtual void initialize (const T& S_data, const amrex::Real time = 0.0) = 0; - void set_rhs (std::function F) + void set_rhs (std::function F) { Fun = F; } - void set_fast_rhs (std::function F) + void set_imex_rhs (std::function Fi, + std::function Fe) { - FastFun = F; + FunIm = Fi; + FunEx = Fe; } - void set_slow_fast_timestep_ratio (const int timestep_ratio = 1) + void set_fast_rhs (std::function F) { - slow_fast_timestep_ratio = timestep_ratio; - } - - void set_fast_timestep (const Real fast_dt = 1.0) - { - fast_timestep = fast_dt; + FunFast = F; } void set_post_update (std::function F) @@ -227,39 +228,24 @@ public: post_update = F; } - std::function get_post_update () - { - return post_update; - } - - std::function get_rhs () + void rhs (T& S_rhs, const T& S_data, const amrex::Real time) { - return Fun; - } - - std::function get_fast_rhs () - { - return FastFun; - } - - int get_slow_fast_timestep_ratio () - { - return slow_fast_timestep_ratio; + Fun(S_rhs, S_data, time); } - Real get_fast_timestep () + void rhs_ex (T& S_rhs, const T& S_data, const amrex::Real time) { - return fast_timestep; + FunEx(S_rhs, S_data, time); } - void rhs (T& S_rhs, T& S_data, const amrex::Real time) + void rhs_im (T& S_rhs, const T& S_data, const amrex::Real time) { - Fun(S_rhs, S_data, time); + FunIm(S_rhs, S_data, time); } - void fast_rhs (T& S_rhs, T& S_data, const amrex::Real time) + void fast_rhs (T& S_rhs, const T& S_data, const amrex::Real time) { - FastFun(S_rhs, S_data, time); + FunFast(S_rhs, S_data, time); } virtual amrex::Real advance (T& S_old, T& S_new, amrex::Real time, amrex::Real dt) = 0; diff --git a/Src/Base/AMReX_TimeIntegrator.H b/Src/Base/AMReX_TimeIntegrator.H index 4fe1c1329bd..b69f31204c2 100644 --- a/Src/Base/AMReX_TimeIntegrator.H +++ b/Src/Base/AMReX_TimeIntegrator.H @@ -75,8 +75,10 @@ private: set_post_update([](T& /* S_data */, amrex::Real /* S_time */){}); // By default, do nothing - set_rhs([](T& /* S_rhs */, T& /* S_data */, const amrex::Real /* time */){}); - set_fast_rhs([](T& /* S_rhs */, T& /* S_data */, const amrex::Real /* time */){}); + set_rhs([](T& /* S_rhs */, const T& /* S_data */, const amrex::Real /* time */){}); + set_imex_rhs([](T& /* S_rhs */, const T& /* S_data */, const amrex::Real /* time */){}, + [](T& /* S_rhs */, const T& /* S_data */, const amrex::Real /* time */){}); + set_fast_rhs([](T& /* S_rhs */, const T& /* S_data */, const amrex::Real /* time */){}); // By default, initialize time, timestep, step number to 0's m_time = 0.0_rt; @@ -143,29 +145,20 @@ public: integrator_ptr->set_post_update(F); } - void set_rhs (std::function F) + void set_rhs (std::function F) { integrator_ptr->set_rhs(F); } - void set_fast_rhs (std::function F) + void set_imex_rhs (std::function Fi, + std::function Fe) { - integrator_ptr->set_fast_rhs(F); - } - - void set_slow_fast_timestep_ratio (const int timestep_ratio = 1) - { - integrator_ptr->set_slow_fast_timestep_ratio(timestep_ratio); - } - - void set_fast_timestep (const Real fast_dt = 1.0) - { - integrator_ptr->set_fast_timestep(fast_dt); + integrator_ptr->set_imex_rhs(Fi, Fe); } - Real get_fast_timestep () + void set_fast_rhs (std::function F) { - return integrator_ptr->get_fast_timestep(); + integrator_ptr->set_fast_rhs(F); } int get_step_number () @@ -188,26 +181,6 @@ public: m_timestep = dt; } - std::function get_post_timestep () - { - return post_timestep; - } - - std::function get_post_update () - { - return integrator_ptr->get_post_update(); - } - - std::function get_rhs () - { - return integrator_ptr->get_rhs(); - } - - std::function get_fast_rhs () - { - return integrator_ptr->get_fast_rhs(); - } - void advance (T& S_old, T& S_new, amrex::Real time, const amrex::Real timestep) { integrator_ptr->advance(S_old, S_new, time, timestep); diff --git a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H index 6a875766272..0bca93fa44c 100644 --- a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H +++ b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H @@ -18,9 +18,17 @@ namespace amrex { struct SundialsUserData { + // ERK or DIRK right-hand side function + // ExMRI or ImMRI slow right-hand side function std::function f; - std::function fe; + + // ImEx right-hand side functions + // ImExMRI slow right-hand side functions std::function fi; + std::function fe; + + // MRI fast time scale right-hand side function + std::function ff; }; namespace SundialsUserFun { @@ -29,14 +37,19 @@ namespace SundialsUserFun { return udata->f(t, y_data, y_rhs, user_data); } + static int fi (sunrealtype t, N_Vector y_data, N_Vector y_rhs, void *user_data) { + SundialsUserData* udata = static_cast(user_data); + return udata->fi(t, y_data, y_rhs, user_data); + } + static int fe (sunrealtype t, N_Vector y_data, N_Vector y_rhs, void *user_data) { SundialsUserData* udata = static_cast(user_data); return udata->fe(t, y_data, y_rhs, user_data); } - static int fi (sunrealtype t, N_Vector y_data, N_Vector y_rhs, void *user_data) { + static int ff (sunrealtype t, N_Vector y_data, N_Vector y_rhs, void *user_data) { SundialsUserData* udata = static_cast(user_data); - return udata->fi(t, y_data, y_rhs, user_data); + return udata->ff(t, y_data, y_rhs, user_data); } } @@ -148,7 +161,6 @@ public: }; - // Create an N_Vector wrapper for the solution MultiFab auto get_length = [&](int index) -> sunindextype { auto* p_mf = &S_data[index]; From 6488c633cb7c8a7f94f904629fe4b76eb91b6e1c Mon Sep 17 00:00:00 2001 From: "David J. Gardner" Date: Fri, 31 May 2024 20:49:10 -0700 Subject: [PATCH 04/36] update error messages --- Src/Base/AMReX_FEIntegrator.H | 4 ++-- Src/Base/AMReX_RKIntegrator.H | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Src/Base/AMReX_FEIntegrator.H b/Src/Base/AMReX_FEIntegrator.H index 8e504397490..b156fb33913 100644 --- a/Src/Base/AMReX_FEIntegrator.H +++ b/Src/Base/AMReX_FEIntegrator.H @@ -58,12 +58,12 @@ public: void evolve (T& S_out, const amrex::Real t_out) override { - amrex::Error("Not implemented yet"); + amrex::Error("Evolve is not yet supported by the forward euler integrator."); } virtual void time_interpolate (const T& /* S_new */, const T& /* S_old */, amrex::Real /* timestep_fraction */, T& /* data */) override { - amrex::Error("Time interpolation not yet supported by forward euler integrator."); + amrex::Error("Time interpolation not yet supported by the forward euler integrator."); } virtual void map_data (std::function Map) override diff --git a/Src/Base/AMReX_RKIntegrator.H b/Src/Base/AMReX_RKIntegrator.H index 951cf7fc9ee..f7467c8bf19 100644 --- a/Src/Base/AMReX_RKIntegrator.H +++ b/Src/Base/AMReX_RKIntegrator.H @@ -224,7 +224,7 @@ public: void evolve (T& S_out, const amrex::Real t_out) override { - amrex::Error("Not implemented yet"); + amrex::Error("Evolve is not yet supported by the RK integrator."); } void time_interpolate (const T& /* S_new */, const T& S_old, amrex::Real timestep_fraction, T& data) override From 3fd83556a7d2197c7acb11cf26f25dd4180b530b Mon Sep 17 00:00:00 2001 From: "David J. Gardner" Date: Fri, 31 May 2024 21:26:42 -0700 Subject: [PATCH 05/36] remove Rhs wrappers --- Src/Base/AMReX_FEIntegrator.H | 2 +- Src/Base/AMReX_IntegratorBase.H | 57 ++++++------------- Src/Base/AMReX_RKIntegrator.H | 2 +- .../SUNDIALS/AMReX_SundialsIntegrator.H | 2 +- 4 files changed, 21 insertions(+), 42 deletions(-) diff --git a/Src/Base/AMReX_FEIntegrator.H b/Src/Base/AMReX_FEIntegrator.H index b156fb33913..f4516220ed2 100644 --- a/Src/Base/AMReX_FEIntegrator.H +++ b/Src/Base/AMReX_FEIntegrator.H @@ -44,7 +44,7 @@ public: // F = RHS(S, t) T& F = *F_nodes[0]; - BaseT::rhs(F, S_new, time); + BaseT::Rhs(F, S_new, time); // S_new += timestep * dS/dt IntegratorOps::Saxpy(S_new, BaseT::timestep, F); diff --git a/Src/Base/AMReX_IntegratorBase.H b/Src/Base/AMReX_IntegratorBase.H index 9affbea24a5..0956bf02a72 100644 --- a/Src/Base/AMReX_IntegratorBase.H +++ b/Src/Base/AMReX_IntegratorBase.H @@ -161,35 +161,29 @@ struct IntegratorOps > > template class IntegratorBase { -private: +protected: /** - * \brief Fun is the right-hand-side function the integrator will use. + * \brief Rhs is the right-hand-side function the integrator will use. */ - std::function Fun; + std::function Rhs; /** - * \brief FunIm is the implicit right-hand-side function an ImEx integrator + * \brief RhsIm is the implicit right-hand-side function an ImEx integrator * will use. */ - std::function FunIm; + std::function RhsIm; /** - * \brief FunEx is the explicit right-hand-side function an ImEx integrator + * \brief RhsEx is the explicit right-hand-side function an ImEx integrator * will use. */ - std::function FunEx; + std::function RhsEx; /** - * \brief FastFun is the fast timescale right-hand-side function a multirate + * \brief RhsFast is the fast timescale right-hand-side function a multirate * integrator will use. */ - std::function FunFast; - -protected: - /** - * \brief Integrator timestep size (Real) - */ - amrex::Real timestep; + std::function RhsFast; /** * \brief The post_update function is called by the integrator on state data @@ -197,6 +191,11 @@ protected: */ std::function post_update; + /** + * \brief Integrator timestep size (Real) + */ + amrex::Real timestep; + public: IntegratorBase () = default; @@ -208,19 +207,19 @@ public: void set_rhs (std::function F) { - Fun = F; + Rhs = F; } void set_imex_rhs (std::function Fi, std::function Fe) { - FunIm = Fi; - FunEx = Fe; + RhsIm = Fi; + RhsEx = Fe; } void set_fast_rhs (std::function F) { - FunFast = F; + RhsFast = F; } void set_post_update (std::function F) @@ -228,26 +227,6 @@ public: post_update = F; } - void rhs (T& S_rhs, const T& S_data, const amrex::Real time) - { - Fun(S_rhs, S_data, time); - } - - void rhs_ex (T& S_rhs, const T& S_data, const amrex::Real time) - { - FunEx(S_rhs, S_data, time); - } - - void rhs_im (T& S_rhs, const T& S_data, const amrex::Real time) - { - FunIm(S_rhs, S_data, time); - } - - void fast_rhs (T& S_rhs, const T& S_data, const amrex::Real time) - { - FunFast(S_rhs, S_data, time); - } - virtual amrex::Real advance (T& S_old, T& S_new, amrex::Real time, amrex::Real dt) = 0; virtual void evolve (T& S_out, const amrex::Real t_out) = 0; diff --git a/Src/Base/AMReX_RKIntegrator.H b/Src/Base/AMReX_RKIntegrator.H index f7467c8bf19..cf33238e287 100644 --- a/Src/Base/AMReX_RKIntegrator.H +++ b/Src/Base/AMReX_RKIntegrator.H @@ -200,7 +200,7 @@ public: // Fill F[i], the RHS at the current stage // F[i] = RHS(y, t) at y = stage_value, t = stage_time - BaseT::rhs(*F_nodes[i], S_new, stage_time); + BaseT::Rhs(*F_nodes[i], S_new, stage_time); } // Fill new State, starting with S_new = S_old. diff --git a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H index 0bca93fa44c..914b0ae7b5c 100644 --- a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H +++ b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H @@ -155,7 +155,7 @@ public: } BaseT::post_update(S_data, rhs_time); - BaseT::rhs(S_rhs, S_data, rhs_time); + BaseT::Rhs(S_rhs, S_data, rhs_time); return 0; }; From 4184fce508676d8fccbb99493e7c1b9aa0549050 Mon Sep 17 00:00:00 2001 From: "David J. Gardner" Date: Fri, 31 May 2024 21:27:51 -0700 Subject: [PATCH 06/36] formatting --- Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H index 914b0ae7b5c..df06a569e80 100644 --- a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H +++ b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H @@ -176,7 +176,13 @@ public: S_data[i].DistributionMap(), S_data[i].nComp(), S_data[i].nGrow()); - MultiFab::Copy(*amrex::sundials::getMFptr(nv_many_arr[i]), S_data[i], 0, 0, S_data[i].nComp(), S_data[i].nGrow()); + + MultiFab::Copy(*amrex::sundials::getMFptr(nv_many_arr[i]), + S_data[i], + 0, + 0, + S_data[i].nComp(), + S_data[i].nGrow()); } nv_S = N_VNew_ManyVector(NVar, nv_many_arr, sunctx); From 83ba92460d6dc45328ca28f4e06a8f53bb0b9f37 Mon Sep 17 00:00:00 2001 From: "David J. Gardner" Date: Fri, 31 May 2024 22:08:58 -0700 Subject: [PATCH 07/36] add more SUNDIALS Rhs wrappers --- .../SUNDIALS/AMReX_SundialsIntegrator.H | 56 +++++++++++++++++-- 1 file changed, 51 insertions(+), 5 deletions(-) diff --git a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H index df06a569e80..b1bca47b252 100644 --- a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H +++ b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H @@ -132,16 +132,15 @@ public: # endif #endif - /* f routine to compute the ODE RHS function f(t,y). */ - udata.f = [&](sunrealtype rhs_time, N_Vector y_data, N_Vector y_rhs, void * /* user_data */) -> int { - amrex::Vector S_data; - amrex::Vector S_rhs; + /* Utility to unpack SUNDIALS vectors */ + auto unpack_vectors = [&](N_Vector y_data, amrex::Vector& S_data, + N_Vector y_rhs, amrex::Vector& S_rhs) -> void { const int num_vecs = N_VGetNumSubvectors_ManyVector(y_data); S_data.resize(num_vecs); S_rhs.resize(num_vecs); - for(int i=0; inComp()); } + }; + + /* Right-hand side function wrappers */ + udata.f = [&](sunrealtype rhs_time, N_Vector y_data, N_Vector y_rhs, void * /* user_data */) -> int { + + amrex::Vector S_data; + amrex::Vector S_rhs; + + unpack_vectors(y_data, S_data, y_rhs, S_rhs); BaseT::post_update(S_data, rhs_time); BaseT::Rhs(S_rhs, S_data, rhs_time); @@ -160,6 +168,44 @@ public: return 0; }; + udata.fi = [&](sunrealtype rhs_time, N_Vector y_data, N_Vector y_rhs, void * /* user_data */) -> int { + + amrex::Vector S_data; + amrex::Vector S_rhs; + + unpack_vectors(y_data, S_data, y_rhs, S_rhs); + + BaseT::post_update(S_data, rhs_time); + BaseT::RhsIm(S_rhs, S_data, rhs_time); + + return 0; + }; + + udata.fe = [&](sunrealtype rhs_time, N_Vector y_data, N_Vector y_rhs, void * /* user_data */) -> int { + + amrex::Vector S_data; + amrex::Vector S_rhs; + + unpack_vectors(y_data, S_data, y_rhs, S_rhs); + + BaseT::post_update(S_data, rhs_time); + BaseT::RhsEx(S_rhs, S_data, rhs_time); + + return 0; + }; + + udata.ff = [&](sunrealtype rhs_time, N_Vector y_data, N_Vector y_rhs, void * /* user_data */) -> int { + + amrex::Vector S_data; + amrex::Vector S_rhs; + + unpack_vectors(y_data, S_data, y_rhs, S_rhs); + + BaseT::post_update(S_data, rhs_time); + BaseT::RhsFast(S_rhs, S_data, rhs_time); + + return 0; + }; // Create an N_Vector wrapper for the solution MultiFab auto get_length = [&](int index) -> sunindextype { From 4d8ac2100584e7227afc366b99e55de6a08a46a2 Mon Sep 17 00:00:00 2001 From: "David J. Gardner" Date: Sat, 1 Jun 2024 20:50:49 -0700 Subject: [PATCH 08/36] add mri --- .../SUNDIALS/AMReX_SundialsIntegrator.H | 187 ++++++++++++++---- 1 file changed, 146 insertions(+), 41 deletions(-) diff --git a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H index b1bca47b252..ff3b1444f18 100644 --- a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H +++ b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H @@ -14,6 +14,7 @@ #include #include #include +#include namespace amrex { @@ -22,8 +23,8 @@ struct SundialsUserData { // ExMRI or ImMRI slow right-hand side function std::function f; - // ImEx right-hand side functions - // ImExMRI slow right-hand side functions + // ImEx-RK right-hand side functions + // ImEx-MRI slow right-hand side functions std::function fi; std::function fe; @@ -65,24 +66,33 @@ private: std::string method_e = "DEFAULT"; std::string method_i = "DEFAULT"; - // method type flags - bool use_erk = false; - bool use_dirk = false; - bool use_imex = false; - bool use_mri = false; + std::string fast_type = "ERK"; + std::string fast_method = "DEFAULT"; + + // SUNDIALS package flags, set based on type + bool use_ark = false; + bool use_mri = false; // structure for interfacing with user-supplied functions SundialsUserData udata; // SUNDIALS objects - SUNContext sunctx = nullptr; // should get from singleton + SUNContext sunctx = nullptr; + void *arkode_mem = nullptr; SUNLinearSolver LS = nullptr; + void *arkode_fast_mem = nullptr; + MRIStepInnerStepper fast_stepper = nullptr; + SUNLinearSolver fast_LS = nullptr; + // relative and absolute tolerances Real reltol = 1.0e-4; Real abstol = 1.0e-9; + Real fast_reltol = 1.0e-4; + Real fast_abstol = 1.0e-9; + int NVar; // NOTE: expects S_data to be a Vector N_Vector* nv_many_arr; /* vector array composed of cons, xmom, ymom, zmom component vectors */ N_Vector nv_S; @@ -93,15 +103,17 @@ private: pp.query("type", type); pp.query("method", method); + pp.query("method_e", method); + pp.query("method_i", method); - if (type == "ERK") { - use_erk = true; - } - else if (type == "DIRK") { - use_dirk = true; + pp.query("fast_type", fast_type); + pp.query("fast_method", fast_method); + + if (type == "ERK" || type == "DIRK" || type == "IMEX-RK") { + use_ark = true; } - else if (type == "IMEX") { - use_imex = true; + else if (type == "EX-MRI" || type == "IM-MRI" || type == "IMEX-MRI") { + use_mri = true; } else { std::string msg("Unknown strategy: "); @@ -132,7 +144,7 @@ public: # endif #endif - /* Utility to unpack SUNDIALS vectors */ + // Utility to unpack SUNDIALS vectors auto unpack_vectors = [&](N_Vector y_data, amrex::Vector& S_data, N_Vector y_rhs, amrex::Vector& S_rhs) -> void { @@ -154,7 +166,7 @@ public: } }; - /* Right-hand side function wrappers */ + // Right-hand side function wrappers udata.f = [&](sunrealtype rhs_time, N_Vector y_data, N_Vector y_rhs, void * /* user_data */) -> int { amrex::Vector S_data; @@ -213,7 +225,7 @@ public: return p_mf->nComp() * (p_mf->boxArray()).numPts(); }; - NVar = S_data.size(); // NOTE: expects S_data to be a Vector + NVar = S_data.size(); // NOTE: expects S_data to be a Vector nv_many_arr = new N_Vector[NVar]; // vector array composed of cons, xmom, ymom, zmom component vectors */ for (int i = 0; i < NVar; ++i) { @@ -232,42 +244,129 @@ public: } nv_S = N_VNew_ManyVector(NVar, nv_many_arr, sunctx); - if (use_erk) { + if (use_ark) { + SetupRK(time, nv_S); + } + else if (use_mri) + { + SetupMRI(time, nv_S); + } + } + + void SetupRK(amrex::Real time, N_Vector y_data) + { + // Create integrator and select method + if (type == "ERK") { amrex::Print() << "SUNDIALS ERK time integrator\n"; arkode_mem = ARKStepCreate(SundialsUserFun::f, nullptr, time, nv_S, sunctx); + + if (method != "DEFAULT") { + amrex::Print() << "SUNDIALS ERK method " << method << "\n"; + ARKStepSetTableName(arkode_mem, "ARKODE_DIRK_NONE", method.c_str()); + } } - else if (use_dirk) { + else if (type == "DIRK") { amrex::Print() << "SUNDIALS DIRK time integrator\n"; arkode_mem = ARKStepCreate(nullptr, SundialsUserFun::f, time, nv_S, sunctx); + + if (method != "DEFAULT") { + amrex::Print() << "SUNDIALS DIRK method " << method << "\n"; + ARKStepSetTableName(arkode_mem, method.c_str(), "ARKODE_ERK_NONE"); + } } - else if (use_imex) { + else if (type == "IMEX-RK") { amrex::Print() << "SUNDIALS IMEX time integrator\n"; arkode_mem = ARKStepCreate(SundialsUserFun::fe, SundialsUserFun::fi, time, nv_S, sunctx); + + if (method_e != "DEFAULT" && method_i != "DEFAULT") + { + amrex::Print() << "SUNDIALS IMEX method " << method_i << " and " + << method_e << "\n"; + ARKStepSetTableName(arkode_mem, method_i.c_str(), method_e.c_str()); + } } - ARKStepSetUserData(arkode_mem, &udata); /* Pass udata to user functions */ + + // Attach structure with user-supplied function wrappers + ARKStepSetUserData(arkode_mem, &udata); + + // Set integrator tolerances ARKStepSStolerances(arkode_mem, reltol, abstol); - if (method != "DEFAULT" || (method_e != "DEFAULT" && method_i != "DEFAULT")) - { - if (use_erk) { + // Create and attach linear solver for implicit methods + if (type == "DIRK" || type == "IMEX-RK") { + LS = SUNLinSol_SPGMR(nv_S, SUN_PREC_NONE, 0, sunctx); + ARKStepSetLinearSolver(arkode_mem, LS, nullptr); + } + } + + void SetupMRI(amrex::Real time, N_Vector y_data) + { + // Create the fast integrator and select method + if (fast_type == "ERK") { + amrex::Print() << "SUNDIALS ERK time integrator\n"; + arkode_fast_mem = ARKStepCreate(SundialsUserFun::ff, nullptr, time, nv_S, sunctx); + + if (method != "DEFAULT") { amrex::Print() << "SUNDIALS ERK method " << method << "\n"; - ARKStepSetTableName(arkode_mem, "ARKODE_DIRK_NONE", method.c_str()); + ARKStepSetTableName(arkode_fast_mem, "ARKODE_DIRK_NONE", fast_method.c_str()); } - else if (use_dirk) { + } + else if (fast_type == "DIRK") { + amrex::Print() << "SUNDIALS DIRK time integrator\n"; + arkode_fast_mem = ARKStepCreate(nullptr, SundialsUserFun::ff, time, nv_S, sunctx); + + if (method != "DEFAULT") { amrex::Print() << "SUNDIALS DIRK method " << method << "\n"; - ARKStepSetTableName(arkode_mem, method.c_str(), "ARKODE_ERK_NONE"); - } - else { - amrex::Print() << "SUNDIALS IMEX method " << method_i << " and " - << method_e << "\n"; - ARKStepSetTableName(arkode_mem, method_i.c_str(), method_e.c_str()); + ARKStepSetTableName(arkode_fast_mem, fast_method.c_str(), "ARKODE_ERK_NONE"); } + + fast_LS = SUNLinSol_SPGMR(nv_S, SUN_PREC_NONE, 0, sunctx); + ARKStepSetLinearSolver(arkode_fast_mem, fast_LS, nullptr); } + // Attach structure with user-supplied function wrappers + ARKStepSetUserData(arkode_fast_mem, &udata); + + // Set integrator tolerances + ARKStepSStolerances(arkode_fast_mem, fast_reltol, fast_abstol); + + // Wrap fast integrator as an inner stepper + ARKStepCreateMRIStepInnerStepper(arkode_fast_mem, &fast_stepper); + + // Create slow integrator + if (type == "EX-MRI") { + amrex::Print() << "SUNDIALS ERK time integrator\n"; + arkode_mem = MRIStepCreate(SundialsUserFun::f, nullptr, time, nv_S, + fast_stepper, sunctx); + } + else if (type == "IM-MRI") { + amrex::Print() << "SUNDIALS DIRK time integrator\n"; + arkode_mem = MRIStepCreate(nullptr, SundialsUserFun::f, time, nv_S, + fast_stepper, sunctx); + } + else if (type == "IMEX-MRI") { + amrex::Print() << "SUNDIALS IMEX time integrator\n"; + arkode_mem = MRIStepCreate(SundialsUserFun::fe, SundialsUserFun::fi, + time, nv_S, fast_stepper, sunctx); + } + + // Set method + if (method != "DEFAULT") { + MRIStepCoupling MRIC = MRIStepCoupling_LoadTableByName(method.c_str()); + MRIStepSetCoupling(arkode_mem, MRIC); + MRIStepCoupling_Free(MRIC); + } + + // Attach structure with user-supplied function wrappers + MRIStepSetUserData(arkode_mem, &udata); + + // Set integrator tolerances + MRIStepSStolerances(arkode_mem, reltol, abstol); + // Create and attach linear solver - if (use_dirk || use_imex) { + if (type == "IM-MRI" || type == "IMEX-MRI") { LS = SUNLinSol_SPGMR(nv_S, SUN_PREC_NONE, 0, sunctx); - ARKStepSetLinearSolver(arkode_mem, LS, nullptr); + MRIStepSetLinearSolver(arkode_mem, LS, nullptr); } } @@ -278,6 +377,10 @@ public: } delete[] nv_many_arr; N_VDestroy(nv_S); + SUNLinSolFree(LS); + SUNLinSolFree(fast_LS); + MRIStepInnerStepper_Free(&fast_stepper); + MRIStepFree(&arkode_fast_mem); ARKStepFree(&arkode_mem); SUNContext_Free(&sunctx); } @@ -293,14 +396,17 @@ public: MultiFab::Copy(*amrex::sundials::getMFptr(N_VGetSubvector_ManyVector(nv_S, i)), S_old[i], 0, 0, S_old[i].nComp(), S_old[i].nGrow()); } - if (use_erk || use_dirk || use_imex) { + if (use_ark) { ARKStepReset(arkode_mem, time, nv_S); ARKStepSetFixedStep(arkode_mem, time_step); int flag = ARKStepEvolve(arkode_mem, tout, nv_S, &tret, ARK_ONE_STEP); AMREX_ALWAYS_ASSERT(flag >= 0); } else if (use_mri) { - Error("SUNDIALS integrator type not implemented, yet."); + MRIStepReset(arkode_mem, time, nv_S); + MRIStepSetFixedStep(arkode_mem, time_step); + int flag = MRIStepEvolve(arkode_mem, tout, nv_S, &tret, ARK_ONE_STEP); + AMREX_ALWAYS_ASSERT(flag >= 0); } else { Error("SUNDIALS integrator type not specified."); } @@ -318,12 +424,13 @@ public: { amrex::Real tret; - if (use_erk || use_dirk || use_imex) { + if (use_ark) { int flag = ARKStepEvolve(arkode_mem, t_out, nv_S, &tret, ARK_NORMAL); AMREX_ALWAYS_ASSERT(flag >= 0); } else if (use_mri) { - Error("SUNDIALS integrator type not implemented, yet."); + int flag = MRIStepEvolve(arkode_mem, t_out, nv_S, &tret, ARK_NORMAL); + AMREX_ALWAYS_ASSERT(flag >= 0); } else { Error("SUNDIALS integrator type not specified."); } @@ -340,8 +447,6 @@ public: } } - - void time_interpolate (const T& /* S_new */, const T& /* S_old */, amrex::Real /* timestep_fraction */, T& /* data */) override {} void map_data (std::function /* Map */) override {} From b8eb59d381acce17edb847cdff0f82d25366667f Mon Sep 17 00:00:00 2001 From: "David J. Gardner" Date: Sat, 1 Jun 2024 21:22:17 -0700 Subject: [PATCH 09/36] change post_update to pre_rhs_update --- Src/Base/AMReX_FEIntegrator.H | 6 +++--- Src/Base/AMReX_IntegratorBase.H | 10 +++++----- Src/Base/AMReX_RKIntegrator.H | 9 +++------ Src/Base/AMReX_TimeIntegrator.H | 6 +++--- Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H | 8 ++++---- 5 files changed, 18 insertions(+), 21 deletions(-) diff --git a/Src/Base/AMReX_FEIntegrator.H b/Src/Base/AMReX_FEIntegrator.H index f4516220ed2..5fdf3fc5b4b 100644 --- a/Src/Base/AMReX_FEIntegrator.H +++ b/Src/Base/AMReX_FEIntegrator.H @@ -42,6 +42,9 @@ public: // So we initialize S_new by copying the old state. IntegratorOps::Copy(S_new, S_old); + // Call the update hook for S_new + BaseT::pre_rhs_update(S_new, time + BaseT::timestep); + // F = RHS(S, t) T& F = *F_nodes[0]; BaseT::Rhs(F, S_new, time); @@ -49,9 +52,6 @@ public: // S_new += timestep * dS/dt IntegratorOps::Saxpy(S_new, BaseT::timestep, F); - // Call the post-update hook for S_new - BaseT::post_update(S_new, time + BaseT::timestep); - // Return timestep return BaseT::timestep; } diff --git a/Src/Base/AMReX_IntegratorBase.H b/Src/Base/AMReX_IntegratorBase.H index 0956bf02a72..4ea7f656f3f 100644 --- a/Src/Base/AMReX_IntegratorBase.H +++ b/Src/Base/AMReX_IntegratorBase.H @@ -186,10 +186,10 @@ protected: std::function RhsFast; /** - * \brief The post_update function is called by the integrator on state data - * before using it to evaluate a right-hand side. + * \brief The pre_rhs_update function is called by the integrator on state + * data before using it to evaluate a right-hand side. */ - std::function post_update; + std::function pre_rhs_update; /** * \brief Integrator timestep size (Real) @@ -222,9 +222,9 @@ public: RhsFast = F; } - void set_post_update (std::function F) + void set_pre_rhs_update (std::function F) { - post_update = F; + pre_rhs_update = F; } virtual amrex::Real advance (T& S_old, T& S_new, amrex::Real time, amrex::Real dt) = 0; diff --git a/Src/Base/AMReX_RKIntegrator.H b/Src/Base/AMReX_RKIntegrator.H index cf33238e287..b6835744a1b 100644 --- a/Src/Base/AMReX_RKIntegrator.H +++ b/Src/Base/AMReX_RKIntegrator.H @@ -193,11 +193,11 @@ public: { IntegratorOps::Saxpy(S_new, BaseT::timestep * tableau[i][j], *F_nodes[j]); } - - // Call the post-update hook for the stage state value - BaseT::post_update(S_new, stage_time); } + // Call the update hook for the stage state value + BaseT::pre_rhs_update(S_new, stage_time); + // Fill F[i], the RHS at the current stage // F[i] = RHS(y, t) at y = stage_value, t = stage_time BaseT::Rhs(*F_nodes[i], S_new, stage_time); @@ -212,9 +212,6 @@ public: IntegratorOps::Saxpy(S_new, BaseT::timestep * weights[i], *F_nodes[i]); } - // Call the post-update hook for S_new - BaseT::post_update(S_new, time + BaseT::timestep); - // If we are working with an extended Butcher tableau, we can estimate the error here, // and then calculate an adaptive timestep. diff --git a/Src/Base/AMReX_TimeIntegrator.H b/Src/Base/AMReX_TimeIntegrator.H index b69f31204c2..25027973759 100644 --- a/Src/Base/AMReX_TimeIntegrator.H +++ b/Src/Base/AMReX_TimeIntegrator.H @@ -72,7 +72,7 @@ private: // By default, do nothing after updating the state // In general, this is where BCs should be filled - set_post_update([](T& /* S_data */, amrex::Real /* S_time */){}); + set_pre_rhs_update([](T& /* S_data */, amrex::Real /* S_time */){}); // By default, do nothing set_rhs([](T& /* S_rhs */, const T& /* S_data */, const amrex::Real /* time */){}); @@ -140,9 +140,9 @@ public: post_timestep = F; } - void set_post_update (std::function F) + void set_pre_rhs_update (std::function F) { - integrator_ptr->set_post_update(F); + integrator_ptr->set_pre_rhs_update(F); } void set_rhs (std::function F) diff --git a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H index ff3b1444f18..ee07c9cfead 100644 --- a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H +++ b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H @@ -174,7 +174,7 @@ public: unpack_vectors(y_data, S_data, y_rhs, S_rhs); - BaseT::post_update(S_data, rhs_time); + BaseT::pre_rhs_update(S_data, rhs_time); BaseT::Rhs(S_rhs, S_data, rhs_time); return 0; @@ -187,7 +187,7 @@ public: unpack_vectors(y_data, S_data, y_rhs, S_rhs); - BaseT::post_update(S_data, rhs_time); + BaseT::pre_rhs_update(S_data, rhs_time); BaseT::RhsIm(S_rhs, S_data, rhs_time); return 0; @@ -200,7 +200,7 @@ public: unpack_vectors(y_data, S_data, y_rhs, S_rhs); - BaseT::post_update(S_data, rhs_time); + BaseT::pre_rhs_update(S_data, rhs_time); BaseT::RhsEx(S_rhs, S_data, rhs_time); return 0; @@ -213,7 +213,7 @@ public: unpack_vectors(y_data, S_data, y_rhs, S_rhs); - BaseT::post_update(S_data, rhs_time); + BaseT::pre_rhs_update(S_data, rhs_time); BaseT::RhsFast(S_rhs, S_data, rhs_time); return 0; From a9f544a058fce17d3f199bd7f0f55c8f6af1fd80 Mon Sep 17 00:00:00 2001 From: "David J. Gardner" Date: Mon, 3 Jun 2024 22:09:29 -0700 Subject: [PATCH 10/36] rename update to action --- Src/Base/AMReX_FEIntegrator.H | 2 +- Src/Base/AMReX_IntegratorBase.H | 6 +++--- Src/Base/AMReX_RKIntegrator.H | 2 +- Src/Base/AMReX_TimeIntegrator.H | 6 +++--- Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H | 8 ++++---- 5 files changed, 12 insertions(+), 12 deletions(-) diff --git a/Src/Base/AMReX_FEIntegrator.H b/Src/Base/AMReX_FEIntegrator.H index 5fdf3fc5b4b..6efbe2a1afd 100644 --- a/Src/Base/AMReX_FEIntegrator.H +++ b/Src/Base/AMReX_FEIntegrator.H @@ -43,7 +43,7 @@ public: IntegratorOps::Copy(S_new, S_old); // Call the update hook for S_new - BaseT::pre_rhs_update(S_new, time + BaseT::timestep); + BaseT::pre_rhs_action(S_new, time + BaseT::timestep); // F = RHS(S, t) T& F = *F_nodes[0]; diff --git a/Src/Base/AMReX_IntegratorBase.H b/Src/Base/AMReX_IntegratorBase.H index 4ea7f656f3f..406086c7942 100644 --- a/Src/Base/AMReX_IntegratorBase.H +++ b/Src/Base/AMReX_IntegratorBase.H @@ -189,7 +189,7 @@ protected: * \brief The pre_rhs_update function is called by the integrator on state * data before using it to evaluate a right-hand side. */ - std::function pre_rhs_update; + std::function pre_rhs_action; /** * \brief Integrator timestep size (Real) @@ -222,9 +222,9 @@ public: RhsFast = F; } - void set_pre_rhs_update (std::function F) + void set_pre_rhs_action (std::function F) { - pre_rhs_update = F; + pre_rhs_action = F; } virtual amrex::Real advance (T& S_old, T& S_new, amrex::Real time, amrex::Real dt) = 0; diff --git a/Src/Base/AMReX_RKIntegrator.H b/Src/Base/AMReX_RKIntegrator.H index b6835744a1b..50dc1678d0e 100644 --- a/Src/Base/AMReX_RKIntegrator.H +++ b/Src/Base/AMReX_RKIntegrator.H @@ -196,7 +196,7 @@ public: } // Call the update hook for the stage state value - BaseT::pre_rhs_update(S_new, stage_time); + BaseT::pre_rhs_action(S_new, stage_time); // Fill F[i], the RHS at the current stage // F[i] = RHS(y, t) at y = stage_value, t = stage_time diff --git a/Src/Base/AMReX_TimeIntegrator.H b/Src/Base/AMReX_TimeIntegrator.H index 25027973759..a5e92a14efe 100644 --- a/Src/Base/AMReX_TimeIntegrator.H +++ b/Src/Base/AMReX_TimeIntegrator.H @@ -72,7 +72,7 @@ private: // By default, do nothing after updating the state // In general, this is where BCs should be filled - set_pre_rhs_update([](T& /* S_data */, amrex::Real /* S_time */){}); + set_pre_rhs_action([](T& /* S_data */, amrex::Real /* S_time */){}); // By default, do nothing set_rhs([](T& /* S_rhs */, const T& /* S_data */, const amrex::Real /* time */){}); @@ -140,9 +140,9 @@ public: post_timestep = F; } - void set_pre_rhs_update (std::function F) + void set_pre_rhs_action (std::function F) { - integrator_ptr->set_pre_rhs_update(F); + integrator_ptr->set_pre_rhs_action(F); } void set_rhs (std::function F) diff --git a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H index ee07c9cfead..f0cd521a285 100644 --- a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H +++ b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H @@ -174,7 +174,7 @@ public: unpack_vectors(y_data, S_data, y_rhs, S_rhs); - BaseT::pre_rhs_update(S_data, rhs_time); + BaseT::pre_rhs_action(S_data, rhs_time); BaseT::Rhs(S_rhs, S_data, rhs_time); return 0; @@ -187,7 +187,7 @@ public: unpack_vectors(y_data, S_data, y_rhs, S_rhs); - BaseT::pre_rhs_update(S_data, rhs_time); + BaseT::pre_rhs_action(S_data, rhs_time); BaseT::RhsIm(S_rhs, S_data, rhs_time); return 0; @@ -200,7 +200,7 @@ public: unpack_vectors(y_data, S_data, y_rhs, S_rhs); - BaseT::pre_rhs_update(S_data, rhs_time); + BaseT::pre_rhs_action(S_data, rhs_time); BaseT::RhsEx(S_rhs, S_data, rhs_time); return 0; @@ -213,7 +213,7 @@ public: unpack_vectors(y_data, S_data, y_rhs, S_rhs); - BaseT::pre_rhs_update(S_data, rhs_time); + BaseT::pre_rhs_action(S_data, rhs_time); BaseT::RhsFast(S_rhs, S_data, rhs_time); return 0; From e9960909a622b785671f781922f4b3f99932c317 Mon Sep 17 00:00:00 2001 From: "David J. Gardner" Date: Mon, 3 Jun 2024 22:41:17 -0700 Subject: [PATCH 11/36] add post step action --- Src/Base/AMReX_FEIntegrator.H | 12 ++-- Src/Base/AMReX_IntegratorBase.H | 17 ++++- Src/Base/AMReX_RKIntegrator.H | 47 ++++++------ Src/Base/AMReX_TimeIntegrator.H | 22 ++++-- .../SUNDIALS/AMReX_SundialsIntegrator.H | 71 +++++++++++-------- 5 files changed, 104 insertions(+), 65 deletions(-) diff --git a/Src/Base/AMReX_FEIntegrator.H b/Src/Base/AMReX_FEIntegrator.H index 6efbe2a1afd..7857268611c 100644 --- a/Src/Base/AMReX_FEIntegrator.H +++ b/Src/Base/AMReX_FEIntegrator.H @@ -37,23 +37,25 @@ public: amrex::Real advance (T& S_old, T& S_new, amrex::Real time, const amrex::Real time_step) override { - BaseT::timestep = time_step; // Assume before advance() that S_old is valid data at the current time ("time" argument) // So we initialize S_new by copying the old state. IntegratorOps::Copy(S_new, S_old); - // Call the update hook for S_new - BaseT::pre_rhs_action(S_new, time + BaseT::timestep); + // Call the pre RHS hook + BaseT::pre_rhs_action(S_new, time); // F = RHS(S, t) T& F = *F_nodes[0]; BaseT::Rhs(F, S_new, time); // S_new += timestep * dS/dt - IntegratorOps::Saxpy(S_new, BaseT::timestep, F); + IntegratorOps::Saxpy(S_new, time_step, F); + + // Call the post step hook + BaseT::post_step_action(S_new, time + time_step); // Return timestep - return BaseT::timestep; + return time_step; } void evolve (T& S_out, const amrex::Real t_out) override diff --git a/Src/Base/AMReX_IntegratorBase.H b/Src/Base/AMReX_IntegratorBase.H index 406086c7942..e6ad56b9492 100644 --- a/Src/Base/AMReX_IntegratorBase.H +++ b/Src/Base/AMReX_IntegratorBase.H @@ -186,15 +186,21 @@ protected: std::function RhsFast; /** - * \brief The pre_rhs_update function is called by the integrator on state + * \brief The pre_rhs_action function is called by the integrator on state * data before using it to evaluate a right-hand side. */ std::function pre_rhs_action; /** - * \brief Integrator timestep size (Real) + * \brief The post_step_action function is called by the integrator on + * computed state just after it is computed */ - amrex::Real timestep; + std::function post_step_action; + + // /** + // * \brief Integrator timestep size (Real) + // */ + // amrex::Real timestep; public: IntegratorBase () = default; @@ -227,6 +233,11 @@ public: pre_rhs_action = F; } + void set_post_step_action (std::function F) + { + post_step_action = F; + } + virtual amrex::Real advance (T& S_old, T& S_new, amrex::Real time, amrex::Real dt) = 0; virtual void evolve (T& S_out, const amrex::Real t_out) = 0; diff --git a/Src/Base/AMReX_RKIntegrator.H b/Src/Base/AMReX_RKIntegrator.H index 50dc1678d0e..b5622eef776 100644 --- a/Src/Base/AMReX_RKIntegrator.H +++ b/Src/Base/AMReX_RKIntegrator.H @@ -170,7 +170,6 @@ public: amrex::Real advance (T& S_old, T& S_new, amrex::Real time, const amrex::Real time_step) override { - BaseT::timestep = time_step; // Assume before advance() that S_old is valid data at the current time ("time" argument) // And that if data is a MultiFab, both S_old and S_new contain ghost cells for evaluating a stencil based RHS // We need this from S_old. This is convenient for S_new to have so we can use it @@ -180,7 +179,7 @@ public: for (int i = 0; i < number_nodes; ++i) { // Get current stage time, t = t_old + h * Ci - amrex::Real stage_time = time + BaseT::timestep * nodes[i]; + amrex::Real stage_time = time + time_step * nodes[i]; // Fill S_new with the solution value for evaluating F at the current stage // Copy S_new = S_old @@ -191,7 +190,7 @@ public: // We should fuse these kernels ... for (int j = 0; j < i; ++j) { - IntegratorOps::Saxpy(S_new, BaseT::timestep * tableau[i][j], *F_nodes[j]); + IntegratorOps::Saxpy(S_new, time_step * tableau[i][j], *F_nodes[j]); } } @@ -209,14 +208,16 @@ public: IntegratorOps::Copy(S_new, S_old); for (int i = 0; i < number_nodes; ++i) { - IntegratorOps::Saxpy(S_new, BaseT::timestep * weights[i], *F_nodes[i]); + IntegratorOps::Saxpy(S_new, time_step * weights[i], *F_nodes[i]); } + BaseT::post_step_action(S_new, time + time_step); + // If we are working with an extended Butcher tableau, we can estimate the error here, // and then calculate an adaptive timestep. // Return timestep - return BaseT::timestep; + return time_step; } void evolve (T& S_out, const amrex::Real t_out) override @@ -235,30 +236,32 @@ public: */ + // Update to use the last time step taken + // currently we only do this for 4th order RK - AMREX_ASSERT(number_nodes == 4); + // AMREX_ASSERT(number_nodes == 4); - // fill data using MC Equation 39 at time + timestep_fraction * dt - amrex::Real c = 0; + // // fill data using MC Equation 39 at time + timestep_fraction * dt + // amrex::Real c = 0; - // data = S_old - IntegratorOps::Copy(data, S_old); + // // data = S_old + // IntegratorOps::Copy(data, S_old); - // data += (chi - 3/2 * chi^2 + 2/3 * chi^3) * k1 - c = timestep_fraction - 1.5 * std::pow(timestep_fraction, 2) + 2./3. * std::pow(timestep_fraction, 3); - IntegratorOps::Saxpy(data, c*BaseT::timestep, *F_nodes[0]); + // // data += (chi - 3/2 * chi^2 + 2/3 * chi^3) * k1 + // c = timestep_fraction - 1.5 * std::pow(timestep_fraction, 2) + 2./3. * std::pow(timestep_fraction, 3); + // IntegratorOps::Saxpy(data, c*BaseT::timestep, *F_nodes[0]); - // data += (chi^2 - 2/3 * chi^3) * k2 - c = std::pow(timestep_fraction, 2) - 2./3. * std::pow(timestep_fraction, 3); - IntegratorOps::Saxpy(data, c*BaseT::timestep, *F_nodes[1]); + // // data += (chi^2 - 2/3 * chi^3) * k2 + // c = std::pow(timestep_fraction, 2) - 2./3. * std::pow(timestep_fraction, 3); + // IntegratorOps::Saxpy(data, c*BaseT::timestep, *F_nodes[1]); - // data += (chi^2 - 2/3 * chi^3) * k3 - c = std::pow(timestep_fraction, 2) - 2./3. * std::pow(timestep_fraction, 3); - IntegratorOps::Saxpy(data, c*BaseT::timestep, *F_nodes[2]); + // // data += (chi^2 - 2/3 * chi^3) * k3 + // c = std::pow(timestep_fraction, 2) - 2./3. * std::pow(timestep_fraction, 3); + // IntegratorOps::Saxpy(data, c*BaseT::timestep, *F_nodes[2]); - // data += (-1/2 * chi^2 + 2/3 * chi^3) * k4 - c = -0.5 * std::pow(timestep_fraction, 2) + 2./3. * std::pow(timestep_fraction, 3); - IntegratorOps::Saxpy(data, c*BaseT::timestep, *F_nodes[3]); + // // data += (-1/2 * chi^2 + 2/3 * chi^3) * k4 + // c = -0.5 * std::pow(timestep_fraction, 2) + 2./3. * std::pow(timestep_fraction, 3); + // IntegratorOps::Saxpy(data, c*BaseT::timestep, *F_nodes[3]); } diff --git a/Src/Base/AMReX_TimeIntegrator.H b/Src/Base/AMReX_TimeIntegrator.H index a5e92a14efe..92ef53d8883 100644 --- a/Src/Base/AMReX_TimeIntegrator.H +++ b/Src/Base/AMReX_TimeIntegrator.H @@ -70,16 +70,19 @@ private: // By default, do nothing post-timestep set_post_timestep([](){}); - // By default, do nothing after updating the state + // By default, do nothing after before calling the RHS // In general, this is where BCs should be filled set_pre_rhs_action([](T& /* S_data */, amrex::Real /* S_time */){}); - // By default, do nothing + // By default, do nothing in the RHS set_rhs([](T& /* S_rhs */, const T& /* S_data */, const amrex::Real /* time */){}); set_imex_rhs([](T& /* S_rhs */, const T& /* S_data */, const amrex::Real /* time */){}, [](T& /* S_rhs */, const T& /* S_data */, const amrex::Real /* time */){}); set_fast_rhs([](T& /* S_rhs */, const T& /* S_data */, const amrex::Real /* time */){}); + // By default, do nothing after a step + set_post_step_action([](T& /* S_rhs */, const amrex::Real /* time */){}); + // By default, initialize time, timestep, step number to 0's m_time = 0.0_rt; m_timestep = 0.0_rt; @@ -140,11 +143,6 @@ public: post_timestep = F; } - void set_pre_rhs_action (std::function F) - { - integrator_ptr->set_pre_rhs_action(F); - } - void set_rhs (std::function F) { integrator_ptr->set_rhs(F); @@ -161,6 +159,16 @@ public: integrator_ptr->set_fast_rhs(F); } + void set_pre_rhs_action (std::function F) + { + integrator_ptr->set_pre_rhs_action(F); + } + + void set_post_step_action (std::function F) + { + integrator_ptr->set_post_step_action(F); + } + int get_step_number () { return m_step_number; diff --git a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H index f0cd521a285..e808e43bdd8 100644 --- a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H +++ b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H @@ -21,37 +21,45 @@ namespace amrex { struct SundialsUserData { // ERK or DIRK right-hand side function // ExMRI or ImMRI slow right-hand side function - std::function f; + std::function f; // ImEx-RK right-hand side functions // ImEx-MRI slow right-hand side functions - std::function fi; - std::function fe; + std::function fi; + std::function fe; // MRI fast time scale right-hand side function - std::function ff; + std::function ff; + + // Post step actions + std::function post_step; }; namespace SundialsUserFun { - static int f (sunrealtype t, N_Vector y_data, N_Vector y_rhs, void *user_data) { + static int f (amrex::Real t, N_Vector y_data, N_Vector y_rhs, void *user_data) { SundialsUserData* udata = static_cast(user_data); return udata->f(t, y_data, y_rhs, user_data); } - static int fi (sunrealtype t, N_Vector y_data, N_Vector y_rhs, void *user_data) { + static int fi (amrex::Real t, N_Vector y_data, N_Vector y_rhs, void *user_data) { SundialsUserData* udata = static_cast(user_data); return udata->fi(t, y_data, y_rhs, user_data); } - static int fe (sunrealtype t, N_Vector y_data, N_Vector y_rhs, void *user_data) { + static int fe (amrex::Real t, N_Vector y_data, N_Vector y_rhs, void *user_data) { SundialsUserData* udata = static_cast(user_data); return udata->fe(t, y_data, y_rhs, user_data); } - static int ff (sunrealtype t, N_Vector y_data, N_Vector y_rhs, void *user_data) { + static int ff (amrex::Real t, N_Vector y_data, N_Vector y_rhs, void *user_data) { SundialsUserData* udata = static_cast(user_data); return udata->ff(t, y_data, y_rhs, user_data); } + + static int post_step (amrex::Real t, N_Vector y_data, void *user_data) { + SundialsUserData* udata = static_cast(user_data); + return udata->post_step(t, y_data, user_data); + } } template @@ -145,12 +153,10 @@ public: #endif // Utility to unpack SUNDIALS vectors - auto unpack_vectors = [&](N_Vector y_data, amrex::Vector& S_data, - N_Vector y_rhs, amrex::Vector& S_rhs) -> void { + auto unpack_vector = [&](N_Vector y_data, amrex::Vector& S_data) -> void { const int num_vecs = N_VGetNumSubvectors_ManyVector(y_data); S_data.resize(num_vecs); - S_rhs.resize(num_vecs); for(int i = 0; i < num_vecs; i++) { @@ -158,21 +164,17 @@ public: amrex::make_alias, 0, amrex::sundials::getMFptr(N_VGetSubvector_ManyVector(y_data, i))->nComp()); - - S_rhs.at(i)=amrex::MultiFab(*amrex::sundials::getMFptr(N_VGetSubvector_ManyVector(y_rhs, i)), - amrex::make_alias, - 0, - amrex::sundials::getMFptr(N_VGetSubvector_ManyVector(y_rhs, i))->nComp()); } }; // Right-hand side function wrappers - udata.f = [&](sunrealtype rhs_time, N_Vector y_data, N_Vector y_rhs, void * /* user_data */) -> int { + udata.f = [&](amrex::Real rhs_time, N_Vector y_data, N_Vector y_rhs, void * /* user_data */) -> int { amrex::Vector S_data; - amrex::Vector S_rhs; + unpack_vector(y_data, S_data); - unpack_vectors(y_data, S_data, y_rhs, S_rhs); + amrex::Vector S_rhs; + unpack_vector(y_rhs, S_rhs); BaseT::pre_rhs_action(S_data, rhs_time); BaseT::Rhs(S_rhs, S_data, rhs_time); @@ -180,12 +182,13 @@ public: return 0; }; - udata.fi = [&](sunrealtype rhs_time, N_Vector y_data, N_Vector y_rhs, void * /* user_data */) -> int { + udata.fi = [&](amrex::Real rhs_time, N_Vector y_data, N_Vector y_rhs, void * /* user_data */) -> int { amrex::Vector S_data; - amrex::Vector S_rhs; + unpack_vector(y_data, S_data); - unpack_vectors(y_data, S_data, y_rhs, S_rhs); + amrex::Vector S_rhs; + unpack_vector(y_rhs, S_rhs); BaseT::pre_rhs_action(S_data, rhs_time); BaseT::RhsIm(S_rhs, S_data, rhs_time); @@ -193,12 +196,13 @@ public: return 0; }; - udata.fe = [&](sunrealtype rhs_time, N_Vector y_data, N_Vector y_rhs, void * /* user_data */) -> int { + udata.fe = [&](amrex::Real rhs_time, N_Vector y_data, N_Vector y_rhs, void * /* user_data */) -> int { amrex::Vector S_data; - amrex::Vector S_rhs; + unpack_vector(y_data, S_data); - unpack_vectors(y_data, S_data, y_rhs, S_rhs); + amrex::Vector S_rhs; + unpack_vector(y_rhs, S_rhs); BaseT::pre_rhs_action(S_data, rhs_time); BaseT::RhsEx(S_rhs, S_data, rhs_time); @@ -206,12 +210,13 @@ public: return 0; }; - udata.ff = [&](sunrealtype rhs_time, N_Vector y_data, N_Vector y_rhs, void * /* user_data */) -> int { + udata.ff = [&](amrex::Real rhs_time, N_Vector y_data, N_Vector y_rhs, void * /* user_data */) -> int { amrex::Vector S_data; - amrex::Vector S_rhs; + unpack_vector(y_data, S_data); - unpack_vectors(y_data, S_data, y_rhs, S_rhs); + amrex::Vector S_rhs; + unpack_vector(y_rhs, S_rhs); BaseT::pre_rhs_action(S_data, rhs_time); BaseT::RhsFast(S_rhs, S_data, rhs_time); @@ -219,6 +224,16 @@ public: return 0; }; + udata.post_step = [&](amrex::Real time, N_Vector y_data, void * /* user_data */) -> int { + + amrex::Vector S_data; + unpack_vector(y_data, S_data); + + BaseT::post_step_action(S_data, time); + + return 0; + }; + // Create an N_Vector wrapper for the solution MultiFab auto get_length = [&](int index) -> sunindextype { auto* p_mf = &S_data[index]; From 5713a0fdb7ad9210aeabba35b05d401ec4f1929e Mon Sep 17 00:00:00 2001 From: "David J. Gardner" Date: Mon, 3 Jun 2024 22:47:18 -0700 Subject: [PATCH 12/36] attach post_step functions --- Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H index e808e43bdd8..23440ed51df 100644 --- a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H +++ b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H @@ -312,6 +312,9 @@ public: LS = SUNLinSol_SPGMR(nv_S, SUN_PREC_NONE, 0, sunctx); ARKStepSetLinearSolver(arkode_mem, LS, nullptr); } + + // Set post step function + ARKStepSetPostprocessStepFn(arkode_mem, SundialsUserFun::post_step); } void SetupMRI(amrex::Real time, N_Vector y_data) @@ -383,6 +386,9 @@ public: LS = SUNLinSol_SPGMR(nv_S, SUN_PREC_NONE, 0, sunctx); MRIStepSetLinearSolver(arkode_mem, LS, nullptr); } + + // Set post step function (only on slow integrator) + MRIStepSetPostprocessStepFn(arkode_mem, SundialsUserFun::post_step); } virtual ~SundialsIntegrator () { From f4bf7622772a22838b73224d5b2d4df6dd4c002d Mon Sep 17 00:00:00 2001 From: "David J. Gardner" Date: Mon, 3 Jun 2024 22:50:49 -0700 Subject: [PATCH 13/36] make init utilities private --- .../SUNDIALS/AMReX_SundialsIntegrator.H | 246 +++++++++--------- 1 file changed, 123 insertions(+), 123 deletions(-) diff --git a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H index 23440ed51df..e2e26afccbb 100644 --- a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H +++ b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H @@ -130,6 +130,129 @@ private: } } + void SetupRK(amrex::Real time, N_Vector y_data) + { + // Create integrator and select method + if (type == "ERK") { + amrex::Print() << "SUNDIALS ERK time integrator\n"; + arkode_mem = ARKStepCreate(SundialsUserFun::f, nullptr, time, nv_S, sunctx); + + if (method != "DEFAULT") { + amrex::Print() << "SUNDIALS ERK method " << method << "\n"; + ARKStepSetTableName(arkode_mem, "ARKODE_DIRK_NONE", method.c_str()); + } + } + else if (type == "DIRK") { + amrex::Print() << "SUNDIALS DIRK time integrator\n"; + arkode_mem = ARKStepCreate(nullptr, SundialsUserFun::f, time, nv_S, sunctx); + + if (method != "DEFAULT") { + amrex::Print() << "SUNDIALS DIRK method " << method << "\n"; + ARKStepSetTableName(arkode_mem, method.c_str(), "ARKODE_ERK_NONE"); + } + } + else if (type == "IMEX-RK") { + amrex::Print() << "SUNDIALS IMEX time integrator\n"; + arkode_mem = ARKStepCreate(SundialsUserFun::fe, SundialsUserFun::fi, time, nv_S, sunctx); + + if (method_e != "DEFAULT" && method_i != "DEFAULT") + { + amrex::Print() << "SUNDIALS IMEX method " << method_i << " and " + << method_e << "\n"; + ARKStepSetTableName(arkode_mem, method_i.c_str(), method_e.c_str()); + } + } + + // Attach structure with user-supplied function wrappers + ARKStepSetUserData(arkode_mem, &udata); + + // Set integrator tolerances + ARKStepSStolerances(arkode_mem, reltol, abstol); + + // Create and attach linear solver for implicit methods + if (type == "DIRK" || type == "IMEX-RK") { + LS = SUNLinSol_SPGMR(nv_S, SUN_PREC_NONE, 0, sunctx); + ARKStepSetLinearSolver(arkode_mem, LS, nullptr); + } + + // Set post step function + ARKStepSetPostprocessStepFn(arkode_mem, SundialsUserFun::post_step); + } + + void SetupMRI(amrex::Real time, N_Vector y_data) + { + // Create the fast integrator and select method + if (fast_type == "ERK") { + amrex::Print() << "SUNDIALS ERK time integrator\n"; + arkode_fast_mem = ARKStepCreate(SundialsUserFun::ff, nullptr, time, nv_S, sunctx); + + if (method != "DEFAULT") { + amrex::Print() << "SUNDIALS ERK method " << method << "\n"; + ARKStepSetTableName(arkode_fast_mem, "ARKODE_DIRK_NONE", fast_method.c_str()); + } + } + else if (fast_type == "DIRK") { + amrex::Print() << "SUNDIALS DIRK time integrator\n"; + arkode_fast_mem = ARKStepCreate(nullptr, SundialsUserFun::ff, time, nv_S, sunctx); + + if (method != "DEFAULT") { + amrex::Print() << "SUNDIALS DIRK method " << method << "\n"; + ARKStepSetTableName(arkode_fast_mem, fast_method.c_str(), "ARKODE_ERK_NONE"); + } + + fast_LS = SUNLinSol_SPGMR(nv_S, SUN_PREC_NONE, 0, sunctx); + ARKStepSetLinearSolver(arkode_fast_mem, fast_LS, nullptr); + } + + // Attach structure with user-supplied function wrappers + ARKStepSetUserData(arkode_fast_mem, &udata); + + // Set integrator tolerances + ARKStepSStolerances(arkode_fast_mem, fast_reltol, fast_abstol); + + // Wrap fast integrator as an inner stepper + ARKStepCreateMRIStepInnerStepper(arkode_fast_mem, &fast_stepper); + + // Create slow integrator + if (type == "EX-MRI") { + amrex::Print() << "SUNDIALS ERK time integrator\n"; + arkode_mem = MRIStepCreate(SundialsUserFun::f, nullptr, time, nv_S, + fast_stepper, sunctx); + } + else if (type == "IM-MRI") { + amrex::Print() << "SUNDIALS DIRK time integrator\n"; + arkode_mem = MRIStepCreate(nullptr, SundialsUserFun::f, time, nv_S, + fast_stepper, sunctx); + } + else if (type == "IMEX-MRI") { + amrex::Print() << "SUNDIALS IMEX time integrator\n"; + arkode_mem = MRIStepCreate(SundialsUserFun::fe, SundialsUserFun::fi, + time, nv_S, fast_stepper, sunctx); + } + + // Set method + if (method != "DEFAULT") { + MRIStepCoupling MRIC = MRIStepCoupling_LoadTableByName(method.c_str()); + MRIStepSetCoupling(arkode_mem, MRIC); + MRIStepCoupling_Free(MRIC); + } + + // Attach structure with user-supplied function wrappers + MRIStepSetUserData(arkode_mem, &udata); + + // Set integrator tolerances + MRIStepSStolerances(arkode_mem, reltol, abstol); + + // Create and attach linear solver + if (type == "IM-MRI" || type == "IMEX-MRI") { + LS = SUNLinSol_SPGMR(nv_S, SUN_PREC_NONE, 0, sunctx); + MRIStepSetLinearSolver(arkode_mem, LS, nullptr); + } + + // Set post step function (only on slow integrator) + MRIStepSetPostprocessStepFn(arkode_mem, SundialsUserFun::post_step); + } + public: SundialsIntegrator () {} @@ -268,129 +391,6 @@ public: } } - void SetupRK(amrex::Real time, N_Vector y_data) - { - // Create integrator and select method - if (type == "ERK") { - amrex::Print() << "SUNDIALS ERK time integrator\n"; - arkode_mem = ARKStepCreate(SundialsUserFun::f, nullptr, time, nv_S, sunctx); - - if (method != "DEFAULT") { - amrex::Print() << "SUNDIALS ERK method " << method << "\n"; - ARKStepSetTableName(arkode_mem, "ARKODE_DIRK_NONE", method.c_str()); - } - } - else if (type == "DIRK") { - amrex::Print() << "SUNDIALS DIRK time integrator\n"; - arkode_mem = ARKStepCreate(nullptr, SundialsUserFun::f, time, nv_S, sunctx); - - if (method != "DEFAULT") { - amrex::Print() << "SUNDIALS DIRK method " << method << "\n"; - ARKStepSetTableName(arkode_mem, method.c_str(), "ARKODE_ERK_NONE"); - } - } - else if (type == "IMEX-RK") { - amrex::Print() << "SUNDIALS IMEX time integrator\n"; - arkode_mem = ARKStepCreate(SundialsUserFun::fe, SundialsUserFun::fi, time, nv_S, sunctx); - - if (method_e != "DEFAULT" && method_i != "DEFAULT") - { - amrex::Print() << "SUNDIALS IMEX method " << method_i << " and " - << method_e << "\n"; - ARKStepSetTableName(arkode_mem, method_i.c_str(), method_e.c_str()); - } - } - - // Attach structure with user-supplied function wrappers - ARKStepSetUserData(arkode_mem, &udata); - - // Set integrator tolerances - ARKStepSStolerances(arkode_mem, reltol, abstol); - - // Create and attach linear solver for implicit methods - if (type == "DIRK" || type == "IMEX-RK") { - LS = SUNLinSol_SPGMR(nv_S, SUN_PREC_NONE, 0, sunctx); - ARKStepSetLinearSolver(arkode_mem, LS, nullptr); - } - - // Set post step function - ARKStepSetPostprocessStepFn(arkode_mem, SundialsUserFun::post_step); - } - - void SetupMRI(amrex::Real time, N_Vector y_data) - { - // Create the fast integrator and select method - if (fast_type == "ERK") { - amrex::Print() << "SUNDIALS ERK time integrator\n"; - arkode_fast_mem = ARKStepCreate(SundialsUserFun::ff, nullptr, time, nv_S, sunctx); - - if (method != "DEFAULT") { - amrex::Print() << "SUNDIALS ERK method " << method << "\n"; - ARKStepSetTableName(arkode_fast_mem, "ARKODE_DIRK_NONE", fast_method.c_str()); - } - } - else if (fast_type == "DIRK") { - amrex::Print() << "SUNDIALS DIRK time integrator\n"; - arkode_fast_mem = ARKStepCreate(nullptr, SundialsUserFun::ff, time, nv_S, sunctx); - - if (method != "DEFAULT") { - amrex::Print() << "SUNDIALS DIRK method " << method << "\n"; - ARKStepSetTableName(arkode_fast_mem, fast_method.c_str(), "ARKODE_ERK_NONE"); - } - - fast_LS = SUNLinSol_SPGMR(nv_S, SUN_PREC_NONE, 0, sunctx); - ARKStepSetLinearSolver(arkode_fast_mem, fast_LS, nullptr); - } - - // Attach structure with user-supplied function wrappers - ARKStepSetUserData(arkode_fast_mem, &udata); - - // Set integrator tolerances - ARKStepSStolerances(arkode_fast_mem, fast_reltol, fast_abstol); - - // Wrap fast integrator as an inner stepper - ARKStepCreateMRIStepInnerStepper(arkode_fast_mem, &fast_stepper); - - // Create slow integrator - if (type == "EX-MRI") { - amrex::Print() << "SUNDIALS ERK time integrator\n"; - arkode_mem = MRIStepCreate(SundialsUserFun::f, nullptr, time, nv_S, - fast_stepper, sunctx); - } - else if (type == "IM-MRI") { - amrex::Print() << "SUNDIALS DIRK time integrator\n"; - arkode_mem = MRIStepCreate(nullptr, SundialsUserFun::f, time, nv_S, - fast_stepper, sunctx); - } - else if (type == "IMEX-MRI") { - amrex::Print() << "SUNDIALS IMEX time integrator\n"; - arkode_mem = MRIStepCreate(SundialsUserFun::fe, SundialsUserFun::fi, - time, nv_S, fast_stepper, sunctx); - } - - // Set method - if (method != "DEFAULT") { - MRIStepCoupling MRIC = MRIStepCoupling_LoadTableByName(method.c_str()); - MRIStepSetCoupling(arkode_mem, MRIC); - MRIStepCoupling_Free(MRIC); - } - - // Attach structure with user-supplied function wrappers - MRIStepSetUserData(arkode_mem, &udata); - - // Set integrator tolerances - MRIStepSStolerances(arkode_mem, reltol, abstol); - - // Create and attach linear solver - if (type == "IM-MRI" || type == "IMEX-MRI") { - LS = SUNLinSol_SPGMR(nv_S, SUN_PREC_NONE, 0, sunctx); - MRIStepSetLinearSolver(arkode_mem, LS, nullptr); - } - - // Set post step function (only on slow integrator) - MRIStepSetPostprocessStepFn(arkode_mem, SundialsUserFun::post_step); - } - virtual ~SundialsIntegrator () { // Clean up allocated memory for (int i = 0; i < NVar; ++i) { From 9185992e8e15923de67dc0269a916994f7775910 Mon Sep 17 00:00:00 2001 From: "David J. Gardner" Date: Tue, 4 Jun 2024 00:19:15 -0700 Subject: [PATCH 14/36] add utilities to unpack/wrap nvectors/multifabs --- .../SUNDIALS/AMReX_SundialsIntegrator.H | 250 +++++++++++------- 1 file changed, 153 insertions(+), 97 deletions(-) diff --git a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H index e2e26afccbb..113643d8cd6 100644 --- a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H +++ b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H @@ -85,7 +85,7 @@ private: SundialsUserData udata; // SUNDIALS objects - SUNContext sunctx = nullptr; + SUNContext sunctx = nullptr; // use context created by amrex:sundials::Initialize void *arkode_mem = nullptr; SUNLinearSolver LS = nullptr; @@ -101,10 +101,6 @@ private: Real fast_reltol = 1.0e-4; Real fast_abstol = 1.0e-9; - int NVar; // NOTE: expects S_data to be a Vector - N_Vector* nv_many_arr; /* vector array composed of cons, xmom, ymom, zmom component vectors */ - N_Vector nv_S; - void initialize_parameters () { amrex::ParmParse pp("integration.sundials"); @@ -135,7 +131,7 @@ private: // Create integrator and select method if (type == "ERK") { amrex::Print() << "SUNDIALS ERK time integrator\n"; - arkode_mem = ARKStepCreate(SundialsUserFun::f, nullptr, time, nv_S, sunctx); + arkode_mem = ARKStepCreate(SundialsUserFun::f, nullptr, time, y_data, sunctx); if (method != "DEFAULT") { amrex::Print() << "SUNDIALS ERK method " << method << "\n"; @@ -144,7 +140,7 @@ private: } else if (type == "DIRK") { amrex::Print() << "SUNDIALS DIRK time integrator\n"; - arkode_mem = ARKStepCreate(nullptr, SundialsUserFun::f, time, nv_S, sunctx); + arkode_mem = ARKStepCreate(nullptr, SundialsUserFun::f, time, y_data, sunctx); if (method != "DEFAULT") { amrex::Print() << "SUNDIALS DIRK method " << method << "\n"; @@ -153,7 +149,7 @@ private: } else if (type == "IMEX-RK") { amrex::Print() << "SUNDIALS IMEX time integrator\n"; - arkode_mem = ARKStepCreate(SundialsUserFun::fe, SundialsUserFun::fi, time, nv_S, sunctx); + arkode_mem = ARKStepCreate(SundialsUserFun::fe, SundialsUserFun::fi, time, y_data, sunctx); if (method_e != "DEFAULT" && method_i != "DEFAULT") { @@ -171,7 +167,7 @@ private: // Create and attach linear solver for implicit methods if (type == "DIRK" || type == "IMEX-RK") { - LS = SUNLinSol_SPGMR(nv_S, SUN_PREC_NONE, 0, sunctx); + LS = SUNLinSol_SPGMR(y_data, SUN_PREC_NONE, 0, sunctx); ARKStepSetLinearSolver(arkode_mem, LS, nullptr); } @@ -184,7 +180,7 @@ private: // Create the fast integrator and select method if (fast_type == "ERK") { amrex::Print() << "SUNDIALS ERK time integrator\n"; - arkode_fast_mem = ARKStepCreate(SundialsUserFun::ff, nullptr, time, nv_S, sunctx); + arkode_fast_mem = ARKStepCreate(SundialsUserFun::ff, nullptr, time, y_data, sunctx); if (method != "DEFAULT") { amrex::Print() << "SUNDIALS ERK method " << method << "\n"; @@ -193,14 +189,14 @@ private: } else if (fast_type == "DIRK") { amrex::Print() << "SUNDIALS DIRK time integrator\n"; - arkode_fast_mem = ARKStepCreate(nullptr, SundialsUserFun::ff, time, nv_S, sunctx); + arkode_fast_mem = ARKStepCreate(nullptr, SundialsUserFun::ff, time, y_data, sunctx); if (method != "DEFAULT") { amrex::Print() << "SUNDIALS DIRK method " << method << "\n"; ARKStepSetTableName(arkode_fast_mem, fast_method.c_str(), "ARKODE_ERK_NONE"); } - fast_LS = SUNLinSol_SPGMR(nv_S, SUN_PREC_NONE, 0, sunctx); + fast_LS = SUNLinSol_SPGMR(y_data, SUN_PREC_NONE, 0, sunctx); ARKStepSetLinearSolver(arkode_fast_mem, fast_LS, nullptr); } @@ -216,18 +212,18 @@ private: // Create slow integrator if (type == "EX-MRI") { amrex::Print() << "SUNDIALS ERK time integrator\n"; - arkode_mem = MRIStepCreate(SundialsUserFun::f, nullptr, time, nv_S, + arkode_mem = MRIStepCreate(SundialsUserFun::f, nullptr, time, y_data, fast_stepper, sunctx); } else if (type == "IM-MRI") { amrex::Print() << "SUNDIALS DIRK time integrator\n"; - arkode_mem = MRIStepCreate(nullptr, SundialsUserFun::f, time, nv_S, + arkode_mem = MRIStepCreate(nullptr, SundialsUserFun::f, time, y_data, fast_stepper, sunctx); } else if (type == "IMEX-MRI") { amrex::Print() << "SUNDIALS IMEX time integrator\n"; arkode_mem = MRIStepCreate(SundialsUserFun::fe, SundialsUserFun::fi, - time, nv_S, fast_stepper, sunctx); + time, y_data, fast_stepper, sunctx); } // Set method @@ -245,7 +241,7 @@ private: // Create and attach linear solver if (type == "IM-MRI" || type == "IMEX-MRI") { - LS = SUNLinSol_SPGMR(nv_S, SUN_PREC_NONE, 0, sunctx); + LS = SUNLinSol_SPGMR(y_data, SUN_PREC_NONE, 0, sunctx); MRIStepSetLinearSolver(arkode_mem, LS, nullptr); } @@ -253,6 +249,120 @@ private: MRIStepSetPostprocessStepFn(arkode_mem, SundialsUserFun::post_step); } + // ------------------------------------- + // Vector / N_Vector Utilities + // ------------------------------------- + + // Utility to unpack a SUNDIALS ManyVector into a vector of MultiFabs + void unpack_vector(N_Vector y_data, amrex::Vector& S_data) + { + const int num_vecs = N_VGetNumSubvectors_ManyVector(y_data); + S_data.resize(num_vecs); + + for(int i = 0; i < num_vecs; i++) + { + S_data.at(i) = amrex::MultiFab(*amrex::sundials::getMFptr(N_VGetSubvector_ManyVector(y_data, i)), + amrex::make_alias, + 0, + amrex::sundials::getMFptr(N_VGetSubvector_ManyVector(y_data, i))->nComp()); + } + }; + + // Utility to wrap vector of MultiFabs as a SUNDIALS ManyVector + N_Vector wrap_data(amrex::Vector& S_data) + { + auto get_length = [&](int index) -> sunindextype { + auto* p_mf = &S_data[index]; + return p_mf->nComp() * (p_mf->boxArray()).numPts(); + }; + + sunindextype NV_len = S_data.size(); + N_Vector* NV_array = new N_Vector[NV_len]; + + for (int i = 0; i < NV_len; ++i) { + NV_array[i] = amrex::sundials::N_VMake_MultiFab(get_length(i), + &S_data[i]); // correct context + } + + N_Vector y_data = N_VNew_ManyVector(NV_len, NV_array, sunctx); + + delete[] NV_array; + + return y_data; + }; + + // Utility to wrap vector of MultiFabs as a SUNDIALS ManyVector + N_Vector copy_data(const amrex::Vector& S_data) + { + auto get_length = [&](int index) -> sunindextype { + auto* p_mf = &S_data[index]; + return p_mf->nComp() * (p_mf->boxArray()).numPts(); + }; + + sunindextype NV_len = S_data.size(); + N_Vector* NV_array = new N_Vector[NV_len]; + + for (int i = 0; i < NV_len; ++i) { + NV_array[i] = amrex::sundials::N_VNew_MultiFab(get_length(i), + S_data[i].boxArray(), + S_data[i].DistributionMap(), + S_data[i].nComp(), + S_data[i].nGrow()); // correct context + + MultiFab::Copy(*amrex::sundials::getMFptr(NV_array[i]), + S_data[i], + 0, + 0, + S_data[i].nComp(), + S_data[i].nGrow()); + } + + N_Vector y_data = N_VNew_ManyVector(NV_len, NV_array, sunctx); + + delete[] NV_array; + + return y_data; + }; + + // ----------------------------- + // MultiFab / N_Vector Utilities + // ----------------------------- + + // Utility to unpack a SUNDIALS Vector into a MultiFab + void unpack_vector(N_Vector y_data, amrex::MultiFab& S_data) + { + S_data = amrex::MultiFab(*amrex::sundials::getMFptr(y_data), + amrex::make_alias, + 0, + amrex::sundials::getMFptr(y_data)->nComp()); + }; + + // Utility to wrap a MultiFab as a SUNDIALS Vector + N_Vector wrap_data(amrex::MultiFab& S_data) + { + return amrex::sundials::N_VMake_MultiFab(S_data.nComp() * S_data.boxArray().numPts(), + &S_data); // correct context + }; + + // Utility to wrap a MultiFab as a SUNDIALS Vector + N_Vector copy_data(const amrex::MultiFab& S_data) + { + N_Vector y_data = amrex::sundials::N_VNew_MultiFab(S_data.nComp() * S_data.boxArray().numPts(), + S_data.boxArray(), + S_data.DistributionMap(), + S_data.nComp(), + S_data.nGrow()); // correct context + + MultiFab::Copy(*amrex::sundials::getMFptr(y_data), + S_data, + 0, + 0, + S_data.nComp(), + S_data.nGrow()); + + return y_data; + }; + public: SundialsIntegrator () {} @@ -275,28 +385,13 @@ public: # endif #endif - // Utility to unpack SUNDIALS vectors - auto unpack_vector = [&](N_Vector y_data, amrex::Vector& S_data) -> void { - - const int num_vecs = N_VGetNumSubvectors_ManyVector(y_data); - S_data.resize(num_vecs); - - for(int i = 0; i < num_vecs; i++) - { - S_data.at(i)=amrex::MultiFab(*amrex::sundials::getMFptr(N_VGetSubvector_ManyVector(y_data, i)), - amrex::make_alias, - 0, - amrex::sundials::getMFptr(N_VGetSubvector_ManyVector(y_data, i))->nComp()); - } - }; - // Right-hand side function wrappers udata.f = [&](amrex::Real rhs_time, N_Vector y_data, N_Vector y_rhs, void * /* user_data */) -> int { - amrex::Vector S_data; + T S_data; unpack_vector(y_data, S_data); - amrex::Vector S_rhs; + T S_rhs; unpack_vector(y_rhs, S_rhs); BaseT::pre_rhs_action(S_data, rhs_time); @@ -307,10 +402,10 @@ public: udata.fi = [&](amrex::Real rhs_time, N_Vector y_data, N_Vector y_rhs, void * /* user_data */) -> int { - amrex::Vector S_data; + T S_data; unpack_vector(y_data, S_data); - amrex::Vector S_rhs; + T S_rhs; unpack_vector(y_rhs, S_rhs); BaseT::pre_rhs_action(S_data, rhs_time); @@ -321,10 +416,10 @@ public: udata.fe = [&](amrex::Real rhs_time, N_Vector y_data, N_Vector y_rhs, void * /* user_data */) -> int { - amrex::Vector S_data; + T S_data; unpack_vector(y_data, S_data); - amrex::Vector S_rhs; + T S_rhs; unpack_vector(y_rhs, S_rhs); BaseT::pre_rhs_action(S_data, rhs_time); @@ -335,10 +430,10 @@ public: udata.ff = [&](amrex::Real rhs_time, N_Vector y_data, N_Vector y_rhs, void * /* user_data */) -> int { - amrex::Vector S_data; + T S_data; unpack_vector(y_data, S_data); - amrex::Vector S_rhs; + T S_rhs; unpack_vector(y_rhs, S_rhs); BaseT::pre_rhs_action(S_data, rhs_time); @@ -349,7 +444,7 @@ public: udata.post_step = [&](amrex::Real time, N_Vector y_data, void * /* user_data */) -> int { - amrex::Vector S_data; + T S_data; unpack_vector(y_data, S_data); BaseT::post_step_action(S_data, time); @@ -357,47 +452,21 @@ public: return 0; }; - // Create an N_Vector wrapper for the solution MultiFab - auto get_length = [&](int index) -> sunindextype { - auto* p_mf = &S_data[index]; - return p_mf->nComp() * (p_mf->boxArray()).numPts(); - }; - - NVar = S_data.size(); // NOTE: expects S_data to be a Vector - nv_many_arr = new N_Vector[NVar]; // vector array composed of cons, xmom, ymom, zmom component vectors */ - - for (int i = 0; i < NVar; ++i) { - nv_many_arr[i] = amrex::sundials::N_VNew_MultiFab(get_length(i), - S_data[i].boxArray(), - S_data[i].DistributionMap(), - S_data[i].nComp(), - S_data[i].nGrow()); - - MultiFab::Copy(*amrex::sundials::getMFptr(nv_many_arr[i]), - S_data[i], - 0, - 0, - S_data[i].nComp(), - S_data[i].nGrow()); - } - nv_S = N_VNew_ManyVector(NVar, nv_many_arr, sunctx); + N_Vector y_data = copy_data(S_data); // ideally just wrap and ignore const if (use_ark) { - SetupRK(time, nv_S); + SetupRK(time, y_data); } else if (use_mri) { - SetupMRI(time, nv_S); + SetupMRI(time, y_data); } + + N_VDestroy(y_data); } virtual ~SundialsIntegrator () { // Clean up allocated memory - for (int i = 0; i < NVar; ++i) { - N_VDestroy(nv_many_arr[i]); - } - delete[] nv_many_arr; - N_VDestroy(nv_S); SUNLinSolFree(LS); SUNLinSolFree(fast_LS); MRIStepInnerStepper_Free(&fast_stepper); @@ -411,32 +480,26 @@ public: amrex::Real tout = time + time_step; amrex::Real tret; - // Copy the S_old to nv_S - for(int i=0; i= 0); } else if (use_mri) { - MRIStepReset(arkode_mem, time, nv_S); + MRIStepReset(arkode_mem, time, y_old); MRIStepSetFixedStep(arkode_mem, time_step); - int flag = MRIStepEvolve(arkode_mem, tout, nv_S, &tret, ARK_ONE_STEP); + int flag = MRIStepEvolve(arkode_mem, tout, y_new, &tret, ARK_ONE_STEP); AMREX_ALWAYS_ASSERT(flag >= 0); } else { Error("SUNDIALS integrator type not specified."); } - // Copy the result stored in nv_S to state_new - for(int i=0; i= 0); } else if (use_mri) { - int flag = MRIStepEvolve(arkode_mem, t_out, nv_S, &tret, ARK_NORMAL); + int flag = MRIStepEvolve(arkode_mem, t_out, y_out, &tret, ARK_NORMAL); AMREX_ALWAYS_ASSERT(flag >= 0); } else { Error("SUNDIALS integrator type not specified."); } - // Should be able to make an alias to S_out and avoid the copy - for(int i=0; i Date: Tue, 4 Jun 2024 01:16:28 -0700 Subject: [PATCH 15/36] cleanup --- Src/Base/AMReX_FEIntegrator.H | 8 +-- Src/Base/AMReX_IntegratorBase.H | 37 +++++++++--- Src/Base/AMReX_RKIntegrator.H | 12 ++-- Src/Base/AMReX_TimeIntegrator.H | 59 ++++++------------- .../SUNDIALS/AMReX_SundialsIntegrator.H | 53 +++++++++-------- 5 files changed, 86 insertions(+), 83 deletions(-) diff --git a/Src/Base/AMReX_FEIntegrator.H b/Src/Base/AMReX_FEIntegrator.H index 7857268611c..0620f028591 100644 --- a/Src/Base/AMReX_FEIntegrator.H +++ b/Src/Base/AMReX_FEIntegrator.H @@ -35,7 +35,7 @@ public: initialize_stages(S_data); } - amrex::Real advance (T& S_old, T& S_new, amrex::Real time, const amrex::Real time_step) override + amrex::Real advance (T& S_old, T& S_new, amrex::Real time, const amrex::Real dt) override { // Assume before advance() that S_old is valid data at the current time ("time" argument) // So we initialize S_new by copying the old state. @@ -49,13 +49,13 @@ public: BaseT::Rhs(F, S_new, time); // S_new += timestep * dS/dt - IntegratorOps::Saxpy(S_new, time_step, F); + IntegratorOps::Saxpy(S_new, dt, F); // Call the post step hook - BaseT::post_step_action(S_new, time + time_step); + BaseT::post_step_action(S_new, time + dt); // Return timestep - return time_step; + return dt; } void evolve (T& S_out, const amrex::Real t_out) override diff --git a/Src/Base/AMReX_IntegratorBase.H b/Src/Base/AMReX_IntegratorBase.H index e6ad56b9492..9613f3c7b18 100644 --- a/Src/Base/AMReX_IntegratorBase.H +++ b/Src/Base/AMReX_IntegratorBase.H @@ -197,10 +197,21 @@ protected: */ std::function post_step_action; - // /** - // * \brief Integrator timestep size (Real) - // */ - // amrex::Real timestep; + /** + * \brief Current time reached by the integrator (Real) + */ + amrex::Real cur_time; + + /** + * \brief Current integrator time step size (Real) + */ + amrex::Real time_step; + + /** + * \brief Number of Integrator time steps (Long) + */ + amrex::Long num_steps; + public: IntegratorBase () = default; @@ -228,14 +239,24 @@ public: RhsFast = F; } - void set_pre_rhs_action (std::function F) + void set_pre_rhs_action (std::function A) + { + pre_rhs_action = A; + } + + void set_post_step_action (std::function A) + { + post_step_action = A; + } + + amrex::Real get_time_step() { - pre_rhs_action = F; + return time_step; } - void set_post_step_action (std::function F) + void set_time_step(amrex::Real dt) { - post_step_action = F; + time_step = dt; } virtual amrex::Real advance (T& S_old, T& S_new, amrex::Real time, amrex::Real dt) = 0; diff --git a/Src/Base/AMReX_RKIntegrator.H b/Src/Base/AMReX_RKIntegrator.H index b5622eef776..ab5f67a2c63 100644 --- a/Src/Base/AMReX_RKIntegrator.H +++ b/Src/Base/AMReX_RKIntegrator.H @@ -168,7 +168,7 @@ public: virtual ~RKIntegrator () {} - amrex::Real advance (T& S_old, T& S_new, amrex::Real time, const amrex::Real time_step) override + amrex::Real advance (T& S_old, T& S_new, amrex::Real time, const amrex::Real dt) override { // Assume before advance() that S_old is valid data at the current time ("time" argument) // And that if data is a MultiFab, both S_old and S_new contain ghost cells for evaluating a stencil based RHS @@ -179,7 +179,7 @@ public: for (int i = 0; i < number_nodes; ++i) { // Get current stage time, t = t_old + h * Ci - amrex::Real stage_time = time + time_step * nodes[i]; + amrex::Real stage_time = time + dt * nodes[i]; // Fill S_new with the solution value for evaluating F at the current stage // Copy S_new = S_old @@ -190,7 +190,7 @@ public: // We should fuse these kernels ... for (int j = 0; j < i; ++j) { - IntegratorOps::Saxpy(S_new, time_step * tableau[i][j], *F_nodes[j]); + IntegratorOps::Saxpy(S_new, dt * tableau[i][j], *F_nodes[j]); } } @@ -208,16 +208,16 @@ public: IntegratorOps::Copy(S_new, S_old); for (int i = 0; i < number_nodes; ++i) { - IntegratorOps::Saxpy(S_new, time_step * weights[i], *F_nodes[i]); + IntegratorOps::Saxpy(S_new, dt * weights[i], *F_nodes[i]); } - BaseT::post_step_action(S_new, time + time_step); + BaseT::post_step_action(S_new, time + dt); // If we are working with an extended Butcher tableau, we can estimate the error here, // and then calculate an adaptive timestep. // Return timestep - return time_step; + return dt; } void evolve (T& S_out, const amrex::Real t_out) override diff --git a/Src/Base/AMReX_TimeIntegrator.H b/Src/Base/AMReX_TimeIntegrator.H index 92ef53d8883..2bc626c6876 100644 --- a/Src/Base/AMReX_TimeIntegrator.H +++ b/Src/Base/AMReX_TimeIntegrator.H @@ -25,10 +25,7 @@ template class TimeIntegrator { private: - amrex::Real m_time, m_timestep; - int m_step_number; std::unique_ptr > integrator_ptr; - std::function post_timestep; IntegratorTypes read_parameters () { @@ -67,9 +64,6 @@ private: void set_default_functions () { - // By default, do nothing post-timestep - set_post_timestep([](){}); - // By default, do nothing after before calling the RHS // In general, this is where BCs should be filled set_pre_rhs_action([](T& /* S_data */, amrex::Real /* S_time */){}); @@ -82,11 +76,6 @@ private: // By default, do nothing after a step set_post_step_action([](T& /* S_rhs */, const amrex::Real /* time */){}); - - // By default, initialize time, timestep, step number to 0's - m_time = 0.0_rt; - m_timestep = 0.0_rt; - m_step_number = 0; } public: @@ -138,11 +127,6 @@ public: } } - void set_post_timestep (std::function F) - { - post_timestep = F; - } - void set_rhs (std::function F) { integrator_ptr->set_rhs(F); @@ -159,53 +143,46 @@ public: integrator_ptr->set_fast_rhs(F); } - void set_pre_rhs_action (std::function F) - { - integrator_ptr->set_pre_rhs_action(F); - } - - void set_post_step_action (std::function F) - { - integrator_ptr->set_post_step_action(F); - } - - int get_step_number () + void set_pre_rhs_action (std::function A) { - return m_step_number; + integrator_ptr->set_pre_rhs_action(A); } - amrex::Real get_time () + void set_post_step_action (std::function A) { - return m_time; + integrator_ptr->set_post_step_action(A); } - amrex::Real get_timestep () + amrex::Real get_time_step () { - return m_timestep; + return integrator_ptr->get_time_step(); } - void set_timestep (amrex::Real dt) + void set_time_step (amrex::Real dt) { - m_timestep = dt; + integrator_ptr->set_time_step(); } - void advance (T& S_old, T& S_new, amrex::Real time, const amrex::Real timestep) + // TODO: Change to step + void advance (T& S_old, T& S_new, amrex::Real time, const amrex::Real dt) { - integrator_ptr->advance(S_old, S_new, time, timestep); + integrator_ptr->advance(S_old, S_new, time, dt); } + // TODO: Add to FE and RK integrators void evolve (T& S_out, const amrex::Real t_out) { integrator_ptr->evolve(S_out, t_out); } + // TODO: Change to advance void integrate (T& S_old, T& S_new, amrex::Real start_time, const amrex::Real start_timestep, const amrex::Real end_time, const int start_step, const int max_steps) { - m_time = start_time; - m_timestep = start_timestep; + amrex::Real m_time = start_time; + amrex::Real m_timestep = start_timestep; bool stop_advance = false; - for (m_step_number = start_step; m_step_number < max_steps && !stop_advance; ++m_step_number) + for (int m_step_number = start_step; m_step_number < max_steps && !stop_advance; ++m_step_number) { if (end_time - m_time < m_timestep) { m_timestep = end_time - m_time; @@ -221,12 +198,10 @@ public: // Update our time variable m_time += m_timestep; - - // Call the post-timestep hook - post_timestep(); } } + // TODO: Update to use prev_time_step void time_interpolate (const T& S_new, const T& S_old, amrex::Real timestep_fraction, T& data) { integrator_ptr->time_interpolate(S_new, S_old, timestep_fraction, data); diff --git a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H index 113643d8cd6..c82b56bebfd 100644 --- a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H +++ b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H @@ -68,12 +68,15 @@ class SundialsIntegrator : public IntegratorBase private: using BaseT = IntegratorBase; - // method type and name - std::string type = "ERK"; + // Method type: ERK, DIRK, IMEX-RK, EX-MRI, IM-MRI, IMEX-MRI + std::string type = "ERK"; + + // Use SUNDIALS default methods std::string method = "DEFAULT"; std::string method_e = "DEFAULT"; std::string method_i = "DEFAULT"; + // Fast method type (ERK or DIRK) and method std::string fast_type = "ERK"; std::string fast_method = "DEFAULT"; @@ -84,22 +87,21 @@ private: // structure for interfacing with user-supplied functions SundialsUserData udata; - // SUNDIALS objects - SUNContext sunctx = nullptr; // use context created by amrex:sundials::Initialize + // SUNDIALS context -- should use context created by amrex:sundials::Initialize + SUNContext sunctx = nullptr; + // Single rate or slow time scale void *arkode_mem = nullptr; SUNLinearSolver LS = nullptr; + Real reltol = 1.0e-4; + Real abstol = 1.0e-9; + // Fast time scale void *arkode_fast_mem = nullptr; MRIStepInnerStepper fast_stepper = nullptr; SUNLinearSolver fast_LS = nullptr; - - // relative and absolute tolerances - Real reltol = 1.0e-4; - Real abstol = 1.0e-9; - - Real fast_reltol = 1.0e-4; - Real fast_abstol = 1.0e-9; + Real fast_reltol = 1.0e-4; + Real fast_abstol = 1.0e-9; void initialize_parameters () { @@ -386,7 +388,8 @@ public: #endif // Right-hand side function wrappers - udata.f = [&](amrex::Real rhs_time, N_Vector y_data, N_Vector y_rhs, void * /* user_data */) -> int { + udata.f = [&](amrex::Real rhs_time, N_Vector y_data, N_Vector y_rhs, + void * /* user_data */) -> int { T S_data; unpack_vector(y_data, S_data); @@ -400,7 +403,8 @@ public: return 0; }; - udata.fi = [&](amrex::Real rhs_time, N_Vector y_data, N_Vector y_rhs, void * /* user_data */) -> int { + udata.fi = [&](amrex::Real rhs_time, N_Vector y_data, N_Vector y_rhs, + void * /* user_data */) -> int { T S_data; unpack_vector(y_data, S_data); @@ -414,7 +418,8 @@ public: return 0; }; - udata.fe = [&](amrex::Real rhs_time, N_Vector y_data, N_Vector y_rhs, void * /* user_data */) -> int { + udata.fe = [&](amrex::Real rhs_time, N_Vector y_data, N_Vector y_rhs, + void * /* user_data */) -> int { T S_data; unpack_vector(y_data, S_data); @@ -428,7 +433,8 @@ public: return 0; }; - udata.ff = [&](amrex::Real rhs_time, N_Vector y_data, N_Vector y_rhs, void * /* user_data */) -> int { + udata.ff = [&](amrex::Real rhs_time, N_Vector y_data, N_Vector y_rhs, + void * /* user_data */) -> int { T S_data; unpack_vector(y_data, S_data); @@ -442,7 +448,8 @@ public: return 0; }; - udata.post_step = [&](amrex::Real time, N_Vector y_data, void * /* user_data */) -> int { + udata.post_step = [&](amrex::Real time, N_Vector y_data, + void * /* user_data */) -> int { T S_data; unpack_vector(y_data, S_data); @@ -475,23 +482,23 @@ public: SUNContext_Free(&sunctx); } - amrex::Real advance (T& S_old, T& S_new, amrex::Real time, const amrex::Real time_step) override + amrex::Real advance (T& S_old, T& S_new, amrex::Real time, const amrex::Real dt) override { - amrex::Real tout = time + time_step; + amrex::Real tout = time + dt; amrex::Real tret; N_Vector y_old = wrap_data(S_old); N_Vector y_new = wrap_data(S_new); if (use_ark) { - ARKStepReset(arkode_mem, time, y_old); - ARKStepSetFixedStep(arkode_mem, time_step); + ARKStepReset(arkode_mem, time, y_old); // should probably resize + ARKStepSetFixedStep(arkode_mem, dt); int flag = ARKStepEvolve(arkode_mem, tout, y_new, &tret, ARK_ONE_STEP); AMREX_ALWAYS_ASSERT(flag >= 0); } else if (use_mri) { - MRIStepReset(arkode_mem, time, y_old); - MRIStepSetFixedStep(arkode_mem, time_step); + MRIStepReset(arkode_mem, time, y_old); // should probably resize -- need to resize inner stepper + MRIStepSetFixedStep(arkode_mem, dt); int flag = MRIStepEvolve(arkode_mem, tout, y_new, &tret, ARK_ONE_STEP); AMREX_ALWAYS_ASSERT(flag >= 0); } else { @@ -501,7 +508,7 @@ public: N_VDestroy(y_old); N_VDestroy(y_new); - return time_step; + return dt; } void evolve (T& S_out, const amrex::Real t_out) override From d2d2ae3af530b53f08c36c0d9376db053e95fd07 Mon Sep 17 00:00:00 2001 From: "David J. Gardner" Date: Tue, 4 Jun 2024 01:18:03 -0700 Subject: [PATCH 16/36] formatting --- Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H index c82b56bebfd..801e74aa4dd 100644 --- a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H +++ b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H @@ -128,7 +128,7 @@ private: } } - void SetupRK(amrex::Real time, N_Vector y_data) + void SetupRK (amrex::Real time, N_Vector y_data) { // Create integrator and select method if (type == "ERK") { @@ -177,7 +177,7 @@ private: ARKStepSetPostprocessStepFn(arkode_mem, SundialsUserFun::post_step); } - void SetupMRI(amrex::Real time, N_Vector y_data) + void SetupMRI (amrex::Real time, N_Vector y_data) { // Create the fast integrator and select method if (fast_type == "ERK") { @@ -256,7 +256,7 @@ private: // ------------------------------------- // Utility to unpack a SUNDIALS ManyVector into a vector of MultiFabs - void unpack_vector(N_Vector y_data, amrex::Vector& S_data) + void unpack_vector (N_Vector y_data, amrex::Vector& S_data) { const int num_vecs = N_VGetNumSubvectors_ManyVector(y_data); S_data.resize(num_vecs); @@ -271,7 +271,7 @@ private: }; // Utility to wrap vector of MultiFabs as a SUNDIALS ManyVector - N_Vector wrap_data(amrex::Vector& S_data) + N_Vector wrap_data (amrex::Vector& S_data) { auto get_length = [&](int index) -> sunindextype { auto* p_mf = &S_data[index]; @@ -294,7 +294,7 @@ private: }; // Utility to wrap vector of MultiFabs as a SUNDIALS ManyVector - N_Vector copy_data(const amrex::Vector& S_data) + N_Vector copy_data (const amrex::Vector& S_data) { auto get_length = [&](int index) -> sunindextype { auto* p_mf = &S_data[index]; @@ -331,7 +331,7 @@ private: // ----------------------------- // Utility to unpack a SUNDIALS Vector into a MultiFab - void unpack_vector(N_Vector y_data, amrex::MultiFab& S_data) + void unpack_vector (N_Vector y_data, amrex::MultiFab& S_data) { S_data = amrex::MultiFab(*amrex::sundials::getMFptr(y_data), amrex::make_alias, @@ -340,14 +340,14 @@ private: }; // Utility to wrap a MultiFab as a SUNDIALS Vector - N_Vector wrap_data(amrex::MultiFab& S_data) + N_Vector wrap_data (amrex::MultiFab& S_data) { return amrex::sundials::N_VMake_MultiFab(S_data.nComp() * S_data.boxArray().numPts(), &S_data); // correct context }; // Utility to wrap a MultiFab as a SUNDIALS Vector - N_Vector copy_data(const amrex::MultiFab& S_data) + N_Vector copy_data (const amrex::MultiFab& S_data) { N_Vector y_data = amrex::sundials::N_VNew_MultiFab(S_data.nComp() * S_data.boxArray().numPts(), S_data.boxArray(), From 3cccc64b80b3307c951f9db10b48e4a18bd5c483 Mon Sep 17 00:00:00 2001 From: "David J. Gardner" Date: Tue, 4 Jun 2024 08:54:56 -0700 Subject: [PATCH 17/36] change advance to step --- Src/Base/AMReX_FEIntegrator.H | 4 ++-- Src/Base/AMReX_IntegratorBase.H | 2 +- Src/Base/AMReX_RKIntegrator.H | 4 ++-- Src/Base/AMReX_TimeIntegrator.H | 15 +++++++-------- Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H | 2 +- 5 files changed, 13 insertions(+), 14 deletions(-) diff --git a/Src/Base/AMReX_FEIntegrator.H b/Src/Base/AMReX_FEIntegrator.H index 0620f028591..c04195d3c73 100644 --- a/Src/Base/AMReX_FEIntegrator.H +++ b/Src/Base/AMReX_FEIntegrator.H @@ -35,9 +35,9 @@ public: initialize_stages(S_data); } - amrex::Real advance (T& S_old, T& S_new, amrex::Real time, const amrex::Real dt) override + amrex::Real step (T& S_old, T& S_new, amrex::Real time, const amrex::Real dt) override { - // Assume before advance() that S_old is valid data at the current time ("time" argument) + // Assume before step() that S_old is valid data at the current time ("time" argument) // So we initialize S_new by copying the old state. IntegratorOps::Copy(S_new, S_old); diff --git a/Src/Base/AMReX_IntegratorBase.H b/Src/Base/AMReX_IntegratorBase.H index 9613f3c7b18..f4ed5134b8e 100644 --- a/Src/Base/AMReX_IntegratorBase.H +++ b/Src/Base/AMReX_IntegratorBase.H @@ -259,7 +259,7 @@ public: time_step = dt; } - virtual amrex::Real advance (T& S_old, T& S_new, amrex::Real time, amrex::Real dt) = 0; + virtual amrex::Real step (T& S_old, T& S_new, amrex::Real time, amrex::Real dt) = 0; virtual void evolve (T& S_out, const amrex::Real t_out) = 0; diff --git a/Src/Base/AMReX_RKIntegrator.H b/Src/Base/AMReX_RKIntegrator.H index ab5f67a2c63..615b2fb55fb 100644 --- a/Src/Base/AMReX_RKIntegrator.H +++ b/Src/Base/AMReX_RKIntegrator.H @@ -168,9 +168,9 @@ public: virtual ~RKIntegrator () {} - amrex::Real advance (T& S_old, T& S_new, amrex::Real time, const amrex::Real dt) override + amrex::Real step (T& S_old, T& S_new, amrex::Real time, const amrex::Real dt) override { - // Assume before advance() that S_old is valid data at the current time ("time" argument) + // Assume before step() that S_old is valid data at the current time ("time" argument) // And that if data is a MultiFab, both S_old and S_new contain ghost cells for evaluating a stencil based RHS // We need this from S_old. This is convenient for S_new to have so we can use it // as scratch space for stage values without creating a new scratch MultiFab with ghost cells. diff --git a/Src/Base/AMReX_TimeIntegrator.H b/Src/Base/AMReX_TimeIntegrator.H index 2bc626c6876..00d3059b0b1 100644 --- a/Src/Base/AMReX_TimeIntegrator.H +++ b/Src/Base/AMReX_TimeIntegrator.H @@ -163,10 +163,9 @@ public: integrator_ptr->set_time_step(); } - // TODO: Change to step - void advance (T& S_old, T& S_new, amrex::Real time, const amrex::Real dt) + void step (T& S_old, T& S_new, amrex::Real time, const amrex::Real dt) { - integrator_ptr->advance(S_old, S_new, time, dt); + integrator_ptr->step(S_old, S_new, time, dt); } // TODO: Add to FE and RK integrators @@ -181,20 +180,20 @@ public: { amrex::Real m_time = start_time; amrex::Real m_timestep = start_timestep; - bool stop_advance = false; - for (int m_step_number = start_step; m_step_number < max_steps && !stop_advance; ++m_step_number) + bool stop = false; + for (int m_step_number = start_step; m_step_number < max_steps && !stop; ++m_step_number) { if (end_time - m_time < m_timestep) { m_timestep = end_time - m_time; - stop_advance = true; + stop = true; } if (m_step_number > 0) { std::swap(S_old, S_new); } - // Call the time integrator advance - integrator_ptr->advance(S_old, S_new, m_time, m_timestep); + // Call the time integrator step + integrator_ptr->step(S_old, S_new, m_time, m_timestep); // Update our time variable m_time += m_timestep; diff --git a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H index 801e74aa4dd..0e81f45b7a8 100644 --- a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H +++ b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H @@ -482,7 +482,7 @@ public: SUNContext_Free(&sunctx); } - amrex::Real advance (T& S_old, T& S_new, amrex::Real time, const amrex::Real dt) override + amrex::Real step (T& S_old, T& S_new, amrex::Real time, const amrex::Real dt) override { amrex::Real tout = time + dt; amrex::Real tret; From 6c9a29ac7cd217eebde842442252e843b36ced1a Mon Sep 17 00:00:00 2001 From: "David J. Gardner" Date: Tue, 4 Jun 2024 10:07:24 -0700 Subject: [PATCH 18/36] wrap long lines --- Src/Base/AMReX_IntegratorBase.H | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/Src/Base/AMReX_IntegratorBase.H b/Src/Base/AMReX_IntegratorBase.H index f4ed5134b8e..24dc0a1f2bc 100644 --- a/Src/Base/AMReX_IntegratorBase.H +++ b/Src/Base/AMReX_IntegratorBase.H @@ -259,11 +259,13 @@ public: time_step = dt; } - virtual amrex::Real step (T& S_old, T& S_new, amrex::Real time, amrex::Real dt) = 0; + virtual amrex::Real step (T& S_old, T& S_new, amrex::Real time, + amrex::Real dt) = 0; virtual void evolve (T& S_out, const amrex::Real t_out) = 0; - virtual void time_interpolate (const T& S_new, const T& S_old, amrex::Real timestep_fraction, T& data) = 0; + virtual void time_interpolate (const T& S_new, const T& S_old, + amrex::Real timestep_fraction, T& data) = 0; virtual void map_data (std::function Map) = 0; }; From d49999169028208d83f6a044cec31ee2249352ed Mon Sep 17 00:00:00 2001 From: "David J. Gardner" Date: Tue, 4 Jun 2024 11:10:45 -0700 Subject: [PATCH 19/36] fix comment --- Src/Base/AMReX_IntegratorBase.H | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Src/Base/AMReX_IntegratorBase.H b/Src/Base/AMReX_IntegratorBase.H index 24dc0a1f2bc..1f9a150e6c7 100644 --- a/Src/Base/AMReX_IntegratorBase.H +++ b/Src/Base/AMReX_IntegratorBase.H @@ -208,7 +208,7 @@ protected: amrex::Real time_step; /** - * \brief Number of Integrator time steps (Long) + * \brief Number of integrator time steps (Long) */ amrex::Long num_steps; From 7b9d542a3ea724580eba06df68d37315210a035a Mon Sep 17 00:00:00 2001 From: "David J. Gardner" Date: Tue, 4 Jun 2024 11:21:19 -0700 Subject: [PATCH 20/36] clean up naming and comments --- Src/Base/AMReX_IntegratorBase.H | 9 ++++++++- Src/Base/AMReX_TimeIntegrator.H | 4 ++-- Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H | 9 +++++---- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/Src/Base/AMReX_IntegratorBase.H b/Src/Base/AMReX_IntegratorBase.H index 1f9a150e6c7..894ba63d82f 100644 --- a/Src/Base/AMReX_IntegratorBase.H +++ b/Src/Base/AMReX_IntegratorBase.H @@ -259,10 +259,17 @@ public: time_step = dt; } + /** + * \brief Take a single time step from (time, S_old) to (time + dt, S_new) + * with the given step size. + */ virtual amrex::Real step (T& S_old, T& S_new, amrex::Real time, amrex::Real dt) = 0; - virtual void evolve (T& S_out, const amrex::Real t_out) = 0; + /** + * \brief Evolve the current (internal) integrator state to time_out + */ + virtual void evolve (T& S_out, const amrex::Real time_out) = 0; virtual void time_interpolate (const T& S_new, const T& S_old, amrex::Real timestep_fraction, T& data) = 0; diff --git a/Src/Base/AMReX_TimeIntegrator.H b/Src/Base/AMReX_TimeIntegrator.H index 00d3059b0b1..86b48926b82 100644 --- a/Src/Base/AMReX_TimeIntegrator.H +++ b/Src/Base/AMReX_TimeIntegrator.H @@ -169,9 +169,9 @@ public: } // TODO: Add to FE and RK integrators - void evolve (T& S_out, const amrex::Real t_out) + void evolve (T& S_out, const amrex::Real time_out) { - integrator_ptr->evolve(S_out, t_out); + integrator_ptr->evolve(S_out, time_out); } // TODO: Change to advance diff --git a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H index 0e81f45b7a8..bfb74ff9084 100644 --- a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H +++ b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H @@ -511,18 +511,19 @@ public: return dt; } - void evolve (T& S_out, const amrex::Real t_out) override + void evolve (T& S_out, const amrex::Real time_out) override { - amrex::Real tret; + int flag = 0; // SUNDIALS return status + amrex::Real time_ret; // SUNDIALS return time N_Vector y_out = wrap_data(S_out); if (use_ark) { - int flag = ARKStepEvolve(arkode_mem, t_out, y_out, &tret, ARK_NORMAL); + flag = ARKStepEvolve(arkode_mem, time_out, y_out, &time_ret, ARK_NORMAL); AMREX_ALWAYS_ASSERT(flag >= 0); } else if (use_mri) { - int flag = MRIStepEvolve(arkode_mem, t_out, y_out, &tret, ARK_NORMAL); + flag = MRIStepEvolve(arkode_mem, time_out, y_out, &time_ret, ARK_NORMAL); AMREX_ALWAYS_ASSERT(flag >= 0); } else { Error("SUNDIALS integrator type not specified."); From 5ba02319323ce45ee2df8f9685c1f895e383c44e Mon Sep 17 00:00:00 2001 From: "David J. Gardner" Date: Tue, 4 Jun 2024 13:43:29 -0700 Subject: [PATCH 21/36] add evolve for FE and RK --- Src/Base/AMReX_FEIntegrator.H | 47 ++++++++++-- Src/Base/AMReX_IntegratorBase.H | 34 ++++++++- Src/Base/AMReX_RKIntegrator.H | 71 ++++++++++++++++--- Src/Base/AMReX_TimeIntegrator.H | 17 ++++- .../SUNDIALS/AMReX_SundialsIntegrator.H | 11 ++- 5 files changed, 161 insertions(+), 19 deletions(-) diff --git a/Src/Base/AMReX_FEIntegrator.H b/Src/Base/AMReX_FEIntegrator.H index c04195d3c73..981532fabe0 100644 --- a/Src/Base/AMReX_FEIntegrator.H +++ b/Src/Base/AMReX_FEIntegrator.H @@ -15,9 +15,23 @@ private: amrex::Vector > F_nodes; - void initialize_stages (const T& S_data) + // Current (internal) state and time + amrex::Vector > S_current; + amrex::Real time_current; + + int max_steps = 500; + + void initialize_stages (const T& S_data, const amrex::Real time) { + // Create data for stage RHS IntegratorOps::CreateLike(F_nodes, S_data); + + // Create and initialize data for current state + IntegratorOps::CreateLike(S_current, S_data, true); + IntegratorOps::Copy(*S_current[0], S_data); + + // Set the initial time + time_current = time; } public: @@ -30,9 +44,9 @@ public: virtual ~FEIntegrator () {} - void initialize (const T& S_data, const amrex::Real /* time */) override + void initialize (const T& S_data, const amrex::Real time = 0.0) override { - initialize_stages(S_data); + initialize_stages(S_data, time); } amrex::Real step (T& S_old, T& S_new, amrex::Real time, const amrex::Real dt) override @@ -58,9 +72,32 @@ public: return dt; } - void evolve (T& S_out, const amrex::Real t_out) override + void evolve (T& S_out, const amrex::Real time_out) override { - amrex::Error("Evolve is not yet supported by the forward euler integrator."); + amrex::Real dt = BaseT::time_step; + bool stop = false; + + for (int step_number = 0; step_number < max_steps && !stop; ++step_number) + { + // Adjust step size to reach output time + if (time_out - time_current < dt) { + dt = time_out - time_current; + stop = true; + } + + // Call the time integrator step + step(*S_current[0], S_out, time_current, dt); + + // Update current state S_current = S_out + IntegratorOps::Copy(*S_current[0], S_out); + + // Update time + time_current += dt; + + if (step_number == max_steps - 1) { + Error("Did not reach output time in max steps."); + } + } } virtual void time_interpolate (const T& /* S_new */, const T& /* S_old */, amrex::Real /* timestep_fraction */, T& /* data */) override diff --git a/Src/Base/AMReX_IntegratorBase.H b/Src/Base/AMReX_IntegratorBase.H index 894ba63d82f..b187f321616 100644 --- a/Src/Base/AMReX_IntegratorBase.H +++ b/Src/Base/AMReX_IntegratorBase.H @@ -198,15 +198,28 @@ protected: std::function post_step_action; /** - * \brief Current time reached by the integrator (Real) + * \brief Flag to enable/disable adaptive time stepping in single rate + * methods or at the slow time scale in multirate methods (bool) */ - amrex::Real cur_time; + bool use_adaptive_time_step; /** * \brief Current integrator time step size (Real) */ amrex::Real time_step; + /** + * \brief Flag to enable/disable adaptive time stepping at the fast time + * scale in multirate methods (bool) + */ + bool use_adaptive_fast_time_step; + + /** + * \brief Current integrator fast time scale time step size with multirate + * methods (Real) + */ + amrex::Real fast_time_step; + /** * \brief Number of integrator time steps (Long) */ @@ -257,6 +270,23 @@ public: void set_time_step(amrex::Real dt) { time_step = dt; + use_adaptive_time_step = false; + } + + void set_adaptive_step() + { + use_adaptive_time_step = true; + } + + void set_fast_time_step(amrex::Real dt) + { + fast_time_step = dt; + use_adaptive_fast_time_step = false; + } + + void set_adaptive_fast_step() + { + use_adaptive_fast_time_step = true; } /** diff --git a/Src/Base/AMReX_RKIntegrator.H b/Src/Base/AMReX_RKIntegrator.H index 615b2fb55fb..b8dd527a76d 100644 --- a/Src/Base/AMReX_RKIntegrator.H +++ b/Src/Base/AMReX_RKIntegrator.H @@ -23,15 +23,36 @@ class RKIntegrator : public IntegratorBase private: using BaseT = IntegratorBase; + // Butcher tableau identifiers ButcherTableauTypes tableau_type; + + // Number of RK stages int number_nodes; - bool use_adaptive_timestep; - amrex::Vector > F_nodes; + + // A matrix from Butcher tableau amrex::Vector > tableau; + + // b vector from Butcher tableau amrex::Vector weights; - amrex::Vector extended_weights; + + // c vector from Butcher tableau amrex::Vector nodes; + // RK embedded method b vector + amrex::Vector extended_weights; + + // Flag to enable adaptive stepping + bool use_adaptive_timestep; + + // RK stage right-hand sides + amrex::Vector > F_nodes; + + // Current (internal) state and time + amrex::Vector > S_current; + amrex::Real time_current; + + int max_steps = 500; + void initialize_preset_tableau () { switch (tableau_type) @@ -85,14 +106,14 @@ private: pp.get("type", _tableau_type); tableau_type = static_cast(_tableau_type); - // By default, define no extended weights and no adaptive timestepping + // By default, define no extended weights and no adaptive time stepping extended_weights = {}; use_adaptive_timestep = false; pp.queryAdd("use_adaptive_timestep", use_adaptive_timestep); if (tableau_type == ButcherTableauTypes::User) { - // Read weights/nodes/butcher tableau" + // Read weights/nodes/butcher tableau pp.getarr("weights", weights); pp.queryarr("extended_weights", extended_weights); pp.getarr("nodes", nodes); @@ -143,13 +164,20 @@ private: } } - void initialize_stages (const T& S_data) + void initialize_stages (const T& S_data, const amrex::Real time) { // Create data for stage RHS for (int i = 0; i < number_nodes; ++i) { IntegratorOps::CreateLike(F_nodes, S_data); } + + // Create and initialize data for current state + IntegratorOps::CreateLike(S_current, S_data, true); + IntegratorOps::Copy(*S_current[0], S_data); + + // Set the initial time + time_current = time; } public: @@ -160,10 +188,10 @@ public: initialize(S_data, time); } - void initialize (const T& S_data, const amrex::Real /* time */) override + void initialize (const T& S_data, const amrex::Real time = 0.0) override { initialize_parameters(); - initialize_stages(S_data); + initialize_stages(S_data, time); } virtual ~RKIntegrator () {} @@ -220,9 +248,32 @@ public: return dt; } - void evolve (T& S_out, const amrex::Real t_out) override + void evolve (T& S_out, const amrex::Real time_out) override { - amrex::Error("Evolve is not yet supported by the RK integrator."); + amrex::Real dt = BaseT::time_step; + bool stop = false; + + for (int step_number = 0; step_number < max_steps && !stop; ++step_number) + { + // Adjust step size to reach output time + if (time_out - time_current < dt) { + dt = time_out - time_current; + stop = true; + } + + // Call the time integrator step + step(*S_current[0], S_out, time_current, dt); + + // Update current state S_current = S_out + IntegratorOps::Copy(*S_current[0], S_out); + + // Update time + time_current += dt; + + if (step_number == max_steps - 1) { + Error("Did not reach output time in max steps."); + } + } } void time_interpolate (const T& /* S_new */, const T& S_old, amrex::Real timestep_fraction, T& data) override diff --git a/Src/Base/AMReX_TimeIntegrator.H b/Src/Base/AMReX_TimeIntegrator.H index 86b48926b82..9aa0849a0b2 100644 --- a/Src/Base/AMReX_TimeIntegrator.H +++ b/Src/Base/AMReX_TimeIntegrator.H @@ -160,7 +160,22 @@ public: void set_time_step (amrex::Real dt) { - integrator_ptr->set_time_step(); + integrator_ptr->set_time_step(dt); + } + + void set_adaptive_step () + { + integrator_ptr->set_adaptive_step(); + } + + void set_fast_time_step (amrex::Real dt) + { + integrator_ptr->set_fast_time_step(dt); + } + + void set_adaptive_fast_step () + { + integrator_ptr->set_adaptive_fast_step(); } void step (T& S_old, T& S_new, amrex::Real time, const amrex::Real dt) diff --git a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H index bfb74ff9084..15fb7dcb77d 100644 --- a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H +++ b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H @@ -20,7 +20,7 @@ namespace amrex { struct SundialsUserData { // ERK or DIRK right-hand side function - // ExMRI or ImMRI slow right-hand side function + // EX-MRI or IM-MRI slow right-hand side function std::function f; // ImEx-RK right-hand side functions @@ -519,10 +519,19 @@ public: N_Vector y_out = wrap_data(S_out); if (use_ark) { + if (!BaseT::use_adaptive_time_step) { + ARKStepSetFixedStep(arkode_mem, BaseT::time_step); + } flag = ARKStepEvolve(arkode_mem, time_out, y_out, &time_ret, ARK_NORMAL); AMREX_ALWAYS_ASSERT(flag >= 0); } else if (use_mri) { + if (!BaseT::use_adaptive_time_step) { + MRIStepSetFixedStep(arkode_mem, BaseT::time_step); + } + if (!BaseT::use_adaptive_fast_time_step) { + ARKStepSetFixedStep(arkode_fast_mem, BaseT::fast_time_step); + } flag = MRIStepEvolve(arkode_mem, time_out, y_out, &time_ret, ARK_NORMAL); AMREX_ALWAYS_ASSERT(flag >= 0); } else { From 9e858023d3108fa3235f86e07881318b3cbb7091 Mon Sep 17 00:00:00 2001 From: "David J. Gardner" Date: Tue, 11 Jun 2024 22:11:42 -0700 Subject: [PATCH 22/36] add post stage actions --- Src/Base/AMReX_IntegratorBase.H | 13 +++++++++- Src/Base/AMReX_RKIntegrator.H | 2 ++ Src/Base/AMReX_TimeIntegrator.H | 20 +++++++++++---- .../SUNDIALS/AMReX_SundialsIntegrator.H | 25 +++++++++++++++++-- 4 files changed, 52 insertions(+), 8 deletions(-) diff --git a/Src/Base/AMReX_IntegratorBase.H b/Src/Base/AMReX_IntegratorBase.H index b187f321616..3b382042225 100644 --- a/Src/Base/AMReX_IntegratorBase.H +++ b/Src/Base/AMReX_IntegratorBase.H @@ -191,9 +191,15 @@ protected: */ std::function pre_rhs_action; + /** + * \brief The post_stage_action function is called by the integrator on + * the computed stage just after it is computed + */ + std::function post_stage_action; + /** * \brief The post_step_action function is called by the integrator on - * computed state just after it is computed + * the computed state just after it is computed */ std::function post_step_action; @@ -257,6 +263,11 @@ public: pre_rhs_action = A; } + void set_post_stage_action (std::function A) + { + post_stage_action = A; + } + void set_post_step_action (std::function A) { post_step_action = A; diff --git a/Src/Base/AMReX_RKIntegrator.H b/Src/Base/AMReX_RKIntegrator.H index b8dd527a76d..251b84da734 100644 --- a/Src/Base/AMReX_RKIntegrator.H +++ b/Src/Base/AMReX_RKIntegrator.H @@ -220,6 +220,8 @@ public: { IntegratorOps::Saxpy(S_new, dt * tableau[i][j], *F_nodes[j]); } + + BaseT::post_stage_action(S_new, stage_time); } // Call the update hook for the stage state value diff --git a/Src/Base/AMReX_TimeIntegrator.H b/Src/Base/AMReX_TimeIntegrator.H index 9aa0849a0b2..d126aacc427 100644 --- a/Src/Base/AMReX_TimeIntegrator.H +++ b/Src/Base/AMReX_TimeIntegrator.H @@ -64,18 +64,23 @@ private: void set_default_functions () { - // By default, do nothing after before calling the RHS - // In general, this is where BCs should be filled - set_pre_rhs_action([](T& /* S_data */, amrex::Real /* S_time */){}); - // By default, do nothing in the RHS set_rhs([](T& /* S_rhs */, const T& /* S_data */, const amrex::Real /* time */){}); set_imex_rhs([](T& /* S_rhs */, const T& /* S_data */, const amrex::Real /* time */){}, [](T& /* S_rhs */, const T& /* S_data */, const amrex::Real /* time */){}); set_fast_rhs([](T& /* S_rhs */, const T& /* S_data */, const amrex::Real /* time */){}); + // In general, the following functions can be used to fill BCs. Which + // function to set will depend on the method type and intended use case + + // By default, do nothing before calling the RHS + set_pre_rhs_action([](T& /* S_data */, amrex::Real /* time */){}); + + // By default, do nothing after a stage + set_post_stage_action([](T& /* S_data */, const amrex::Real /* time */){}); + // By default, do nothing after a step - set_post_step_action([](T& /* S_rhs */, const amrex::Real /* time */){}); + set_post_step_action([](T& /* S_data */, const amrex::Real /* time */){}); } public: @@ -148,6 +153,11 @@ public: integrator_ptr->set_pre_rhs_action(A); } + void set_post_stage_action (std::function A) + { + integrator_ptr->set_post_stage_action(A); + } + void set_post_step_action (std::function A) { integrator_ptr->set_post_step_action(A); diff --git a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H index 15fb7dcb77d..bbbfac49faf 100644 --- a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H +++ b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H @@ -31,6 +31,9 @@ struct SundialsUserData { // MRI fast time scale right-hand side function std::function ff; + // Post stage actions + std::function post_stage; + // Post step actions std::function post_step; }; @@ -56,6 +59,11 @@ namespace SundialsUserFun { return udata->ff(t, y_data, y_rhs, user_data); } + static int post_stage (amrex::Real t, N_Vector y_data, void *user_data) { + SundialsUserData* udata = static_cast(user_data); + return udata->post_stage(t, y_data, user_data); + } + static int post_step (amrex::Real t, N_Vector y_data, void *user_data) { SundialsUserData* udata = static_cast(user_data); return udata->post_step(t, y_data, user_data); @@ -173,7 +181,8 @@ private: ARKStepSetLinearSolver(arkode_mem, LS, nullptr); } - // Set post step function + // Set post stage and step function + ARKStepSetPostprocessStageFn(arkode_mem, SundialsUserFun::post_stage); ARKStepSetPostprocessStepFn(arkode_mem, SundialsUserFun::post_step); } @@ -247,7 +256,8 @@ private: MRIStepSetLinearSolver(arkode_mem, LS, nullptr); } - // Set post step function (only on slow integrator) + // Set post stage and step function + MRIStepSetPostprocessStageFn(arkode_mem, SundialsUserFun::post_stage); MRIStepSetPostprocessStepFn(arkode_mem, SundialsUserFun::post_step); } @@ -448,6 +458,17 @@ public: return 0; }; + udata.post_stage = [&](amrex::Real time, N_Vector y_data, + void * /* user_data */) -> int { + + T S_data; + unpack_vector(y_data, S_data); + + BaseT::post_stage_action(S_data, time); + + return 0; + }; + udata.post_step = [&](amrex::Real time, N_Vector y_data, void * /* user_data */) -> int { From a7fcaa0d583688aabe304a5365412b3eb1c75050 Mon Sep 17 00:00:00 2001 From: "David J. Gardner" Date: Tue, 11 Jun 2024 22:21:07 -0700 Subject: [PATCH 23/36] add fast stage and step actions --- Src/Base/AMReX_IntegratorBase.H | 22 ++++++++++ Src/Base/AMReX_TimeIntegrator.H | 18 ++++++-- .../SUNDIALS/AMReX_SundialsIntegrator.H | 44 +++++++++++++++++-- 3 files changed, 78 insertions(+), 6 deletions(-) diff --git a/Src/Base/AMReX_IntegratorBase.H b/Src/Base/AMReX_IntegratorBase.H index 3b382042225..6bae9e5251a 100644 --- a/Src/Base/AMReX_IntegratorBase.H +++ b/Src/Base/AMReX_IntegratorBase.H @@ -203,6 +203,18 @@ protected: */ std::function post_step_action; + /** + * \brief The post_stage_action function is called by the integrator on + * the computed stage just after it is computed + */ + std::function post_fast_stage_action; + + /** + * \brief The post_step_action function is called by the integrator on + * the computed state just after it is computed + */ + std::function post_fast_step_action; + /** * \brief Flag to enable/disable adaptive time stepping in single rate * methods or at the slow time scale in multirate methods (bool) @@ -273,6 +285,16 @@ public: post_step_action = A; } + void set_post_fast_stage_action (std::function A) + { + post_fast_stage_action = A; + } + + void set_post_fast_step_action (std::function A) + { + post_fast_step_action = A; + } + amrex::Real get_time_step() { return time_step; diff --git a/Src/Base/AMReX_TimeIntegrator.H b/Src/Base/AMReX_TimeIntegrator.H index d126aacc427..8354c3a63d7 100644 --- a/Src/Base/AMReX_TimeIntegrator.H +++ b/Src/Base/AMReX_TimeIntegrator.H @@ -76,11 +76,13 @@ private: // By default, do nothing before calling the RHS set_pre_rhs_action([](T& /* S_data */, amrex::Real /* time */){}); - // By default, do nothing after a stage + // By default, do nothing after a stage or step set_post_stage_action([](T& /* S_data */, const amrex::Real /* time */){}); - - // By default, do nothing after a step set_post_step_action([](T& /* S_data */, const amrex::Real /* time */){}); + + // By default, do nothing after a stage or step + set_post_fast_stage_action([](T& /* S_data */, const amrex::Real /* time */){}); + set_post_fast_step_action([](T& /* S_data */, const amrex::Real /* time */){}); } public: @@ -163,6 +165,16 @@ public: integrator_ptr->set_post_step_action(A); } + void set_post_fast_stage_action (std::function A) + { + integrator_ptr->set_post_fast_stage_action(A); + } + + void set_post_fast_step_action (std::function A) + { + integrator_ptr->set_post_fast_step_action(A); + } + amrex::Real get_time_step () { return integrator_ptr->get_time_step(); diff --git a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H index bbbfac49faf..152db83cbec 100644 --- a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H +++ b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H @@ -31,11 +31,13 @@ struct SundialsUserData { // MRI fast time scale right-hand side function std::function ff; - // Post stage actions + // Post stage and step actions std::function post_stage; - - // Post step actions std::function post_step; + + // Post fast stage and step actions + std::function post_fast_stage; + std::function post_fast_step; }; namespace SundialsUserFun { @@ -68,6 +70,16 @@ namespace SundialsUserFun { SundialsUserData* udata = static_cast(user_data); return udata->post_step(t, y_data, user_data); } + + static int post_fast_stage (amrex::Real t, N_Vector y_data, void *user_data) { + SundialsUserData* udata = static_cast(user_data); + return udata->post_fast_stage(t, y_data, user_data); + } + + static int post_fast_step (amrex::Real t, N_Vector y_data, void *user_data) { + SundialsUserData* udata = static_cast(user_data); + return udata->post_fast_step(t, y_data, user_data); + } } template @@ -217,6 +229,10 @@ private: // Set integrator tolerances ARKStepSStolerances(arkode_fast_mem, fast_reltol, fast_abstol); + // Set post stage and step function + ARKStepSetPostprocessStageFn(arkode_fast_mem, SundialsUserFun::post_fast_stage); + ARKStepSetPostprocessStepFn(arkode_fast_mem, SundialsUserFun::post_fast_step); + // Wrap fast integrator as an inner stepper ARKStepCreateMRIStepInnerStepper(arkode_fast_mem, &fast_stepper); @@ -480,6 +496,28 @@ public: return 0; }; + udata.post_fast_stage = [&](amrex::Real time, N_Vector y_data, + void * /* user_data */) -> int { + + T S_data; + unpack_vector(y_data, S_data); + + BaseT::post_fast_stage_action(S_data, time); + + return 0; + }; + + udata.post_fast_step = [&](amrex::Real time, N_Vector y_data, + void * /* user_data */) -> int { + + T S_data; + unpack_vector(y_data, S_data); + + BaseT::post_fast_step_action(S_data, time); + + return 0; + }; + N_Vector y_data = copy_data(S_data); // ideally just wrap and ignore const if (use_ark) { From 0466cfed14323dde8a0aa2bbbb7eff920f88eba4 Mon Sep 17 00:00:00 2001 From: "David J. Gardner" Date: Tue, 11 Jun 2024 22:50:53 -0700 Subject: [PATCH 24/36] remove unused flag --- Src/Base/AMReX_RKIntegrator.H | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/Src/Base/AMReX_RKIntegrator.H b/Src/Base/AMReX_RKIntegrator.H index 251b84da734..3fb4edeb4e5 100644 --- a/Src/Base/AMReX_RKIntegrator.H +++ b/Src/Base/AMReX_RKIntegrator.H @@ -41,9 +41,6 @@ private: // RK embedded method b vector amrex::Vector extended_weights; - // Flag to enable adaptive stepping - bool use_adaptive_timestep; - // RK stage right-hand sides amrex::Vector > F_nodes; @@ -106,10 +103,8 @@ private: pp.get("type", _tableau_type); tableau_type = static_cast(_tableau_type); - // By default, define no extended weights and no adaptive time stepping + // By default, define no extended weights extended_weights = {}; - use_adaptive_timestep = false; - pp.queryAdd("use_adaptive_timestep", use_adaptive_timestep); if (tableau_type == ButcherTableauTypes::User) { From 09e8bbd22e337a8ce45648660cfa9e05329df77f Mon Sep 17 00:00:00 2001 From: "David J. Gardner" Date: Tue, 11 Jun 2024 23:06:27 -0700 Subject: [PATCH 25/36] move max steps to base class --- Src/Base/AMReX_FEIntegrator.H | 6 ++---- Src/Base/AMReX_IntegratorBase.H | 9 +++++++++ Src/Base/AMReX_RKIntegrator.H | 6 ++---- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/Src/Base/AMReX_FEIntegrator.H b/Src/Base/AMReX_FEIntegrator.H index 981532fabe0..ac814ac0034 100644 --- a/Src/Base/AMReX_FEIntegrator.H +++ b/Src/Base/AMReX_FEIntegrator.H @@ -19,8 +19,6 @@ private: amrex::Vector > S_current; amrex::Real time_current; - int max_steps = 500; - void initialize_stages (const T& S_data, const amrex::Real time) { // Create data for stage RHS @@ -77,7 +75,7 @@ public: amrex::Real dt = BaseT::time_step; bool stop = false; - for (int step_number = 0; step_number < max_steps && !stop; ++step_number) + for (int step_number = 0; step_number < BaseT::max_steps && !stop; ++step_number) { // Adjust step size to reach output time if (time_out - time_current < dt) { @@ -94,7 +92,7 @@ public: // Update time time_current += dt; - if (step_number == max_steps - 1) { + if (step_number == BaseT::max_steps - 1) { Error("Did not reach output time in max steps."); } } diff --git a/Src/Base/AMReX_IntegratorBase.H b/Src/Base/AMReX_IntegratorBase.H index 6bae9e5251a..cefe2bf8665 100644 --- a/Src/Base/AMReX_IntegratorBase.H +++ b/Src/Base/AMReX_IntegratorBase.H @@ -243,6 +243,10 @@ protected: */ amrex::Long num_steps; + /** + * \brief Max number of internal steps before an error is returned (Long) + */ + int max_steps = 500; public: IntegratorBase () = default; @@ -322,6 +326,11 @@ public: use_adaptive_fast_time_step = true; } + void set_max_steps(int steps) + { + max_steps = steps; + } + /** * \brief Take a single time step from (time, S_old) to (time + dt, S_new) * with the given step size. diff --git a/Src/Base/AMReX_RKIntegrator.H b/Src/Base/AMReX_RKIntegrator.H index 3fb4edeb4e5..b52a5d9609f 100644 --- a/Src/Base/AMReX_RKIntegrator.H +++ b/Src/Base/AMReX_RKIntegrator.H @@ -48,8 +48,6 @@ private: amrex::Vector > S_current; amrex::Real time_current; - int max_steps = 500; - void initialize_preset_tableau () { switch (tableau_type) @@ -250,7 +248,7 @@ public: amrex::Real dt = BaseT::time_step; bool stop = false; - for (int step_number = 0; step_number < max_steps && !stop; ++step_number) + for (int step_number = 0; step_number < BaseT::max_steps && !stop; ++step_number) { // Adjust step size to reach output time if (time_out - time_current < dt) { @@ -267,7 +265,7 @@ public: // Update time time_current += dt; - if (step_number == max_steps - 1) { + if (step_number == BaseT::max_steps - 1) { Error("Did not reach output time in max steps."); } } From 61fcc1d8492f9d94c1406bbb91a0060f536d8e70 Mon Sep 17 00:00:00 2001 From: "David J. Gardner" Date: Tue, 11 Jun 2024 23:07:03 -0700 Subject: [PATCH 26/36] initialize adaptive flags, step counter --- Src/Base/AMReX_IntegratorBase.H | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Src/Base/AMReX_IntegratorBase.H b/Src/Base/AMReX_IntegratorBase.H index cefe2bf8665..3626a52adea 100644 --- a/Src/Base/AMReX_IntegratorBase.H +++ b/Src/Base/AMReX_IntegratorBase.H @@ -219,7 +219,7 @@ protected: * \brief Flag to enable/disable adaptive time stepping in single rate * methods or at the slow time scale in multirate methods (bool) */ - bool use_adaptive_time_step; + bool use_adaptive_time_step = false; /** * \brief Current integrator time step size (Real) @@ -230,7 +230,7 @@ protected: * \brief Flag to enable/disable adaptive time stepping at the fast time * scale in multirate methods (bool) */ - bool use_adaptive_fast_time_step; + bool use_adaptive_fast_time_step = false; /** * \brief Current integrator fast time scale time step size with multirate @@ -241,7 +241,7 @@ protected: /** * \brief Number of integrator time steps (Long) */ - amrex::Long num_steps; + amrex::Long num_steps = 0; /** * \brief Max number of internal steps before an error is returned (Long) From 5e7be65068d3249b6ab94f23147b7c65d3d2ecda Mon Sep 17 00:00:00 2001 From: "David J. Gardner" Date: Mon, 17 Jun 2024 09:28:57 -0700 Subject: [PATCH 27/36] formatting --- Src/Base/AMReX_IntegratorBase.H | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/Src/Base/AMReX_IntegratorBase.H b/Src/Base/AMReX_IntegratorBase.H index 3626a52adea..e4f82be20c3 100644 --- a/Src/Base/AMReX_IntegratorBase.H +++ b/Src/Base/AMReX_IntegratorBase.H @@ -299,34 +299,34 @@ public: post_fast_step_action = A; } - amrex::Real get_time_step() + amrex::Real get_time_step () { return time_step; } - void set_time_step(amrex::Real dt) + void set_time_step (amrex::Real dt) { time_step = dt; use_adaptive_time_step = false; } - void set_adaptive_step() + void set_adaptive_step () { use_adaptive_time_step = true; } - void set_fast_time_step(amrex::Real dt) + void set_fast_time_step (amrex::Real dt) { fast_time_step = dt; use_adaptive_fast_time_step = false; } - void set_adaptive_fast_step() + void set_adaptive_fast_step () { use_adaptive_fast_time_step = true; } - void set_max_steps(int steps) + void set_max_steps (int steps) { max_steps = steps; } From 65d80ea51c8219361cbf5d910fc0df73737115bf Mon Sep 17 00:00:00 2001 From: "David J. Gardner" Date: Mon, 17 Jun 2024 09:30:54 -0700 Subject: [PATCH 28/36] change step back to advance --- Src/Base/AMReX_FEIntegrator.H | 4 ++-- Src/Base/AMReX_IntegratorBase.H | 4 ++-- Src/Base/AMReX_RKIntegrator.H | 4 ++-- Src/Base/AMReX_TimeIntegrator.H | 4 ++-- Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H | 2 +- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/Src/Base/AMReX_FEIntegrator.H b/Src/Base/AMReX_FEIntegrator.H index ac814ac0034..c1a26626b3c 100644 --- a/Src/Base/AMReX_FEIntegrator.H +++ b/Src/Base/AMReX_FEIntegrator.H @@ -47,7 +47,7 @@ public: initialize_stages(S_data, time); } - amrex::Real step (T& S_old, T& S_new, amrex::Real time, const amrex::Real dt) override + amrex::Real advance (T& S_old, T& S_new, amrex::Real time, const amrex::Real dt) override { // Assume before step() that S_old is valid data at the current time ("time" argument) // So we initialize S_new by copying the old state. @@ -84,7 +84,7 @@ public: } // Call the time integrator step - step(*S_current[0], S_out, time_current, dt); + advance(*S_current[0], S_out, time_current, dt); // Update current state S_current = S_out IntegratorOps::Copy(*S_current[0], S_out); diff --git a/Src/Base/AMReX_IntegratorBase.H b/Src/Base/AMReX_IntegratorBase.H index e4f82be20c3..711594cd666 100644 --- a/Src/Base/AMReX_IntegratorBase.H +++ b/Src/Base/AMReX_IntegratorBase.H @@ -335,8 +335,8 @@ public: * \brief Take a single time step from (time, S_old) to (time + dt, S_new) * with the given step size. */ - virtual amrex::Real step (T& S_old, T& S_new, amrex::Real time, - amrex::Real dt) = 0; + virtual amrex::Real advance (T& S_old, T& S_new, amrex::Real time, + amrex::Real dt) = 0; /** * \brief Evolve the current (internal) integrator state to time_out diff --git a/Src/Base/AMReX_RKIntegrator.H b/Src/Base/AMReX_RKIntegrator.H index b52a5d9609f..54c007ba70c 100644 --- a/Src/Base/AMReX_RKIntegrator.H +++ b/Src/Base/AMReX_RKIntegrator.H @@ -189,7 +189,7 @@ public: virtual ~RKIntegrator () {} - amrex::Real step (T& S_old, T& S_new, amrex::Real time, const amrex::Real dt) override + amrex::Real advance (T& S_old, T& S_new, amrex::Real time, const amrex::Real dt) override { // Assume before step() that S_old is valid data at the current time ("time" argument) // And that if data is a MultiFab, both S_old and S_new contain ghost cells for evaluating a stencil based RHS @@ -257,7 +257,7 @@ public: } // Call the time integrator step - step(*S_current[0], S_out, time_current, dt); + advance(*S_current[0], S_out, time_current, dt); // Update current state S_current = S_out IntegratorOps::Copy(*S_current[0], S_out); diff --git a/Src/Base/AMReX_TimeIntegrator.H b/Src/Base/AMReX_TimeIntegrator.H index 8354c3a63d7..6cf671eaadd 100644 --- a/Src/Base/AMReX_TimeIntegrator.H +++ b/Src/Base/AMReX_TimeIntegrator.H @@ -200,9 +200,9 @@ public: integrator_ptr->set_adaptive_fast_step(); } - void step (T& S_old, T& S_new, amrex::Real time, const amrex::Real dt) + void advance (T& S_old, T& S_new, amrex::Real time, const amrex::Real dt) { - integrator_ptr->step(S_old, S_new, time, dt); + integrator_ptr->advance(S_old, S_new, time, dt); } // TODO: Add to FE and RK integrators diff --git a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H index 152db83cbec..8dff9a2299a 100644 --- a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H +++ b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H @@ -541,7 +541,7 @@ public: SUNContext_Free(&sunctx); } - amrex::Real step (T& S_old, T& S_new, amrex::Real time, const amrex::Real dt) override + amrex::Real advance (T& S_old, T& S_new, amrex::Real time, const amrex::Real dt) override { amrex::Real tout = time + dt; amrex::Real tret; From 389e08213912b597a03f5a3606d5cb64466a21f6 Mon Sep 17 00:00:00 2001 From: "David J. Gardner" Date: Mon, 17 Jun 2024 09:42:22 -0700 Subject: [PATCH 29/36] add set_post_update wrapper --- Src/Base/AMReX_IntegratorBase.H | 6 ++++++ Src/Base/AMReX_TimeIntegrator.H | 5 +++++ 2 files changed, 11 insertions(+) diff --git a/Src/Base/AMReX_IntegratorBase.H b/Src/Base/AMReX_IntegratorBase.H index 711594cd666..df1347d82cf 100644 --- a/Src/Base/AMReX_IntegratorBase.H +++ b/Src/Base/AMReX_IntegratorBase.H @@ -299,6 +299,12 @@ public: post_fast_step_action = A; } + void set_post_update (std::function A) + { + set_post_stage_action(A); + set_post_step_action(A); + } + amrex::Real get_time_step () { return time_step; diff --git a/Src/Base/AMReX_TimeIntegrator.H b/Src/Base/AMReX_TimeIntegrator.H index 6cf671eaadd..e56967f359a 100644 --- a/Src/Base/AMReX_TimeIntegrator.H +++ b/Src/Base/AMReX_TimeIntegrator.H @@ -175,6 +175,11 @@ public: integrator_ptr->set_post_fast_step_action(A); } + void set_post_update (std::function A) + { + integrator_ptr->set_post_update(A); + } + amrex::Real get_time_step () { return integrator_ptr->get_time_step(); From 0f7b7588ecf8c0dae3541515618c3e4fe92712d0 Mon Sep 17 00:00:00 2001 From: "David J. Gardner" Date: Tue, 18 Jun 2024 13:52:17 -0700 Subject: [PATCH 30/36] remove initialize from base class --- Src/Base/AMReX_FEIntegrator.H | 2 +- Src/Base/AMReX_IntegratorBase.H | 2 -- Src/Base/AMReX_RKIntegrator.H | 2 +- Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H | 2 +- 4 files changed, 3 insertions(+), 5 deletions(-) diff --git a/Src/Base/AMReX_FEIntegrator.H b/Src/Base/AMReX_FEIntegrator.H index c1a26626b3c..f8a002ef534 100644 --- a/Src/Base/AMReX_FEIntegrator.H +++ b/Src/Base/AMReX_FEIntegrator.H @@ -42,7 +42,7 @@ public: virtual ~FEIntegrator () {} - void initialize (const T& S_data, const amrex::Real time = 0.0) override + void initialize (const T& S_data, const amrex::Real time = 0.0) { initialize_stages(S_data, time); } diff --git a/Src/Base/AMReX_IntegratorBase.H b/Src/Base/AMReX_IntegratorBase.H index df1347d82cf..752fab47d22 100644 --- a/Src/Base/AMReX_IntegratorBase.H +++ b/Src/Base/AMReX_IntegratorBase.H @@ -255,8 +255,6 @@ public: virtual ~IntegratorBase () = default; - virtual void initialize (const T& S_data, const amrex::Real time = 0.0) = 0; - void set_rhs (std::function F) { Rhs = F; diff --git a/Src/Base/AMReX_RKIntegrator.H b/Src/Base/AMReX_RKIntegrator.H index 54c007ba70c..88200d6f784 100644 --- a/Src/Base/AMReX_RKIntegrator.H +++ b/Src/Base/AMReX_RKIntegrator.H @@ -181,7 +181,7 @@ public: initialize(S_data, time); } - void initialize (const T& S_data, const amrex::Real time = 0.0) override + void initialize (const T& S_data, const amrex::Real time = 0.0) { initialize_parameters(); initialize_stages(S_data, time); diff --git a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H index 8dff9a2299a..ecbc8ca68fc 100644 --- a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H +++ b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H @@ -399,7 +399,7 @@ public: initialize(S_data, time); } - void initialize (const T& S_data, const amrex::Real time = 0.0) override + void initialize (const T& S_data, const amrex::Real time = 0.0) { initialize_parameters(); MPI_Comm mpi_comm = ParallelContext::CommunicatorSub(); From 1d480ae67918bd75b2f07c34ccb87ac2fb87bb1e Mon Sep 17 00:00:00 2001 From: "David J. Gardner" Date: Tue, 18 Jun 2024 14:03:30 -0700 Subject: [PATCH 31/36] fix integrate --- Src/Base/AMReX_IntegratorBase.H | 5 ++++ Src/Base/AMReX_RKIntegrator.H | 47 ++++++++++++++++++--------------- Src/Base/AMReX_TimeIntegrator.H | 5 +--- 3 files changed, 31 insertions(+), 26 deletions(-) diff --git a/Src/Base/AMReX_IntegratorBase.H b/Src/Base/AMReX_IntegratorBase.H index 752fab47d22..54caca41059 100644 --- a/Src/Base/AMReX_IntegratorBase.H +++ b/Src/Base/AMReX_IntegratorBase.H @@ -226,6 +226,11 @@ protected: */ amrex::Real time_step; + /** + * \brief Step size of the last completed step (Real) + */ + amrex::Real previous_time_step; + /** * \brief Flag to enable/disable adaptive time stepping at the fast time * scale in multirate methods (bool) diff --git a/Src/Base/AMReX_RKIntegrator.H b/Src/Base/AMReX_RKIntegrator.H index 88200d6f784..f72890c808e 100644 --- a/Src/Base/AMReX_RKIntegrator.H +++ b/Src/Base/AMReX_RKIntegrator.H @@ -237,9 +237,12 @@ public: BaseT::post_step_action(S_new, time + dt); // If we are working with an extended Butcher tableau, we can estimate the error here, - // and then calculate an adaptive timestep. + // and then calculate an adaptive time step. - // Return timestep + // Save last completed step size for time_interpolate + BaseT::previous_time_step = dt; + + // Return time step return dt; } @@ -265,6 +268,9 @@ public: // Update time time_current += dt; + // Save last completed step size for time_interpolate + BaseT::previous_time_step = dt; + if (step_number == BaseT::max_steps - 1) { Error("Did not reach output time in max steps."); } @@ -281,33 +287,30 @@ public: IntegratorOps::Saxpy(data, timestep_fraction, S_new); */ - - // Update to use the last time step taken - // currently we only do this for 4th order RK - // AMREX_ASSERT(number_nodes == 4); + AMREX_ASSERT(number_nodes == 4); - // // fill data using MC Equation 39 at time + timestep_fraction * dt - // amrex::Real c = 0; + // fill data using MC Equation 39 at time + timestep_fraction * dt + amrex::Real c = 0; - // // data = S_old - // IntegratorOps::Copy(data, S_old); + // data = S_old + IntegratorOps::Copy(data, S_old); - // // data += (chi - 3/2 * chi^2 + 2/3 * chi^3) * k1 - // c = timestep_fraction - 1.5 * std::pow(timestep_fraction, 2) + 2./3. * std::pow(timestep_fraction, 3); - // IntegratorOps::Saxpy(data, c*BaseT::timestep, *F_nodes[0]); + // data += (chi - 3/2 * chi^2 + 2/3 * chi^3) * k1 + c = timestep_fraction - 1.5 * std::pow(timestep_fraction, 2) + 2./3. * std::pow(timestep_fraction, 3); + IntegratorOps::Saxpy(data, c*BaseT::previous_time_step, *F_nodes[0]); - // // data += (chi^2 - 2/3 * chi^3) * k2 - // c = std::pow(timestep_fraction, 2) - 2./3. * std::pow(timestep_fraction, 3); - // IntegratorOps::Saxpy(data, c*BaseT::timestep, *F_nodes[1]); + // data += (chi^2 - 2/3 * chi^3) * k2 + c = std::pow(timestep_fraction, 2) - 2./3. * std::pow(timestep_fraction, 3); + IntegratorOps::Saxpy(data, c*BaseT::previous_time_step, *F_nodes[1]); - // // data += (chi^2 - 2/3 * chi^3) * k3 - // c = std::pow(timestep_fraction, 2) - 2./3. * std::pow(timestep_fraction, 3); - // IntegratorOps::Saxpy(data, c*BaseT::timestep, *F_nodes[2]); + // data += (chi^2 - 2/3 * chi^3) * k3 + c = std::pow(timestep_fraction, 2) - 2./3. * std::pow(timestep_fraction, 3); + IntegratorOps::Saxpy(data, c*BaseT::previous_time_step, *F_nodes[2]); - // // data += (-1/2 * chi^2 + 2/3 * chi^3) * k4 - // c = -0.5 * std::pow(timestep_fraction, 2) + 2./3. * std::pow(timestep_fraction, 3); - // IntegratorOps::Saxpy(data, c*BaseT::timestep, *F_nodes[3]); + // data += (-1/2 * chi^2 + 2/3 * chi^3) * k4 + c = -0.5 * std::pow(timestep_fraction, 2) + 2./3. * std::pow(timestep_fraction, 3); + IntegratorOps::Saxpy(data, c*BaseT::previous_time_step, *F_nodes[3]); } diff --git a/Src/Base/AMReX_TimeIntegrator.H b/Src/Base/AMReX_TimeIntegrator.H index e56967f359a..cbd1984b81e 100644 --- a/Src/Base/AMReX_TimeIntegrator.H +++ b/Src/Base/AMReX_TimeIntegrator.H @@ -210,13 +210,11 @@ public: integrator_ptr->advance(S_old, S_new, time, dt); } - // TODO: Add to FE and RK integrators void evolve (T& S_out, const amrex::Real time_out) { integrator_ptr->evolve(S_out, time_out); } - // TODO: Change to advance void integrate (T& S_old, T& S_new, amrex::Real start_time, const amrex::Real start_timestep, const amrex::Real end_time, const int start_step, const int max_steps) { @@ -235,14 +233,13 @@ public: } // Call the time integrator step - integrator_ptr->step(S_old, S_new, m_time, m_timestep); + integrator_ptr->advance(S_old, S_new, m_time, m_timestep); // Update our time variable m_time += m_timestep; } } - // TODO: Update to use prev_time_step void time_interpolate (const T& S_new, const T& S_old, amrex::Real timestep_fraction, T& data) { integrator_ptr->time_interpolate(S_new, S_old, timestep_fraction, data); From 93dcbc6122c7bae41d4814a4853457bf144d0629 Mon Sep 17 00:00:00 2001 From: "David J. Gardner" Date: Tue, 18 Jun 2024 14:08:43 -0700 Subject: [PATCH 32/36] fix comment --- Src/Base/AMReX_TimeIntegrator.H | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Src/Base/AMReX_TimeIntegrator.H b/Src/Base/AMReX_TimeIntegrator.H index cbd1984b81e..1165aff96fe 100644 --- a/Src/Base/AMReX_TimeIntegrator.H +++ b/Src/Base/AMReX_TimeIntegrator.H @@ -232,7 +232,7 @@ public: std::swap(S_old, S_new); } - // Call the time integrator step + // Call the time integrator advance integrator_ptr->advance(S_old, S_new, m_time, m_timestep); // Update our time variable From d1d85d5e3f274738a26212bf5859ce43bcc8060b Mon Sep 17 00:00:00 2001 From: "David J. Gardner" Date: Wed, 19 Jun 2024 09:12:10 -0700 Subject: [PATCH 33/36] use local context in vector constructors --- Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H index ecbc8ca68fc..110ba8ef317 100644 --- a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H +++ b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H @@ -107,7 +107,10 @@ private: // structure for interfacing with user-supplied functions SundialsUserData udata; - // SUNDIALS context -- should use context created by amrex:sundials::Initialize + // SUNDIALS context + // + // We should probably use context created by amrex:sundials::Initialize but + // that context is not MPI-aware SUNContext sunctx = nullptr; // Single rate or slow time scale @@ -309,7 +312,7 @@ private: for (int i = 0; i < NV_len; ++i) { NV_array[i] = amrex::sundials::N_VMake_MultiFab(get_length(i), - &S_data[i]); // correct context + &S_data[i], sunctx); } N_Vector y_data = N_VNew_ManyVector(NV_len, NV_array, sunctx); @@ -335,7 +338,8 @@ private: S_data[i].boxArray(), S_data[i].DistributionMap(), S_data[i].nComp(), - S_data[i].nGrow()); // correct context + S_data[i].nGrow(), + sunctx); MultiFab::Copy(*amrex::sundials::getMFptr(NV_array[i]), S_data[i], @@ -369,7 +373,7 @@ private: N_Vector wrap_data (amrex::MultiFab& S_data) { return amrex::sundials::N_VMake_MultiFab(S_data.nComp() * S_data.boxArray().numPts(), - &S_data); // correct context + &S_data, sunctx); }; // Utility to wrap a MultiFab as a SUNDIALS Vector @@ -379,7 +383,8 @@ private: S_data.boxArray(), S_data.DistributionMap(), S_data.nComp(), - S_data.nGrow()); // correct context + S_data.nGrow(), + sunctx); MultiFab::Copy(*amrex::sundials::getMFptr(y_data), S_data, From 8c38539745671d4fe32fe9369ba266ebed95a325 Mon Sep 17 00:00:00 2001 From: "David J. Gardner" Date: Wed, 19 Jun 2024 09:12:30 -0700 Subject: [PATCH 34/36] correct error message --- Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H index 110ba8ef317..cce177e6f58 100644 --- a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H +++ b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H @@ -145,7 +145,7 @@ private: use_mri = true; } else { - std::string msg("Unknown strategy: "); + std::string msg("Unknown method type: "); msg += type; amrex::Error(msg.c_str()); } From 2db0d851f9089f73c325f2b21e01ab0cab8a7982 Mon Sep 17 00:00:00 2001 From: "David J. Gardner" Date: Wed, 19 Jun 2024 09:39:08 -0700 Subject: [PATCH 35/36] update to C++ sundials context wrapper --- .../SUNDIALS/AMReX_SundialsIntegrator.H | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H index cce177e6f58..8976e5ad38a 100644 --- a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H +++ b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H @@ -111,7 +111,7 @@ private: // // We should probably use context created by amrex:sundials::Initialize but // that context is not MPI-aware - SUNContext sunctx = nullptr; + ::sundials::Context sunctx; // Single rate or slow time scale void *arkode_mem = nullptr; @@ -312,7 +312,7 @@ private: for (int i = 0; i < NV_len; ++i) { NV_array[i] = amrex::sundials::N_VMake_MultiFab(get_length(i), - &S_data[i], sunctx); + &S_data[i], &sunctx); } N_Vector y_data = N_VNew_ManyVector(NV_len, NV_array, sunctx); @@ -339,7 +339,7 @@ private: S_data[i].DistributionMap(), S_data[i].nComp(), S_data[i].nGrow(), - sunctx); + &sunctx); MultiFab::Copy(*amrex::sundials::getMFptr(NV_array[i]), S_data[i], @@ -373,7 +373,7 @@ private: N_Vector wrap_data (amrex::MultiFab& S_data) { return amrex::sundials::N_VMake_MultiFab(S_data.nComp() * S_data.boxArray().numPts(), - &S_data, sunctx); + &S_data, &sunctx); }; // Utility to wrap a MultiFab as a SUNDIALS Vector @@ -384,7 +384,7 @@ private: S_data.DistributionMap(), S_data.nComp(), S_data.nGrow(), - sunctx); + &sunctx); MultiFab::Copy(*amrex::sundials::getMFptr(y_data), S_data, @@ -409,12 +409,16 @@ public: initialize_parameters(); MPI_Comm mpi_comm = ParallelContext::CommunicatorSub(); #if defined(SUNDIALS_VERSION_MAJOR) && (SUNDIALS_VERSION_MAJOR < 7) - SUNContext_Create(&mpi_comm, &sunctx); +# ifdef AMREX_USE_MPI + sunctx = ::sundials::Context(&mpi_comm); +# else + sunctx = ::sundials::Context(nullptr); +# endif #else # ifdef AMREX_USE_MPI - SUNContext_Create(mpi_comm, &sunctx); + sunctx = ::sundials::Context(mpi_comm); # else - SUNContext_Create(SUN_COMM_NULL, &sunctx); + sunctx = ::sundials::Context(SUN_COMM_NULL); # endif #endif @@ -543,7 +547,6 @@ public: MRIStepInnerStepper_Free(&fast_stepper); MRIStepFree(&arkode_fast_mem); ARKStepFree(&arkode_mem); - SUNContext_Free(&sunctx); } amrex::Real advance (T& S_old, T& S_new, amrex::Real time, const amrex::Real dt) override From 43912e8b1e81f81ff4f6cc9006f0f1d276832a81 Mon Sep 17 00:00:00 2001 From: "David J. Gardner" Date: Wed, 19 Jun 2024 10:02:10 -0700 Subject: [PATCH 36/36] move tolerances to base class --- Src/Base/AMReX_IntegratorBase.H | 35 +++++++++++++++++++ Src/Base/AMReX_TimeIntegrator.H | 15 ++++++++ .../SUNDIALS/AMReX_SundialsIntegrator.H | 10 ++---- 3 files changed, 53 insertions(+), 7 deletions(-) diff --git a/Src/Base/AMReX_IntegratorBase.H b/Src/Base/AMReX_IntegratorBase.H index 54caca41059..d9af8053d7e 100644 --- a/Src/Base/AMReX_IntegratorBase.H +++ b/Src/Base/AMReX_IntegratorBase.H @@ -253,6 +253,29 @@ protected: */ int max_steps = 500; + /** + * \brief Relative tolerance for adaptive time stepping (Real) + */ + amrex::Real rel_tol = 1.0e-4; + + /** + * \brief Absolute tolerance for adaptive time stepping (Real) + */ + amrex::Real abs_tol = 1.0e-9; + + /** + * \brief Relative tolerance for adaptive time stepping at the fast time + * scale (Real) + */ + amrex::Real fast_rel_tol = 1.0e-4; + + /** + * \brief Absolute tolerance for adaptive time stepping at the fast time + * scale (Real) + */ + amrex::Real fast_abs_tol = 1.0e-9; + + public: IntegratorBase () = default; @@ -340,6 +363,18 @@ public: max_steps = steps; } + void set_tolerances (amrex::Real rtol, amrex::Real atol) + { + rel_tol = rtol; + abs_tol = atol; + } + + void set_fast_tolerances (amrex::Real rtol, amrex::Real atol) + { + fast_rel_tol = rtol; + fast_abs_tol = atol; + } + /** * \brief Take a single time step from (time, S_old) to (time + dt, S_new) * with the given step size. diff --git a/Src/Base/AMReX_TimeIntegrator.H b/Src/Base/AMReX_TimeIntegrator.H index 1165aff96fe..10443361533 100644 --- a/Src/Base/AMReX_TimeIntegrator.H +++ b/Src/Base/AMReX_TimeIntegrator.H @@ -205,6 +205,21 @@ public: integrator_ptr->set_adaptive_fast_step(); } + void set_max_steps (int steps) + { + integrator_ptr->set_max_steps(steps); + } + + void set_tolerances (amrex::Real rtol, amrex::Real atol) + { + integrator_ptr->set_tolerances(rtol, atol); + } + + void set_fast_tolerances (amrex::Real rtol, amrex::Real atol) + { + integrator_ptr->set_fast_tolerances(rtol, atol); + } + void advance (T& S_old, T& S_new, amrex::Real time, const amrex::Real dt) { integrator_ptr->advance(S_old, S_new, time, dt); diff --git a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H index 8976e5ad38a..30ff30a499b 100644 --- a/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H +++ b/Src/Extern/SUNDIALS/AMReX_SundialsIntegrator.H @@ -116,15 +116,11 @@ private: // Single rate or slow time scale void *arkode_mem = nullptr; SUNLinearSolver LS = nullptr; - Real reltol = 1.0e-4; - Real abstol = 1.0e-9; // Fast time scale void *arkode_fast_mem = nullptr; MRIStepInnerStepper fast_stepper = nullptr; SUNLinearSolver fast_LS = nullptr; - Real fast_reltol = 1.0e-4; - Real fast_abstol = 1.0e-9; void initialize_parameters () { @@ -188,7 +184,7 @@ private: ARKStepSetUserData(arkode_mem, &udata); // Set integrator tolerances - ARKStepSStolerances(arkode_mem, reltol, abstol); + ARKStepSStolerances(arkode_mem, BaseT::rel_tol, BaseT::abs_tol); // Create and attach linear solver for implicit methods if (type == "DIRK" || type == "IMEX-RK") { @@ -230,7 +226,7 @@ private: ARKStepSetUserData(arkode_fast_mem, &udata); // Set integrator tolerances - ARKStepSStolerances(arkode_fast_mem, fast_reltol, fast_abstol); + ARKStepSStolerances(arkode_fast_mem, BaseT::fast_rel_tol, BaseT::fast_abs_tol); // Set post stage and step function ARKStepSetPostprocessStageFn(arkode_fast_mem, SundialsUserFun::post_fast_stage); @@ -267,7 +263,7 @@ private: MRIStepSetUserData(arkode_mem, &udata); // Set integrator tolerances - MRIStepSStolerances(arkode_mem, reltol, abstol); + MRIStepSStolerances(arkode_mem, BaseT::rel_tol, BaseT::abs_tol); // Create and attach linear solver if (type == "IM-MRI" || type == "IMEX-MRI") {