Skip to content

Commit

Permalink
Refactor buffers based on Marshall's suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
ashao committed Jan 29, 2025
1 parent c1ab465 commit 3ff9dcb
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 94 deletions.
196 changes: 116 additions & 80 deletions src/framework/MOM_diag_buffers.F90
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,32 @@
!! diagnostics which need to store intermediate or partial states of state variables
module MOM_diag_buffers

use iso_fortran_env, only : stdout => output_unit, stderr => error_unit
use MOM_io, only : stdout, stderr

! This file is part of MOM6. See LICENSE.md for the license.

implicit none ; private

public :: diag_buffer_unit_tests_2d, diag_buffer_unit_tests_3d

type, abstract :: buffer_base
end type buffer_base

type, extends(buffer_base) :: buffer_2d
real, dimension(:,:), allocatable :: field
end type buffer_2d

type, extends(buffer_base) :: buffer_3d
real, dimension(:,:,:), allocatable :: field
end type buffer_3d

!> The base class for the diagnostic buffers in this module
type, abstract :: diag_buffer_base ; private
integer :: is !< The start slot of the array i-direction
integer :: js !< The start slot of the array j-direction
integer :: ie !< The end slot of the array i-direction
integer :: je !< The end slot of the array j-direction
real :: fill_value = 0. !< Set the fill value to use when growing the buffer
real :: fill_value = 0. !< Set the fill value to use when growing the buffer [arbitrary]

integer, allocatable, dimension(:) :: ids !< List of diagnostic ids whose slot corresponds to the row in the buffer
integer :: length = 0 !< The number of slots in the buffer
Expand All @@ -33,8 +44,8 @@ module MOM_diag_buffers
end type diag_buffer_base

!> Dynamically growing buffer for 2D arrays.
type, extends(diag_buffer_base), public :: diag_buffer_2d
real, public, allocatable, dimension(:,:,:) :: buffer !< The actual buffer to store data [arbitrary]
type, extends(diag_buffer_base), public :: diag_buffer_2d; private
type(buffer_2d), public, dimension(:), allocatable :: buffer !< The actual 2D buffer which will dynamically grow

contains

Expand All @@ -44,9 +55,9 @@ module MOM_diag_buffers

!> Dynamically growing buffer for 3D arrays.
type, extends(diag_buffer_base), public :: diag_buffer_3d ; private
type(buffer_3d), public, dimension(:), allocatable :: buffer !< The actual 2D buffer which will dynamically grow
integer :: ks !< The start slot in the k-dimension
integer :: ke !< The last slot in the k-dimension
real, public, allocatable, dimension(:,:,:,:) :: buffer !< The actual buffer to store data [arbitrary]

contains

Expand Down Expand Up @@ -152,21 +163,29 @@ subroutine set_vertical_extent(this, ks, ke)
this%ks = ks; this%ke = ke
end subroutine set_vertical_extent

!> Grow a buffer for 2D arrays
!> Grow a 2d diagnostic buffer
subroutine grow_2d(this)
class(diag_buffer_2d), intent(inout) :: this !< This 2d buffer
class(diag_buffer_2d), intent(inout) :: this

integer :: n
real, allocatable, dimension(:,:,:) :: temp ! Temporary array to hold the contents of the buffer
! prior to allocating new memory [arbitrary]
integer :: i, n
integer :: is, ie, js, je
type(buffer_2d), dimension(:), allocatable :: new_buffer

! Grow the ID array
call this%grow_ids()

is = this%is; ie=this%ie; js=this%js; je=this%je
n = this%length
allocate(temp(n+1, this%is:this%ie, this%js:this%je), source=this%fill_value)
if (n>0) temp(1:n,:,:) = this%buffer(:,:,:)
call move_alloc(temp, this%buffer)
this%length = this%length + 1

