Skip to content

Commit

Permalink
Additional adjustments for stellar
Browse files Browse the repository at this point in the history
  • Loading branch information
ashao committed May 19, 2024
1 parent 26f74a0 commit 6439894
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 52 deletions.
108 changes: 57 additions & 51 deletions src/parameterizations/lateral/MOM_CNN_GZ21.F90
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ module MOM_CNN_GZ21
integer :: jedw !< The upper j-memory limit for the wide halo arrays.
integer :: CNN_halo_size !< Halo size at each side of subdomains

character(len=200) :: CNN_VS !< default = "none". Vertical profile of CNN momentum forcing
character(len=200) :: CNN_VS !< default = "none". Vertical profile of CNN momentum forcing

type(diag_ctrl), pointer :: diag => NULL() !< A type that regulates diagnostics output
!>@{ Diagnostic handles
Expand Down Expand Up @@ -120,7 +120,7 @@ subroutine CNN_init(Time,G,GV,US,param_file,diag, dbcomms_CS, CS)
CS%id_Systd = register_diag_field('ocean_model', 'Systd', diag%axesTL, Time, &
'Meridional Acceleration from CNN model standard deviation part', &
'm s-2', conversion=US%L_T2_to_m_s2)

call get_param(param_file, mdl, "CNN_VS", CS%CNN_VS, &
"Vertical profile of momentum forcing from CNN :\n" // &
" 'none': infer CNN at each layer\n"// &
Expand Down Expand Up @@ -150,11 +150,10 @@ subroutine CNN_init(Time,G,GV,US,param_file,diag, dbcomms_CS, CS)
end subroutine CNN_init

!> Manage input and output of CNN model
subroutine CNN_inference(u, v, h, diffu, diffv, G, GV, VarMix, FP_CS, SS_CS, CNN, python_bridge_lib,python_data_collect)
subroutine CNN_inference(u, v, h, diffu, diffv, G, GV, VarMix, SS_CS, CNN, python_bridge_lib,python_data_collect)
type(ocean_grid_type), intent(in) :: G !< The ocean's grid structure.
type(verticalGrid_type), intent(in) :: GV !< The ocean's vertical grid structure.
type(VarMix_CS), intent(in) :: VarMix !< Variable mixing control struct
type(python_interface), intent(in) :: FP_CS !< Forpy Python interface object
type(smartsim_python_interface),intent(in) :: SS_CS !< SmartSim Python interface object
type(CNN_CS), intent(in) :: CNN !< Control structure for CNN
character(len=*), intent(in) :: python_bridge_lib !< The library used for language bridging
Expand Down Expand Up @@ -257,24 +256,24 @@ subroutine CNN_inference(u, v, h, diffu, diffv, G, GV, VarMix, FP_CS, SS_CS, CNN
enddo ; enddo

! do k=1,nztemp
! do j=js,je ; do i=is,ie
! do j=js,je ; do i=is,ie
! WH_u(i,j,k) = 0.01*(PE_here())
! WH_v(i,j,k) = 0.01*(15.-PE_here())
! enddo ; enddo
! enddo ; enddo
! enddo
! if (is_root_pe()) then
! open(10,file='WH_u0')
! do j=1,size(WH_u,2)
! do j=1,size(WH_u,2)
! write(10,100) (WH_u(i,j,1),i=1,size(WH_u,1))
! enddo
! enddo
! close(10)
! endif

! Update the wide halos of WH_u WH_v WH_m
call create_group_pass(pass_uvm,WH_u,CNN%CNN_Domain)
call create_group_pass(pass_uvm,WH_v,CNN%CNN_Domain)
call create_group_pass(pass_uvm,WH_m,CNN%CNN_Domain)
call do_group_pass(pass_uvm,CNN%CNN_Domain)
call create_group_pass(pass_uvm, WH_u, CNN%CNN_Domain)
call create_group_pass(pass_uvm, WH_v, CNN%CNN_Domain)
call create_group_pass(pass_uvm, WH_m, CNN%CNN_Domain)
call do_group_pass(pass_uvm, CNN%CNN_Domain)

! Combine arrays for CNN input
if (python_data_collect) then
Expand Down Expand Up @@ -326,44 +325,44 @@ subroutine CNN_inference(u, v, h, diffu, diffv, G, GV, VarMix, FP_CS, SS_CS, CNN
! write(*,*) "index_global=",index_global

! open(10,file='g3d_u')
! do j=1,size(g3d_u,2)
! do j=1,size(g3d_u,2)
! write(10,100) (g3d_u(i,j,1),i=1,size(g3d_u,1))
! enddo
! enddo
! close(10)

! open(10,file='g3d_v')
! do j=1,size(g3d_v,2)
! do j=1,size(g3d_v,2)
! write(10,100) (g3d_v(i,j,1),i=1,size(g3d_v,1))
! enddo
! enddo
! close(10)

! open(10,file='g3d_m')
! do j=1,size(g3d_m,2)
! do j=1,size(g3d_m,2)
! write(10,100) (g3d_m(i,j),i=1,size(g3d_m,1))
! enddo
! enddo
! close(10)

! open(10,file='WH_uv_glob')
! do j=1,size(WH_uv_glob,3)
! do j=1,size(WH_uv_glob,3)
! write(10,100) (WH_uv_glob(1,i,j,1),i=1,size(WH_uv_glob,2))
! enddo
! enddo
! close(10)

! open(10,file='WH_m_glob')
! do j=1,size(WH_m_glob,2)
! do j=1,size(WH_m_glob,2)
! write(10,100) (WH_m_glob(i,j),i=1,size(WH_m_glob,1))
! enddo
! enddo
! close(10)
! endif

else
else
! not collect data to root PE
WH_uv = 0.0
do k=1,nztemp
do j=jsdw,jedw ; do i=isdw,iedw
do j=jsdw,jedw ; do i=isdw,iedw
WH_uv(1,i,j,k) = WH_u(i,j,k)
WH_uv(2,i,j,k) = WH_v(i,j,k)
enddo ; enddo
enddo ; enddo
enddo
index_global(1) = G%isc+G%idg_offset
index_global(2) = G%iec+G%idg_offset
Expand All @@ -372,21 +371,21 @@ subroutine CNN_inference(u, v, h, diffu, diffv, G, GV, VarMix, FP_CS, SS_CS, CNN

! if (is_root_pe()) then
! open(10,file='WH_u')
! do j=1,size(WH_u,2)
! do j=1,size(WH_u,2)
! write(10,100) (WH_u(i,j,1),i=1,size(WH_u,1))
! enddo
! enddo
! close(10)

! open(10,file='WH_uv')
! do j=1,size(WH_uv,3)
! do j=1,size(WH_uv,3)
! write(10,100) (WH_uv(1,i,j,1),i=1,size(WH_uv,2))
! enddo
! enddo
! close(10)

! open(10,file='WH_m')
! do j=1,size(WH_m,2)
! do j=1,size(WH_m,2)
! write(10,100) (WH_m(i,j),i=1,size(WH_m,1))
! enddo
! enddo
! close(10)
! endif
endif
Expand All @@ -397,35 +396,42 @@ subroutine CNN_inference(u, v, h, diffu, diffv, G, GV, VarMix, FP_CS, SS_CS, CNN
call cpu_clock_begin(CNN%id_cnn_inference)
select case (lowercase(python_bridge_lib))
case("forpy")
! if (python_data_collect) then
! if (is_root_pe()) call forpy_run_python(WH_uv_glob, Sxy_glob, WH_m_glob, index_global, FP_CS)
! else
! call forpy_run_python(WH_uv, Sxy, WH_m, index_global, FP_CS)
! endif
case("smartsim")
if (python_data_collect) then
! if (is_root_pe()) call forpy_run_python(WH_uv_glob, Sxy_glob, WH_m_glob, index_global, FP_CS)
call sync_PEs()
do l=1,size(Sxy,1)
call mpp_scatter(isg, ieg, jsg, jeg, nztemp, pelist, &
Sxy(l,is:ie,js:je,:), Sxy_glob(l,:,:,:), is_root_pe())
enddo
if (is_root_pe()) then
deallocate(g3d_u)
deallocate(g3d_v)
deallocate(g3d_m)
deallocate(WH_uv_glob)
deallocate(WH_m_glob)
deallocate(Sxy_glob)
endif
if (is_root_pe()) call smartsim_run_python(WH_uv_glob, Sxy_glob, nztemp, SS_CS, CNN%CNN_halo_size)
else
call forpy_run_python(WH_uv, Sxy, WH_m, index_global, FP_CS)
call smartsim_run_python(WH_uv, Sxy, nztemp, SS_CS, CNN%CNN_halo_size)
endif
case("smartsim")
call smartsim_run_python(WH_uv, Sxy, SS_CS, CNN%CNN_halo_size)
end select

if (python_data_collect) then
do l=1,size(Sxy,1)
call mpp_scatter(isg, ieg, jsg, jeg, nztemp, pelist, &
Sxy(l,is:ie,js:je,:), Sxy_glob(l,:,:,:), is_root_pe())
enddo
if (is_root_pe()) then
deallocate(g3d_u)
deallocate(g3d_v)
deallocate(g3d_m)
deallocate(WH_uv_glob)
deallocate(WH_m_glob)
deallocate(Sxy_glob)
endif
endif

call cpu_clock_end(CNN%id_cnn_inference)

!Extract data from CNN output
call cpu_clock_begin(CNN%id_cnn_post)
call cpu_clock_begin(CNN%id_cnn_post1)
Sx=0.0; Sy=0.0; Sxmean=0.0; Symean=0.0; Sxstd=0.0; Systd=0.0;
do k=1,nz
do j=js,je ; do i=is,ie
do j=js,je ; do i=is,ie
if (CNN%CNN_VS /= 'none') then
Sx(i,j,k) = Sxy(1,i,j,1)*vs(i,j,k)
Sy(i,j,k) = Sxy(2,i,j,1)*vs(i,j,k)
Expand All @@ -441,7 +447,7 @@ subroutine CNN_inference(u, v, h, diffu, diffv, G, GV, VarMix, FP_CS, SS_CS, CNN
Sxstd(i,j,k) = Sxy(5,i,j,k)
Systd(i,j,k) = Sxy(6,i,j,k)
endif
enddo ; enddo
enddo ; enddo
enddo

! Delocate variables
Expand Down
2 changes: 1 addition & 1 deletion src/parameterizations/lateral/MOM_hor_visc.F90
Original file line number Diff line number Diff line change
Expand Up @@ -1688,7 +1688,7 @@ subroutine horizontal_viscosity(u, v, h, diffu, diffv, MEKE, VarMix, G, GV, US,
if (CS%id_diffv_visc_rem > 0) call post_product_v(CS%id_diffv_visc_rem, diffv, ADp%visc_rem_v, G, nz, CS%diag)
endif

if (CS%use_hor_visc_python) call CNN_inference(u, v, h, diffu, diffv, G, GV, VarMix, CS%python, CS%smartsim_python, &
if (CS%use_hor_visc_python) call CNN_inference(u, v, h, diffu, diffv, G, GV, VarMix, CS%smartsim_python, &
CS%CNN, CS%python_bridge_lib, CS%python_data_collect) !Cheng

end subroutine horizontal_viscosity
Expand Down

0 comments on commit 6439894

Please sign in to comment.