Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IO] Adding functions to estimate max byte size of all fields in a pool #1217

Draft
wants to merge 19 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 22 additions & 22 deletions src/framework/mpas_io_streams.F
Original file line number Diff line number Diff line change
Expand Up @@ -3335,14 +3335,14 @@ subroutine MPAS_writeStream(stream, frame, ierr)

if (field_cursor % field_type == FIELD_0D_INT) then

!call mpas_log_write('Writing out field '//trim(field_cursor % int0dField % fieldName))
!call mpas_log_write(' > is the field decomposed? $l', logicArgs=(/field_cursor % isDecomposed/))
!call mpas_log_write(' > outer dimension size $i', intArgs=(/field_cursor % totalDimSize/))
! call mpas_log_write('Writing out field '//trim(field_cursor % int0dField % fieldName))
! call mpas_log_write(' > is the field decomposed? $l', logicArgs=(/field_cursor % isDecomposed/))
! call mpas_log_write(' > outer dimension size $i', intArgs=(/field_cursor % totalDimSize/))

!call mpas_log_write('Copying field from first block')
!call mpas_log_write('Copying field from first block')
int0d_temp = field_cursor % int0dField % scalar

!call mpas_log_write('MGD calling MPAS_io_put_var now...')
!call mpas_log_write('MGD calling MPAS_io_put_var now...')
call MPAS_io_put_var(stream % fileHandle, field_cursor % int0dField % fieldName, int0d_temp, io_err)
call MPAS_io_err_mesg(stream % fileHandle % ioContext, io_err, .false.)
if (io_err /= MPAS_IO_NOERR .and. present(ierr)) ierr = MPAS_IO_ERR
Expand Down Expand Up @@ -3376,7 +3376,7 @@ subroutine MPAS_writeStream(stream, frame, ierr)
end if

if (field_cursor % int1dField % isVarArray) then
! I suspect we will never hit this code, as it doesn't make sense, really
! I suspect we will never hit this code, as it doesn't make sense, really
int0d_temp = field_1dint_ptr % array(j)
else
int1d_temp(i:i+ownedSize-1) = field_1dint_ptr % array(1:ownedSize)
Expand Down Expand Up @@ -3543,14 +3543,14 @@ subroutine MPAS_writeStream(stream, frame, ierr)

else if (field_cursor % field_type == FIELD_0D_REAL) then

!call mpas_log_write('Writing out field '//trim(field_cursor % real0dField % fieldName))
!call mpas_log_write(' > is the field decomposed? $l', logicArgs=(/field_cursor % isDecomposed/))
!call mpas_log_write(' > outer dimension size $i', intArgs=(/field_cursor % totalDimSize/))
!call mpas_log_write('Writing out field '//trim(field_cursor % real0dField % fieldName))
!call mpas_log_write(' > is the field decomposed? $l', logicArgs=(/field_cursor % isDecomposed/))
!call mpas_log_write(' > outer dimension size $i', intArgs=(/field_cursor % totalDimSize/))

!call mpas_log_write('Copying field from first block')
!call mpas_log_write('Copying field from first block')
real0d_temp = field_cursor % real0dField % scalar

!call mpas_log_write('MGD calling MPAS_io_put_var now...')
!call mpas_log_write('MGD calling MPAS_io_put_var now...')
call MPAS_io_put_var(stream % fileHandle, field_cursor % real0dField % fieldName, real0d_temp, io_err)
call MPAS_io_err_mesg(stream % fileHandle % ioContext, io_err, .false.)
if (io_err /= MPAS_IO_NOERR .and. present(ierr)) ierr = MPAS_IO_ERR
Expand Down Expand Up @@ -3584,7 +3584,7 @@ subroutine MPAS_writeStream(stream, frame, ierr)
end if

if (field_cursor % real1dField % isVarArray) then
! I suspect we will never hit this code, as it doesn't make sense, really
! I suspect we will never hit this code, as it doesn't make sense, really
real0d_temp = field_1dreal_ptr % array(j)
else
real1d_temp(i:i+ownedSize-1) = field_1dreal_ptr % array(1:ownedSize)
Expand Down Expand Up @@ -3891,24 +3891,24 @@ subroutine MPAS_writeStream(stream, frame, ierr)

else if (field_cursor % field_type == FIELD_0D_CHAR) then

!call mpas_log_write('Writing out field '//trim(field_cursor % char0dField % fieldName))
!call mpas_log_write(' > is the field decomposed? $l', logicArgs=(/field_cursor % isDecomposed/))
!call mpas_log_write(' > outer dimension size $i', intArgs=(/field_cursor % totalDimSize/))
!call mpas_log_write('Writing out field '//trim(field_cursor % char0dField % fieldName))
!call mpas_log_write(' > is the field decomposed? $l', logicArgs=(/field_cursor % isDecomposed/))
!call mpas_log_write(' > outer dimension size $i', intArgs=(/field_cursor % totalDimSize/))

!call mpas_log_write('Copying field from first block')
!call mpas_log_write('MGD calling MPAS_io_put_var now...')
!call mpas_log_write('Copying field from first block')
!call mpas_log_write('MGD calling MPAS_io_put_var now...')
call MPAS_io_put_var(stream % fileHandle, field_cursor % char0dField % fieldName, field_cursor % char0dField % scalar, io_err)
call MPAS_io_err_mesg(stream % fileHandle % ioContext, io_err, .false.)
if (io_err /= MPAS_IO_NOERR .and. present(ierr)) ierr = MPAS_IO_ERR

else if (field_cursor % field_type == FIELD_1D_CHAR) then

!call mpas_log_write('Writing out field '//trim(field_cursor % char1dField % fieldName))
!call mpas_log_write(' > is the field decomposed? $l', logicArgs=(/field_cursor % isDecomposed/))
!call mpas_log_write(' > outer dimension size $i', intArgs=(/field_cursor % totalDimSize/))
!call mpas_log_write('Writing out field '//trim(field_cursor % char1dField % fieldName))
!call mpas_log_write(' > is the field decomposed? $l', logicArgs=(/field_cursor % isDecomposed/))
!call mpas_log_write(' > outer dimension size $i', intArgs=(/field_cursor % totalDimSize/))

!call mpas_log_write('Copying field from first block')
!call mpas_log_write('MGD calling MPAS_io_put_var now...')
!call mpas_log_write('Copying field from first block')
!call mpas_log_write('MGD calling MPAS_io_put_var now...')
call MPAS_io_put_var(stream % fileHandle, field_cursor % char1dField % fieldName, field_cursor % char1dField % array, io_err)
call MPAS_io_err_mesg(stream % fileHandle % ioContext, io_err, .false.)
if (io_err /= MPAS_IO_NOERR .and. present(ierr)) ierr = MPAS_IO_ERR
Expand Down
166 changes: 166 additions & 0 deletions src/framework/mpas_stream_manager.F
Original file line number Diff line number Diff line change
Expand Up @@ -3088,6 +3088,8 @@ subroutine write_stream(manager, stream, blockID, timeLevel, mgLevel, forceWrite
logical :: clobberRecords, clobberFiles, truncateFiles
integer :: maxRecords, tempRecord
integer :: local_ierr, threadNum
integer(kind=I8KIND) :: max_var_size_bytes
character(len=StrKIND):: message

threadNum = mpas_threading_get_thread_num()

Expand Down Expand Up @@ -3182,6 +3184,9 @@ subroutine write_stream(manager, stream, blockID, timeLevel, mgLevel, forceWrite
!
! Build stream from pools of fields and attributes
!
max_var_size_bytes = stream_max_var_size(stream % field_pool, manager % allFields)
write(message,fmt='(A,i18)') 'final max_var_size_bytes =',max_var_size_bytes
call mpas_log_write(message)
allocate(stream % stream)
call MPAS_createStream(stream % stream, manager % ioContext, stream % filename, stream % io_type, MPAS_IO_WRITE, &
precision=stream % precision, clobberRecords=clobberRecords, &
Expand Down Expand Up @@ -4325,6 +4330,7 @@ end subroutine gen_random
timeLevel = 1
end if


select case (info % fieldType)
case (MPAS_POOL_REAL)
select case (info % nDims)
Expand Down Expand Up @@ -4495,6 +4501,111 @@ subroutine update_stream(stream, allFields, timeLevelIn, mgLevelIn, ierr) !{{{
end subroutine update_stream !}}}


integer(kind=I8KIND) function stream_max_var_size(field_pool, allFields) !{{{
use iso_c_binding, only : c_sizeof

implicit none

type (mpas_pool_type), intent(inout) :: field_pool
type (MPAS_Pool_type), intent(in) :: allFields

type (MPAS_Pool_iterator_type) :: itr
type (mpas_pool_field_info_type) :: info
integer :: timeLevel

type (field5DReal), pointer :: real5d
type (field4DReal), pointer :: real4d
type (field3DReal), pointer :: real3d
type (field2DReal), pointer :: real2d
type (field1DReal), pointer :: real1d
type (field0DReal), pointer :: real0d

type (field3DInteger), pointer :: int3d
type (field2DInteger), pointer :: int2d
type (field1DInteger), pointer :: int1d
type (field0DInteger), pointer :: int0d

type (field1DChar), pointer :: char1d
type (field0DChar), pointer :: char0d

character :: tmp_char
integer(kind=I8KIND) :: field_bytes, max_field_size, field_size
integer(kind=I8KIND), parameter :: int_size = c_sizeof(1)
integer(kind=I8KIND), parameter :: real_size = c_sizeof(1.0_RKIND)
integer(kind=I8KIND), parameter :: char_size = c_sizeof(tmp_char)
character(len=StrKIND):: message

call mpas_pool_begin_iteration(field_pool)
stream_max_var_size = 0
field_bytes = 0
do while ( mpas_pool_get_next_member(field_pool, itr) )

if (itr % memberType == MPAS_POOL_CONFIG) then

! To avoid accidentally matching in case statements below...
info % fieldType = -1
call mpas_pool_get_field_info(allFields, itr % memberName, info)
! Reading first time level
timeLevel = 1
call mpas_log_write('In check_max_var_size, field '//trim(itr % memberName)//' ndims: $i',intArgs=(/info % nDims/))
select case (info % fieldType)
case (MPAS_POOL_REAL)
select case (info % nDims)
case (0)
field_size = 1
case (1)
call mpas_pool_get_field(allFields, itr % memberName, real1d, timeLevel)
field_size = global_dim_size(real1d % block, real1d % dimNames, real1d % isVarArray)
case (2)
call mpas_pool_get_field(allFields, itr % memberName, real2d, timeLevel)
field_size = global_dim_size(real2d % block, real2d % dimNames, real2d % isVarArray)
case (3)
call mpas_log_write('before get field')
call mpas_pool_get_field(allFields, itr % memberName, real3d, timeLevel)
call mpas_log_write('after get field')
field_size = global_dim_size(real3d % block, real3d % dimNames, real3d % isVarArray)
case (4)
call mpas_pool_get_field(allFields, itr % memberName, real4d, timeLevel)
field_size = global_dim_size(real4d % block, real4d % dimNames, real4d % isVarArray)
case (5)
call mpas_pool_get_field(allFields, itr % memberName, real5d, timeLevel)
field_size = global_dim_size(real5d % block, real5d % dimNames, real5d % isVarArray)
end select
field_bytes = int(field_size, kind=I8KIND) * real_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps field_size should be declared as an I8KIND, and this type conversion should be eliminated.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't there a type conversion involved whether it's in this function or inside global_dim_size . Not sure if mpas_pool_get_dimension can work with I8KIND int pointers?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've now changed this. The only type conversion happens in global_dim_size , due to mpas_pool_get_dimension forcing a regular int. And I've changed to mpas_dmpar_sum_int8 now

case (MPAS_POOL_INTEGER)
select case (info % nDims)
case (0)
field_size = 1
case (1)
call mpas_pool_get_field(allFields, itr % memberName, int1d, timeLevel)
field_size = global_dim_size(int1d % block, int1d % dimNames, int1d % isVarArray)
case (2)
call mpas_pool_get_field(allFields, itr % memberName, int2d, timeLevel)
field_size = global_dim_size(int2d % block, int2d % dimNames, int2d % isVarArray)
case (3)
call mpas_pool_get_field(allFields, itr % memberName, int3d, timeLevel)
field_size = global_dim_size(int3d % block, int3d % dimNames, int3d % isVarArray)
end select
field_bytes = field_size * int_size
case (MPAS_POOL_CHARACTER)
select case (info % nDims)
case (0)
field_size = 1
case (1)
! call mpas_pool_get_field(allFields, itr % memberName, char1d, timeLevel)
call mpas_log_write('In check_max_var_size, unsupported type field1DChar.', MPAS_LOG_ERR)
end select
field_bytes = field_size * char_size
end select
stream_max_var_size = merge(field_bytes, stream_max_var_size, field_bytes > stream_max_var_size)
write(message,fmt='(A,i14,A,i18)') 'check_max_var_size.. field_bytes =',field_bytes,' stream_max_var_size =',stream_max_var_size
call mpas_log_write(message)
end if
end do

end function stream_max_var_size !}}}


!-----------------------------------------------------------------------
! routine stream_active_pkg_check
!
Expand Down Expand Up @@ -4845,6 +4956,61 @@ logical function is_decomposed_dim(dimName) !{{{

end function is_decomposed_dim !}}}



integer(kind=I8KIND) function global_dim_size(block, dimNames, isVarArray) !{{{

implicit none

character(len=*), intent(in), dimension(:) :: dimNames
type(block_type), intent(in) :: block
logical, intent(in) :: isVarArray
integer, pointer :: block_dim_size
integer(kind=I8KIND):: sum_block_dim_size, block_dim_size_int8
integer :: iDim, iDimStart, iDimEnd
character(len=StrKIND):: message

call mpas_log_write('----- from global_dim_size: size $i',intArgs=(/size(dimNames)/))
global_dim_size = 1
! Skip left-most dimension, as constituent elements of varArrays are written out separately
iDimStart = merge(2, 1, isVarArray)
iDimEnd = size(dimNames)
call mpas_log_write('----- from global_dim_size: iDimStart $i iDimEnd $i',intArgs=(/iDimStart,iDimEnd/))
do iDim = iDimStart, iDimEnd
if ( is_decomposed_dim(dimNames(iDim))) then
if (trim(dimNames(iDim)) == 'nCells') then
call mpas_pool_get_dimension(block % dimensions, 'nCellsSolve', block_dim_size)
else if (trim(dimNames(iDim)) == 'nEdges') then
call mpas_pool_get_dimension(block % dimensions, 'nEdgesSolve', block_dim_size)
else if (trim(dimNames(iDim)) == 'nVertices') then
call mpas_pool_get_dimension(block % dimensions, 'nVerticesSolve', block_dim_size)
else
global_dim_size = -1
end if

block_dim_size_int8 = int(block_dim_size, kind=I8KIND)

call mpas_dmpar_sum_int8(block % domain % dminfo, block_dim_size_int8, sum_block_dim_size)
write(message,fmt='(A,i18,A,i18)') '----- from global_dim_size: Dimname is decomposed '//trim(dimNames(iDim))//' local size =',block_dim_size,' sum_block_dim_size=', sum_block_dim_size
call mpas_log_write(message)

global_dim_size = global_dim_size * sum_block_dim_size
else
call mpas_log_write('----- from global_dim_size ----- before get dim... Dimname is not decomposed '//trim(dimNames(iDim)))
call mpas_pool_get_dimension(block % dimensions, dimNames(iDim), block_dim_size)
block_dim_size_int8 = int(block_dim_size, kind=I8KIND)
global_dim_size = global_dim_size * block_dim_size_int8
call mpas_log_write('----- from global_dim_size ----- Dimname is not decomposed '//trim(dimNames(iDim))//' Size $i',intArgs=(/block_dim_size/))
end if
end do

write(message,fmt='(A,i18)') '----- from global_dim_size: cumulative global_dim_size =',global_dim_size
call mpas_log_write(message)

end function global_dim_size !}}}




!-----------------------------------------------------------------------
! routine prewrite_reindex
Expand Down