allocate(new_buffer(n+1))
do i=1,n
allocate(new_buffer(i)%field(is:ie,js:je))
new_buffer(i)%field(:,:) = this%buffer(i)%field(:,:)
enddo
allocate(new_buffer(n+1)%field(is:ie,js:je), source=this%fill_value)
call move_alloc(new_buffer, this%buffer)
this%length = n+1

end subroutine grow_2d

!> Store a 2D array into this buffer
Expand All @@ -178,24 +197,32 @@ subroutine store_2d(this, data, id)
integer :: slot

slot = this%check_capacity_by_id(id)
this%buffer(slot,:,:) = data(:,:)
this%buffer(slot)%field(:,:) = data(:,:)
end subroutine store_2d

!> Grow a buffer for 3d arrays
!> Grow a 2d diagnostic buffer
subroutine grow_3d(this)
class(diag_buffer_3d), intent(inout) :: this !< This 3d buffer
class(diag_buffer_3d), intent(inout) :: this

integer :: n
real, allocatable, dimension(:,:,:,:) :: temp ! Temporary array to hold the contents of the buffer
! prior to allocating new memory [arbitrary]
integer :: i, n
integer :: is, ie, js, je, ks, ke
type(buffer_3d), dimension(:), allocatable :: new_buffer

! Grow the ID array
call this%grow_ids()

is = this%is; ie=this%ie; js=this%js; je=this%je; ks=this%ks; ke=this%ke
n = this%length
allocate(temp(n+1, this%is:this%ie, this%js:this%je, this%ks:this%ke), source=this%fill_value)
if (n>0) temp(1:n,:,:,:) = this%buffer(:,:,:,:)
call move_alloc(temp, this%buffer)
this%length = this%length + 1

allocate(new_buffer(n+1))
do i=1,n
allocate(new_buffer(i)%field(is:ie,js:je,ks:ke))
new_buffer(i)%field(:,:,:) = this%buffer(i)%field(:,:,:)
enddo
allocate(new_buffer(n+1)%field(is:ie,js:je,ks:ke), source=this%fill_value)
call move_alloc(new_buffer, this%buffer)
this%length = n+1

end subroutine grow_3d

!> Store a 3d array into this buffer
Expand All @@ -208,7 +235,7 @@ subroutine store_3d(this, data, id)

! Find the first slot in the ids array that is 0, i.e. this is a portion of the buffer that can be reused
slot = this%check_capacity_by_id(id)
this%buffer(slot,:,:,:) = data(:,:,:)
this%buffer(slot)%field(:,:,:) = data(:,:,:)
end subroutine store_3d

!> Unit tests for the 2d version of the diag buffer
Expand All @@ -227,55 +254,57 @@ function diag_buffer_unit_tests_2d(verbose) result(fail)

!> Ensure properties of a newly initialized buffer
function new_buffer_2d() result(local_fail)
type(diag_buffer_2d) :: buffer_2d
type(diag_buffer_2d) :: buffer
logical :: local_fail !< True if any of the unit tests fail
local_fail = .false.
local_fail = local_fail .or. allocated(buffer_2d%buffer)
local_fail = local_fail .or. allocated(buffer_2d%ids)
local_fail = local_fail .or. buffer_2d%length /= 0
local_fail = local_fail .or. allocated(buffer%buffer)
if (verbose) write(stdout,*) "new_buffer_2d: ", local_fail
local_fail = local_fail .or. allocated(buffer%ids)
if (verbose) write(stdout,*) "new_buffer_2d: ", local_fail
local_fail = local_fail .or. buffer%length /= 0
if (verbose) write(stdout,*) "new_buffer_2d: ", local_fail
end function new_buffer_2d

!> Test the growing of a buffer
function grow_buffer_2d() result(local_fail)
type(diag_buffer_2d) :: buffer_2d
type(diag_buffer_2d) :: buffer
logical :: local_fail !< True if any of the unit tests fail
integer, parameter :: is=1, ie=2, js=3, je=6
integer :: i

