From 6439894affa437c59153324b38999c7f86720c19 Mon Sep 17 00:00:00 2001 From: Andrew Shao Date: Sun, 19 May 2024 09:41:19 -0400 Subject: [PATCH] Additional adjustments for stellar --- .../lateral/MOM_CNN_GZ21.F90 | 108 +++++++++--------- .../lateral/MOM_hor_visc.F90 | 2 +- 2 files changed, 58 insertions(+), 52 deletions(-) diff --git a/src/parameterizations/lateral/MOM_CNN_GZ21.F90 b/src/parameterizations/lateral/MOM_CNN_GZ21.F90 index b1cbeb8263..c45d30b36e 100644 --- a/src/parameterizations/lateral/MOM_CNN_GZ21.F90 +++ b/src/parameterizations/lateral/MOM_CNN_GZ21.F90 @@ -64,7 +64,7 @@ module MOM_CNN_GZ21 integer :: jedw !< The upper j-memory limit for the wide halo arrays. integer :: CNN_halo_size !< Halo size at each side of subdomains - character(len=200) :: CNN_VS !< default = "none". Vertical profile of CNN momentum forcing + character(len=200) :: CNN_VS !< default = "none". Vertical profile of CNN momentum forcing type(diag_ctrl), pointer :: diag => NULL() !< A type that regulates diagnostics output !>@{ Diagnostic handles @@ -120,7 +120,7 @@ subroutine CNN_init(Time,G,GV,US,param_file,diag, dbcomms_CS, CS) CS%id_Systd = register_diag_field('ocean_model', 'Systd', diag%axesTL, Time, & 'Meridional Acceleration from CNN model standard deviation part', & 'm s-2', conversion=US%L_T2_to_m_s2) - + call get_param(param_file, mdl, "CNN_VS", CS%CNN_VS, & "Vertical profile of momentum forcing from CNN :\n" // & " 'none': infer CNN at each layer\n"// & @@ -150,11 +150,10 @@ subroutine CNN_init(Time,G,GV,US,param_file,diag, dbcomms_CS, CS) end subroutine CNN_init !> Manage input and output of CNN model -subroutine CNN_inference(u, v, h, diffu, diffv, G, GV, VarMix, FP_CS, SS_CS, CNN, python_bridge_lib,python_data_collect) +subroutine CNN_inference(u, v, h, diffu, diffv, G, GV, VarMix, SS_CS, CNN, python_bridge_lib,python_data_collect) type(ocean_grid_type), intent(in) :: G !< The ocean's grid structure. type(verticalGrid_type), intent(in) :: GV !< The ocean's vertical grid structure. type(VarMix_CS), intent(in) :: VarMix !< Variable mixing control struct - type(python_interface), intent(in) :: FP_CS !< Forpy Python interface object type(smartsim_python_interface),intent(in) :: SS_CS !< SmartSim Python interface object type(CNN_CS), intent(in) :: CNN !< Control structure for CNN character(len=*), intent(in) :: python_bridge_lib !< The library used for language bridging @@ -257,24 +256,24 @@ subroutine CNN_inference(u, v, h, diffu, diffv, G, GV, VarMix, FP_CS, SS_CS, CNN enddo ; enddo ! do k=1,nztemp - ! do j=js,je ; do i=is,ie + ! do j=js,je ; do i=is,ie ! WH_u(i,j,k) = 0.01*(PE_here()) ! WH_v(i,j,k) = 0.01*(15.-PE_here()) - ! enddo ; enddo + ! enddo ; enddo ! enddo ! if (is_root_pe()) then ! open(10,file='WH_u0') - ! do j=1,size(WH_u,2) + ! do j=1,size(WH_u,2) ! write(10,100) (WH_u(i,j,1),i=1,size(WH_u,1)) - ! enddo + ! enddo ! close(10) ! endif ! Update the wide halos of WH_u WH_v WH_m - call create_group_pass(pass_uvm,WH_u,CNN%CNN_Domain) - call create_group_pass(pass_uvm,WH_v,CNN%CNN_Domain) - call create_group_pass(pass_uvm,WH_m,CNN%CNN_Domain) - call do_group_pass(pass_uvm,CNN%CNN_Domain) + call create_group_pass(pass_uvm, WH_u, CNN%CNN_Domain) + call create_group_pass(pass_uvm, WH_v, CNN%CNN_Domain) + call create_group_pass(pass_uvm, WH_m, CNN%CNN_Domain) + call do_group_pass(pass_uvm, CNN%CNN_Domain) ! Combine arrays for CNN input if (python_data_collect) then @@ -326,44 +325,44 @@ subroutine CNN_inference(u, v, h, diffu, diffv, G, GV, VarMix, FP_CS, SS_CS, CNN ! write(*,*) "index_global=",index_global ! open(10,file='g3d_u') - ! do j=1,size(g3d_u,2) + ! do j=1,size(g3d_u,2) ! write(10,100) (g3d_u(i,j,1),i=1,size(g3d_u,1)) - ! enddo + ! enddo ! close(10) - + ! open(10,file='g3d_v') - ! do j=1,size(g3d_v,2) + ! do j=1,size(g3d_v,2) ! write(10,100) (g3d_v(i,j,1),i=1,size(g3d_v,1)) - ! enddo + ! enddo ! close(10) ! open(10,file='g3d_m') - ! do j=1,size(g3d_m,2) + ! do j=1,size(g3d_m,2) ! write(10,100) (g3d_m(i,j),i=1,size(g3d_m,1)) - ! enddo + ! enddo ! close(10) ! open(10,file='WH_uv_glob') - ! do j=1,size(WH_uv_glob,3) + ! do j=1,size(WH_uv_glob,3) ! write(10,100) (WH_uv_glob(1,i,j,1),i=1,size(WH_uv_glob,2)) - ! enddo + ! enddo ! close(10) ! open(10,file='WH_m_glob') - ! do j=1,size(WH_m_glob,2) + ! do j=1,size(WH_m_glob,2) ! write(10,100) (WH_m_glob(i,j),i=1,size(WH_m_glob,1)) - ! enddo + ! enddo ! close(10) ! endif - else + else ! not collect data to root PE WH_uv = 0.0 do k=1,nztemp - do j=jsdw,jedw ; do i=isdw,iedw + do j=jsdw,jedw ; do i=isdw,iedw WH_uv(1,i,j,k) = WH_u(i,j,k) WH_uv(2,i,j,k) = WH_v(i,j,k) - enddo ; enddo + enddo ; enddo enddo index_global(1) = G%isc+G%idg_offset index_global(2) = G%iec+G%idg_offset @@ -372,21 +371,21 @@ subroutine CNN_inference(u, v, h, diffu, diffv, G, GV, VarMix, FP_CS, SS_CS, CNN ! if (is_root_pe()) then ! open(10,file='WH_u') - ! do j=1,size(WH_u,2) + ! do j=1,size(WH_u,2) ! write(10,100) (WH_u(i,j,1),i=1,size(WH_u,1)) - ! enddo + ! enddo ! close(10) ! open(10,file='WH_uv') - ! do j=1,size(WH_uv,3) + ! do j=1,size(WH_uv,3) ! write(10,100) (WH_uv(1,i,j,1),i=1,size(WH_uv,2)) - ! enddo + ! enddo ! close(10) ! open(10,file='WH_m') - ! do j=1,size(WH_m,2) + ! do j=1,size(WH_m,2) ! write(10,100) (WH_m(i,j),i=1,size(WH_m,1)) - ! enddo + ! enddo ! close(10) ! endif endif @@ -397,27 +396,34 @@ subroutine CNN_inference(u, v, h, diffu, diffv, G, GV, VarMix, FP_CS, SS_CS, CNN call cpu_clock_begin(CNN%id_cnn_inference) select case (lowercase(python_bridge_lib)) case("forpy") + ! if (python_data_collect) then + ! if (is_root_pe()) call forpy_run_python(WH_uv_glob, Sxy_glob, WH_m_glob, index_global, FP_CS) + ! else + ! call forpy_run_python(WH_uv, Sxy, WH_m, index_global, FP_CS) + ! endif + case("smartsim") if (python_data_collect) then - ! if (is_root_pe()) call forpy_run_python(WH_uv_glob, Sxy_glob, WH_m_glob, index_global, FP_CS) - call sync_PEs() - do l=1,size(Sxy,1) - call mpp_scatter(isg, ieg, jsg, jeg, nztemp, pelist, & - Sxy(l,is:ie,js:je,:), Sxy_glob(l,:,:,:), is_root_pe()) - enddo - if (is_root_pe()) then - deallocate(g3d_u) - deallocate(g3d_v) - deallocate(g3d_m) - deallocate(WH_uv_glob) - deallocate(WH_m_glob) - deallocate(Sxy_glob) - endif + if (is_root_pe()) call smartsim_run_python(WH_uv_glob, Sxy_glob, nztemp, SS_CS, CNN%CNN_halo_size) else - call forpy_run_python(WH_uv, Sxy, WH_m, index_global, FP_CS) + call smartsim_run_python(WH_uv, Sxy, nztemp, SS_CS, CNN%CNN_halo_size) endif - case("smartsim") - call smartsim_run_python(WH_uv, Sxy, SS_CS, CNN%CNN_halo_size) end select + + if (python_data_collect) then + do l=1,size(Sxy,1) + call mpp_scatter(isg, ieg, jsg, jeg, nztemp, pelist, & + Sxy(l,is:ie,js:je,:), Sxy_glob(l,:,:,:), is_root_pe()) + enddo + if (is_root_pe()) then + deallocate(g3d_u) + deallocate(g3d_v) + deallocate(g3d_m) + deallocate(WH_uv_glob) + deallocate(WH_m_glob) + deallocate(Sxy_glob) + endif + endif + call cpu_clock_end(CNN%id_cnn_inference) !Extract data from CNN output @@ -425,7 +431,7 @@ subroutine CNN_inference(u, v, h, diffu, diffv, G, GV, VarMix, FP_CS, SS_CS, CNN call cpu_clock_begin(CNN%id_cnn_post1) Sx=0.0; Sy=0.0; Sxmean=0.0; Symean=0.0; Sxstd=0.0; Systd=0.0; do k=1,nz - do j=js,je ; do i=is,ie + do j=js,je ; do i=is,ie if (CNN%CNN_VS /= 'none') then Sx(i,j,k) = Sxy(1,i,j,1)*vs(i,j,k) Sy(i,j,k) = Sxy(2,i,j,1)*vs(i,j,k) @@ -441,7 +447,7 @@ subroutine CNN_inference(u, v, h, diffu, diffv, G, GV, VarMix, FP_CS, SS_CS, CNN Sxstd(i,j,k) = Sxy(5,i,j,k) Systd(i,j,k) = Sxy(6,i,j,k) endif - enddo ; enddo + enddo ; enddo enddo ! Delocate variables diff --git a/src/parameterizations/lateral/MOM_hor_visc.F90 b/src/parameterizations/lateral/MOM_hor_visc.F90 index e1e32f90bc..25d15b984b 100644 --- a/src/parameterizations/lateral/MOM_hor_visc.F90 +++ b/src/parameterizations/lateral/MOM_hor_visc.F90 @@ -1688,7 +1688,7 @@ subroutine horizontal_viscosity(u, v, h, diffu, diffv, MEKE, VarMix, G, GV, US, if (CS%id_diffv_visc_rem > 0) call post_product_v(CS%id_diffv_visc_rem, diffv, ADp%visc_rem_v, G, nz, CS%diag) endif - if (CS%use_hor_visc_python) call CNN_inference(u, v, h, diffu, diffv, G, GV, VarMix, CS%python, CS%smartsim_python, & + if (CS%use_hor_visc_python) call CNN_inference(u, v, h, diffu, diffv, G, GV, VarMix, CS%smartsim_python, & CS%CNN, CS%python_bridge_lib, CS%python_data_collect) !Cheng end subroutine horizontal_viscosity