Skip to content

Commit

Permalink
CCSD trpdrv GPU profiling
Browse files Browse the repository at this point in the history
Former-commit-id: c90f41c
  • Loading branch information
jeffhammond committed Apr 11, 2022
1 parent d61c390 commit e17533c
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 21 deletions.
6 changes: 6 additions & 0 deletions src/ccsd/ccsd_pstat.F
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ subroutine ccsd_pstat_init(rtdb)
$ ps_dovvv)) call errquit('ccsd: ccsd_pstat_init', 0,0)
if (.not. pstat_allocate('ccsd:doxxx', pstat_qstat, 0, junk,
$ ps_doxxx)) call errquit('ccsd: ccsd_pstat_init', 0,0)
if (.not. pstat_allocate('ccsd:gpumove', pstat_qstat, 0, junk,
$ ps_gpumove)) call errquit('ccsd: ccsd_pstat_init', 0,0)
if (.not. pstat_allocate('ccsd:accwait', pstat_qstat, 0, junk,
$ ps_accwait)) call errquit('ccsd: ccsd_pstat_init', 0,0)
if (.not. pstat_allocate('ccsd:z2pm', pstat_qstat, 0, junk,
$ ps_z2pm)) call errquit('ccsd: ccsd_pstat_init', 0,0)
if (.not. pstat_allocate('ccsd:hz2pm', pstat_qstat, 0, junk,
Expand Down Expand Up @@ -97,6 +101,8 @@ subroutine ccsd_pstat_print()
if(.not.pstat_free(ps_idx34))call errquit('ccsd_pstat',0,0)
if(.not.pstat_free(ps_t2eri))call errquit('ccsd_pstat',0,0)
if(.not.pstat_free(ps_trpdrv))call errquit('ccsd_pstat',0,0)
if(.not.pstat_free(ps_accwait))call errquit('ccsd_pstat',0,0)
if(.not.pstat_free(ps_gpumove))call errquit('ccsd_pstat',0,0)
if(.not.pstat_free(ps_doxxx))call errquit('ccsd_pstat',0,0)
if(.not.pstat_free(ps_dovvv))call errquit('ccsd_pstat',0,0)
if(.not.pstat_free(ps_doooo))call errquit('ccsd_pstat',0,0)
Expand Down
104 changes: 85 additions & 19 deletions src/ccsd/ccsd_trpdrv_openacc.F
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,9 @@ subroutine ccsd_trpdrv_openacc(t1,xeorb,
write(6,99)
endif
99 format(2x,'Using Fortran OpenACC+CUBLAS in CCSD(T)')
tt0 = util_wallsec()
agg_flops = 0
!
tt0 = util_wallsec()
! setup CUDA streams
do shi=1,8
err = cudaStreamCreate(stream(shi))
Expand All @@ -132,11 +133,6 @@ subroutine ccsd_trpdrv_openacc(t1,xeorb,
err = cublasSetStream(handle(shi), stream(shi))
if (err.ne.0) call errquit('cublasSetStream',err,UNKNOWN_ERR)
end do
tt1 = util_wallsec()
if (me.eq.0) then
write(6,500) tt1-tt0
500 format('CU init took ',e15.5,' seconds')
endif
!
! device-only temp arrays
! produced by DGEMM, consumed by TENGY
Expand Down Expand Up @@ -183,6 +179,12 @@ subroutine ccsd_trpdrv_openacc(t1,xeorb,
& xKij(1:lnov*nocc), xKkj(1:kchunk*lnov),
& stat=alloc_error)
if (alloc_error.ne.0) call errquit('TKJKD GPU alloc',1,MA_ERR)

tt1 = util_wallsec()
if (me.eq.0) then
write(6,500) tt1-tt0
500 format('CU+MEM init took ',e15.5,' seconds')
endif
!
! call ga_sync() ! ga_sync called just before trpdrv in aoccsd2
!
Expand Down Expand Up @@ -261,8 +263,19 @@ subroutine ccsd_trpdrv_openacc(t1,xeorb,
call ga_nbget(g_objv,1+2*lnoov+(i-1)*lnov,
& 2*lnoov+i*lnov,av,av,Xia,lnov,nbh_objv7)

if (occsdps) then
call pstat_on(ps_accwait)
else
call qenter('accwait',0)
endif
!$acc wait(9)
!$acc wait(10)
if (occsdps) then
call pstat_off(ps_accwait)
else
call qexit('accwait',0)
endif

t1v2(:) = t1((i-1)*nvir+1:i*nvir)
if(i.eq.1) then
call ga_nbwait(nbh_objv1) ! Dja
Expand All @@ -272,26 +285,37 @@ subroutine ccsd_trpdrv_openacc(t1,xeorb,
dintx1(:) = Djia(1:nvir)

do k=klo,min(khi,i)
!$acc wait(9)
!$acc wait(10)
if (occsdps) then
call pstat_on(ps_accwait)
else
call qenter('accwait',0)
endif
!$acc wait(9)
!$acc wait(10)
if (occsdps) then
call pstat_off(ps_accwait)
else
call qexit('accwait',0)
endif

t1v1(:) = t1((k-1)*nvir+1:k*nvir)
dintc2(:) = Dja(1+(k-1)*nvir:k*nvir)
if(i.eq.1) then
call ga_nbwait(nbh_objv4(k)) ! Djka
endif
dintx2(:) = Djka(1+(k-klo)*nvir:(k-klo+1)*nvir)
if (occsdps) then
call pstat_on(ps_doxxx)
else
call qenter('doxxx',0)
endif
!
! These are the input dependencies for the DGEMM calls below.
! We wait on all of them here because GA is not even remotely thread-safe.
! All of these are independent of k, so we wait on them only
! at the first trip of the loop.
!
if (k.eq.klo) then
if (occsdps) then
call pstat_on(ps_gpumove)
else
call qenter('gpumove',0)
endif
call ga_nbwait(nbh_coul2)
!xJia = Jia
err = cudaMemcpyAsync(xJia,Jia,size(Jia),stream(1))
Expand Down Expand Up @@ -389,14 +413,37 @@ subroutine ccsd_trpdrv_openacc(t1,xeorb,
call errquit('cudaStreamSync',err,UNKNOWN_ERR)
endif
enddo
if (occsdps) then
call pstat_off(ps_gpumove)
else
call qexit('gpumove',0)
endif
endif ! k==klo

if (occsdps) then
call pstat_on(ps_doxxx)
else
call qenter('doxxx',0)
endif

tc0 = util_wallsec()

nv4 = nvir ! no possibility of overflow
no4 = nocc
!$acc wait(9)
!$acc wait(10)

if (occsdps) then
call pstat_on(ps_accwait)
else
call qenter('accwait',0)
endif
!$acc wait(9)
!$acc wait(10)
if (occsdps) then
call pstat_off(ps_accwait)
else
call qexit('accwait',0)
endif

err = cublasDgemm_v2(handle(1),
& cu_op_n,cu_op_t,
& nv4,nv4,nv4,1.0d0,
Expand Down Expand Up @@ -628,8 +675,8 @@ subroutine ccsd_trpdrv_openacc(t1,xeorb,
end do
end do
end if ! (i.ne.k)
tengy_flops = nvir*nvir*( 3 + 2*( 12 + 11 + 11 ) + 2*27 )
agg_flops = agg_flops + tengy_flops
tengy_flops = nvir*nvir*( 3 + 2*( 12 + 11 + 11 ) + 2*27 )
agg_flops = agg_flops + tengy_flops

tc1 = util_wallsec()

Expand All @@ -641,8 +688,20 @@ subroutine ccsd_trpdrv_openacc(t1,xeorb,

end do ! k
end do ! i
!$acc wait(9)
!$acc wait(10)

if (occsdps) then
call pstat_on(ps_accwait)
else
call qenter('accwait',0)
endif
!$acc wait(9)
!$acc wait(10)
if (occsdps) then
call pstat_off(ps_accwait)
else
call qexit('accwait',0)
endif

emp4 = emp4 + emp4i
emp5 = emp5 + emp5i
emp4 = emp4 + emp4k
Expand Down Expand Up @@ -690,6 +749,7 @@ subroutine ccsd_trpdrv_openacc(t1,xeorb,
call qexit('trpdrv',0)
endif
!
tt0 = util_wallsec()
deallocate( f1n, f1t, f2n, f2t, f3n, f3t, f4n, f4t,
& stat=alloc_error)
if (alloc_error.ne.0) call errquit('free f[1234][tn]',8,MA_ERR)
Expand All @@ -715,5 +775,11 @@ subroutine ccsd_trpdrv_openacc(t1,xeorb,
err = cudaStreamDestroy(stream(shi))
if (err.ne.0) call errquit('cudaStreamDestroy',err,UNKNOWN_ERR)
end do
!
tt1 = util_wallsec()
if (me.eq.0) then
write(6,501) tt1-tt0
501 format('CU+MEM free took ',e15.5,' seconds')
endif
!
end
4 changes: 2 additions & 2 deletions src/ccsd/ccsdps.fh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ c
, ps_trpdrv,ps_doooo,ps_dovvv,ps_z2pm,ps_hz2pm,ps_zvecs,
, ps_pdiis,ps_ht2pm,ps_t2pm,ps_sxy,ps_itrdrv,ps_tengy,
, ps_t2eriw,ps_t2erin,ps_t2erih,ps_doxxx,ps_rdtrpo,
, ps_trpmos
, ps_trpmos,ps_gpumove,ps_accwait


logical occsdps ! True if gathering stats
Expand All @@ -17,4 +17,4 @@ c
$ ps_trpdrv,ps_doooo,ps_dovvv,ps_z2pm,ps_hz2pm,ps_zvecs,
$ ps_pdiis,ps_ht2pm,ps_t2pm,ps_sxy,ps_itrdrv,ps_tengy,
$ ps_t2eriw,ps_t2erin,ps_t2erih,ps_doxxx,ps_rdtrpo,
$ ps_trpmos
$ ps_trpmos,ps_gpumove,ps_accwait

0 comments on commit e17533c

Please sign in to comment.