local_fail = .false.

buffer_2d = diag_buffer_2d(is=is, ie=ie, js=js, je=je)
call buffer%set_horizontal_extents(is=is, ie=ie, js=js, je=je)
! Grow the buffer 3 times
do i=1,3
call buffer_2d%grow()
local_fail = local_fail .or. (buffer_2d%length /= i)
local_fail = local_fail .or. (size(buffer_2d%buffer, 1) /= i)
local_fail = local_fail .or. (lbound(buffer_2d%buffer, 2) /= is)
local_fail = local_fail .or. (ubound(buffer_2d%buffer, 2) /= ie)
local_fail = local_fail .or. (lbound(buffer_2d%buffer, 3) /= js)
local_fail = local_fail .or. (ubound(buffer_2d%buffer, 3) /= je)
call buffer%grow()
local_fail = local_fail .or. (buffer%length /= i)
local_fail = local_fail .or. (lbound(buffer%buffer(i)%field, 1) /= is)
local_fail = local_fail .or. (ubound(buffer%buffer(i)%field, 1) /= ie)
local_fail = local_fail .or. (lbound(buffer%buffer(i)%field, 2) /= js)
local_fail = local_fail .or. (ubound(buffer%buffer(i)%field, 2) /= je)
enddo
if (verbose) write(stdout,*) "grow_buffer_2d: ", local_fail
end function grow_buffer_2d

!> Test storing a buffer based on a unique id
function store_buffer_2d() result(local_fail)
type(diag_buffer_2d) :: buffer_2d
type(diag_buffer_2d) :: buffer
logical :: local_fail !< True if any of the unit tests fail

integer, parameter :: is=1, ie=2, js=3, je=6, nlen=3
integer :: i
integer :: i, slot
real, allocatable, dimension(:,:,:) :: test_2d

allocate(test_2d(nlen, is:ie, js:je))
call random_number(test_2d)
buffer_2d = diag_buffer_2d(is=is, ie=ie, js=js, je=je)
buffer = diag_buffer_2d(is=is, ie=ie, js=js, je=je)

do i=1,nlen
call buffer_2d%store(test_2d(i,:,:), i*3)
call buffer%store(test_2d(i,:,:), i*3)
slot = buffer%find_buffer_slot(i*3)
local_fail = local_fail .or. ANY(buffer%buffer(slot)%field(:,:) /= test_2d(i,:,:))
enddo
local_fail = ANY(buffer_2d%buffer /= test_2d)

if (verbose) write(stdout,*) "store_buffer_2d: ", local_fail
end function store_buffer_2d
Expand All @@ -284,7 +313,7 @@ end function store_buffer_2d
!! loop through again, but use the slots of the buffer in the following
!! order: 2, 1, 3
function reuse_buffer_2d() result(local_fail)
type(diag_buffer_2d) :: buffer_2d
type(diag_buffer_2d) :: buffer
logical :: local_fail !< True if any of the unit tests fail

integer, parameter :: is=1, ie=2, js=3, je=6, nlen=3
Expand All @@ -296,24 +325,26 @@ function reuse_buffer_2d() result(local_fail)
call random_number(test_2d_first)
call random_number(test_2d_second)

buffer_2d = diag_buffer_2d(is=is, ie=ie, js=js, je=je)
call buffer%set_horizontal_extents(is=is, ie=ie, js=js, je=je)

do i=1,nlen
call buffer_2d%store(test_2d_first(i,:,:), id=i*3)
call buffer%store(test_2d_first(i,:,:), id=i*3)
enddo

do i=1,nlen
new_i = reorder(i)
! id and new_id are multiplied by primes to make sure they are unique
id = reorder(i)*3
new_id = i*7
call buffer_2d%mark_available(id=reorder(i)*3)
call buffer_2d%store(test_2d_second(i,:,:), id=new_id)
local_fail = local_fail .or. buffer_2d%find_buffer_slot(new_id) /= new_i
call buffer%mark_available(id=reorder(i)*3)
call buffer%store(test_2d_second(i,:,:), id=new_id)
local_fail = local_fail .or. buffer%find_buffer_slot(new_id) /= new_i
test_2d_first(new_i,:,:) = test_2d_second(i,:,:)
enddo
local_fail = local_fail .or. any(buffer_2d%ids /= [14, 7, 21])
local_fail = local_fail .or. any(buffer_2d%buffer /= test_2d_first)
local_fail = local_fail .or. any(buffer%ids /= [14, 7, 21])
do i=1,nlen
local_fail = local_fail .or. any(buffer%buffer(i)%field(:,:) /= test_2d_first(i,:,:))
enddo
if (verbose) write(stdout,*) "reuse_buffer_2d: ", local_fail
end function reuse_buffer_2d

Expand All @@ -335,56 +366,59 @@ function diag_buffer_unit_tests_3d(verbose) result(fail)

!> Ensure properties of a newly initialized buffer
function new_buffer_3d() result(local_fail)
type(diag_buffer_3d) :: buffer_3d
type(diag_buffer_3d) :: buffer
logical :: local_fail !< True if any of the unit tests fail
local_fail = .false.
local_fail = local_fail .or. allocated(buffer_3d%buffer)
local_fail = local_fail .or. allocated(buffer_3d%ids)
local_fail = local_fail .or. buffer_3d%length /= 0
local_fail = local_fail .or. allocated(buffer%buffer)
local_fail = local_fail .or. allocated(buffer%ids)
local_fail = local_fail .or. buffer%length /= 0
if (verbose) write(stdout,*) "new_buffer_3d: ", local_fail
end function new_buffer_3d

!> Test the growing of a buffer
function grow_buffer_3d() result(local_fail)
type(diag_buffer_3d) :: buffer_3d
type(diag_buffer_3d) :: buffer
logical :: local_fail !< True if any of the unit tests fail
integer, parameter :: is=1, ie=2, js=3, je=6, ks=1, ke=10
integer :: i

local_fail = .false.

buffer_3d = diag_buffer_3d(is=is, ie=ie, js=js, je=je, ks=ks, ke=ke)
call buffer%set_horizontal_extents(is=is, ie=ie, js=js, je=je)
call buffer%set_vertical_extent(ks=ks, ke=ke)
! Grow the buffer 3 times
do i=1,3
call buffer_3d%grow()
local_fail = local_fail .or. (buffer_3d%length /= i)
local_fail = local_fail .or. (size(buffer_3d%buffer, 1) /= i)
local_fail = local_fail .or. (lbound(buffer_3d%buffer, 2) /= is)
local_fail = local_fail .or. (ubound(buffer_3d%buffer, 2) /= ie)
local_fail = local_fail .or. (lbound(buffer_3d%buffer, 3) /= js)
local_fail = local_fail .or. (ubound(buffer_3d%buffer, 3) /= je)
local_fail = local_fail .or. (lbound(buffer_3d%buffer, 4) /= ks)
local_fail = local_fail .or. (ubound(buffer_3d%buffer, 4) /= ke)
call buffer%grow()
local_fail = local_fail .or. (buffer%length /= i)
local_fail = local_fail .or. (lbound(buffer%buffer(i)%field, 1) /= is)
local_fail = local_fail .or. (ubound(buffer%buffer(i)%field, 1) /= ie)
local_fail = local_fail .or. (lbound(buffer%buffer(i)%field, 2) /= js)
local_fail = local_fail .or. (ubound(buffer%buffer(i)%field, 2) /= je)
local_fail = local_fail .or. (lbound(buffer%buffer(i)%field, 3) /= ks)
local_fail = local_fail .or. (ubound(buffer%buffer(i)%field, 3) /= ke)
if (verbose) write(stdout,*) "grow_buffer_3d: ", local_fail
enddo
if (verbose) write(stdout,*) "grow_buffer_3d: ", local_fail
end function grow_buffer_3d

!> Test storing a buffer based on a unique id
function store_buffer_3d() result(local_fail)
type(diag_buffer_3d) :: buffer_3d
type(diag_buffer_3d) :: buffer
logical :: local_fail !< True if any of the unit tests fail

integer, parameter :: is=1, ie=2, js=3, je=6, ks=1, ke=10, nlen=3
integer :: i
real, dimension(nlen, is:ie, js:je, ks:ke) :: test_3d
integer :: i, slot
real, dimension(nlen,is:ie,js:je,ks:ke) :: test_3d

local_fail = .false.
call random_number(test_3d)
buffer_3d = diag_buffer_3d(is=is, ie=ie, js=js, je=je, ks=1, ke=10)
buffer = diag_buffer_3d(is=is, ie=ie, js=js, je=je, ks=1, ke=10)

do i=1,nlen
call buffer_3d%store(test_3d(i,:,:,:), i*3)
call buffer%store(test_3d(i,:,:,:), i*3)
slot = buffer%find_buffer_slot(i*3)
local_fail = local_fail .or. ANY(buffer%buffer(slot)%field(:,:,:) /= test_3d(i,:,:,:))
enddo
local_fail = ANY(buffer_3d%buffer /= test_3d)

if (verbose) write(stdout,*) "store_buffer_3d: ", local_fail
end function store_buffer_3d
Expand All @@ -393,7 +427,7 @@ end function store_buffer_3d
!! loop through again, but use the slots of the buffer in the following
!! order: 2, 1, 3
function reuse_buffer_3d() result(local_fail)
type(diag_buffer_3d) :: buffer_3d
type(diag_buffer_3d) :: buffer
logical :: local_fail !< True if any of the unit tests fail

integer, parameter :: is=1, ie=2, js=3, je=6, ks=1, ke=10, nlen=3
Expand All @@ -405,24 +439,26 @@ function reuse_buffer_3d() result(local_fail)
call random_number(test_3d_first)
call random_number(test_3d_second)

buffer_3d = diag_buffer_3d(is=is, ie=ie, js=js, je=je, ks=ks, ke=ke)
buffer = diag_buffer_3d(is=is, ie=ie, js=js, je=je, ks=ks, ke=ke)

do i=1,nlen
call buffer_3d%store(test_3d_first(i,:,:,:), id=i*3)
call buffer%store(test_3d_first(i,:,:,:), id=i*3)
enddo

do i=1,nlen
new_i = reorder(i)
! id and new_id are multiplied by primes to make sure they are unique
id = reorder(i)*3
new_id = i*7
call buffer_3d%mark_available(id=reorder(i)*3)
call buffer_3d%store(test_3d_second(i,:,:,:), id=new_id)
local_fail = local_fail .or. buffer_3d%find_buffer_slot(new_id) /= new_i
call buffer%mark_available(id=reorder(i)*3)
call buffer%store(test_3d_second(i,:,:,:), id=new_id)
local_fail = local_fail .or. buffer%find_buffer_slot(new_id) /= new_i
test_3d_first(new_i,:,:,:) = test_3d_second(i,:,:,:)
enddo
local_fail = local_fail .or. any(buffer_3d%ids /= [14, 7, 21])
local_fail = local_fail .or. any(buffer_3d%buffer /= test_3d_first)
local_fail = local_fail .or. any(buffer%ids /= [14, 7, 21])
do i=1,nlen
local_fail = local_fail .or. any(buffer%buffer(i)%field(:,:,:) /= test_3d_first(i,:,:,:))
enddo
if (verbose) write(stdout,*) "reuse_buffer_3d: ", local_fail
end function reuse_buffer_3d

Expand Down
Loading

0 comments on commit 3ff9dcb

Please sign in to comment.