Commit 2c52336a authored by Yingzhou Li's avatar Yingzhou Li Committed by Victor Yu

Added functionality of unit overlapping matrix

parent 0c77c087
......@@ -185,6 +185,7 @@ contains
real(r8), save, allocatable :: Vec_ev(:)
real(r8), save, allocatable :: Vec_evold(:)
logical, save :: ovlp_is_unit ! ovlp is unit flag
logical :: convflag ! convergence flag.
!**********************************************!
......@@ -196,6 +197,7 @@ contains
n = r_h%n_state
nact = r_h%n_state
n_state = r_h%n_state
ovlp_is_unit = r_h%ovlp_is_unit
max_iter = r_h%max_iter
max_inneriter = r_h%cheb_max_inneriter
tol_iter = r_h%tol_iter
......@@ -317,7 +319,11 @@ contains
! -- SPsi = S*Psi
if (ijob == SID_ORTH + 3) then
call rci_op_s_multi(iS, task, 'N', m, n, MID_Psi, MID_WORK)
if (ovlp_is_unit) then
call rci_op_copy(iS, task, 'N', MID_Psi, MID_WORK)
else
call rci_op_s_multi(iS, task, 'N', m, n, MID_Psi, MID_WORK)
end if
ijob = ijob + 1
return
end if
......
......@@ -27,8 +27,8 @@ module ELSI_RCI_DAVIDSON
integer(i4), parameter :: MID_PsiExt = 2
integer(i4), parameter :: MID_HPsi = 3
integer(i4), parameter :: MID_HPsiExt = 4
integer(i4), parameter :: MID_SPsi = 5
integer(i4), parameter :: MID_SPsiExt = 6
integer(i4), parameter :: MID_SPsiExt = 5 ! also serve temp space
integer(i4), parameter :: MID_SPsi = 6
!&>
! Matrix ID of size n by n
......@@ -87,6 +87,10 @@ contains
ijob = ijob + 1
iter = 20
call rci_op_null(task)
elseif ((iter > 5) .and. (r_h%ovlp_is_unit)) then
ijob = ijob + 1
iter = 20
call rci_op_null(task)
else
call rci_op_allocate(iS, task, m, n, iter)
end if
......@@ -135,6 +139,10 @@ contains
ijob = ijob + 1
iter = 20
call rci_op_null(task)
elseif ((iter > 5) .and. (r_h%ovlp_is_unit)) then
ijob = ijob + 1
iter = 20
call rci_op_null(task)
else
call rci_op_deallocate(iS, task, iter)
end if
......@@ -182,6 +190,8 @@ contains
real(r8), save :: tol_iter ! convergence tolerance
logical, save :: ovlp_is_unit ! ovlp is unit flag
real(r8), save, allocatable :: Vec_conv(:)
real(r8), save, allocatable :: Vec_ev(:)
real(r8), save, allocatable :: Vec_evsub(:)
......@@ -195,6 +205,7 @@ contains
m = r_h%n_basis
n = r_h%n_state
n_state = r_h%n_state
ovlp_is_unit = r_h%ovlp_is_unit
max_iter = r_h%max_iter
tol_iter = r_h%tol_iter
max_n = r_h%max_n
......@@ -223,7 +234,11 @@ contains
! -- SPsi = S*Psi
if (ijob == SID_INIT + 2) then
call rci_op_s_multi(iS, task, 'N', m, n, MID_Psi, MID_SPsi)
if (ovlp_is_unit) then
call rci_op_null(task)
else
call rci_op_s_multi(iS, task, 'N', m, n, MID_Psi, MID_SPsi)
end if
ijob = ijob + 1
return
end if
......@@ -239,9 +254,15 @@ contains
! -- SR = Psi'*SPsi
if (ijob == SID_INIT + 4) then
call rci_op_gemm(iS, task, 'C', 'N', n, n, m, 1.0_r8, &
MID_Psi, m, MID_SPsi, m, 0.0_r8, &
MID_SR, max_n)
if (ovlp_is_unit) then
call rci_op_gemm(iS, task, 'C', 'N', n, n, m, 1.0_r8, &
MID_Psi, m, MID_Psi, m, 0.0_r8, &
MID_SR, max_n)
else
call rci_op_gemm(iS, task, 'C', 'N', n, n, m, 1.0_r8, &
MID_Psi, m, MID_SPsi, m, 0.0_r8, &
MID_SR, max_n)
end if
ijob = ijob + 1
return
end if
......@@ -303,9 +324,15 @@ contains
! - SPsiExt = SPsi * VRsub
if (ijob == SID_ITER + 2) then
call rci_op_gemm(iS, task, 'N', 'N', m, next, n, 1.0_r8, &
MID_SPsi, m, MID_VRsub, max_n, 0.0_r8, &
MID_SPsiExt, m)
if (ovlp_is_unit) then
call rci_op_gemm(iS, task, 'N', 'N', m, next, n, 1.0_r8, &
MID_Psi, m, MID_VRsub, max_n, 0.0_r8, &
MID_SPsiExt, m)
else
call rci_op_gemm(iS, task, 'N', 'N', m, next, n, 1.0_r8, &
MID_SPsi, m, MID_VRsub, max_n, 0.0_r8, &
MID_SPsiExt, m)
end if
ijob = ijob + 1
return
end if
......@@ -452,17 +479,27 @@ contains
! -- SPsiExt = S*PsiExt
if (ijob == SID_RECONST + 7) then
call rci_op_s_multi(iS, task, 'N', m, next, &
MID_PsiExt, MID_SPsiExt)
if (ovlp_is_unit) then
call rci_op_null(task)
else
call rci_op_s_multi(iS, task, 'N', m, next, &
MID_PsiExt, MID_SPsiExt)
end if
ijob = ijob + 1
return
end if
! -- SR12 = Psi' * SPsiExt
if (ijob == SID_RECONST + 8) then
call rci_op_gemm(iS, task, 'C', 'N', n, next, m, 1.0_r8, &
MID_Psi, m, MID_SPsiExt, m, 0.0_r8, &
MID_WORK1, max_n)
if (ovlp_is_unit) then
call rci_op_gemm(iS, task, 'C', 'N', n, next, m, 1.0_r8, &
MID_Psi, m, MID_PsiExt, m, 0.0_r8, &
MID_WORK1, max_n)
else
call rci_op_gemm(iS, task, 'C', 'N', n, next, m, 1.0_r8, &
MID_Psi, m, MID_SPsiExt, m, 0.0_r8, &
MID_WORK1, max_n)
end if
ijob = ijob + 1
return
end if
......@@ -478,9 +515,15 @@ contains
! -- SR21 = SPsiExt' * Psi
if (ijob == SID_RECONST + 10) then
call rci_op_gemm(iS, task, 'C', 'N', next, n, m, 1.0_r8, &
MID_SPsiExt, m, MID_Psi, m, 0.0_r8, &
MID_WORK1, max_n)
if (ovlp_is_unit) then
call rci_op_gemm(iS, task, 'C', 'N', next, n, m, 1.0_r8, &
MID_PsiExt, m, MID_Psi, m, 0.0_r8, &
MID_WORK1, max_n)
else
call rci_op_gemm(iS, task, 'C', 'N', next, n, m, 1.0_r8, &
MID_SPsiExt, m, MID_Psi, m, 0.0_r8, &
MID_WORK1, max_n)
end if
ijob = ijob + 1
return
end if
......@@ -496,9 +539,15 @@ contains
! -- SR22 = PsiExt' * SPsiExt
if (ijob == SID_RECONST + 12) then
call rci_op_gemm(iS, task, 'C', 'N', next, next, m, 1.0_r8, &
MID_PsiExt, m, MID_SPsiExt, m, 0.0_r8, &
MID_WORK1, max_n)
if (ovlp_is_unit) then
call rci_op_gemm(iS, task, 'C', 'N', next, next, m, 1.0_r8, &
MID_PsiExt, m, MID_PsiExt, m, 0.0_r8, &
MID_WORK1, max_n)
else
call rci_op_gemm(iS, task, 'C', 'N', next, next, m, 1.0_r8, &
MID_PsiExt, m, MID_SPsiExt, m, 0.0_r8, &
MID_WORK1, max_n)
end if
ijob = ijob + 1
return
end if
......@@ -553,9 +602,13 @@ contains
! -- SPsi(1:m, n+(1:next)) = SPsiExt
if (ijob == SID_ITER + 25) then
call rci_op_subcopy(iS, task, m, next, &
MID_SPsiExt, m, 0, 0, &
MID_SPsi, m, 0, n)
if (ovlp_is_unit) then
call rci_op_null(task)
else
call rci_op_subcopy(iS, task, m, next, &
MID_SPsiExt, m, 0, 0, &
MID_SPsi, m, 0, n)
end if
ijob = ijob + 1
return
end if
......@@ -667,16 +720,24 @@ contains
! - SPsiExt = SPsi * VRsub
if (ijob == SID_RESTART + 6) then
call rci_op_gemm(iS, task, 'N', 'N', m, n_state, n, 1.0_r8, &
MID_SPsi, m, MID_VRsub, max_n, 0.0_r8, &
MID_SPsiExt, m)
if (ovlp_is_unit) then
call rci_op_null(task)
else
call rci_op_gemm(iS, task, 'N', 'N', m, n_state, n, 1.0_r8, &
MID_SPsi, m, MID_VRsub, max_n, 0.0_r8, &
MID_SPsiExt, m)
end if
ijob = ijob + 1
return
end if
! - SPsi = SPsiExt
if (ijob == SID_RESTART + 7) then
call rci_op_copy(iS, task, 'N', MID_SPsiExt, MID_SPsi)
if (ovlp_is_unit) then
call rci_op_null(task)
else
call rci_op_copy(iS, task, 'N', MID_SPsiExt, MID_SPsi)
end if
ijob = ijob + 1
return
end if
......@@ -745,7 +806,11 @@ contains
! - SPsiExt = SPsi(:,not conv)
if (ijob == SID_RESTART + 14) then
call rci_op_subcol(iS, task, m, n_state, MID_SPsi, MID_SPsiExt)
if (ovlp_is_unit) then
call rci_op_subcol(iS, task, m, n_state, MID_Psi, MID_SPsiExt)
else
call rci_op_subcol(iS, task, m, n_state, MID_SPsi, MID_SPsiExt)
end if
ijob = ijob + 1
return
end if
......
......@@ -25,15 +25,15 @@ module ELSI_RCI_OMM
!&<
integer(i4), parameter :: MID_C = 1
integer(i4), parameter :: MID_HC = 2
integer(i4), parameter :: MID_SC = 3
integer(i4), parameter :: MID_G = 4
integer(i4), parameter :: MID_PG = 5
integer(i4), parameter :: MID_G = 3
integer(i4), parameter :: MID_PG = 4
integer(i4), parameter :: MID_D = 5
integer(i4), parameter :: MID_HD = 6
integer(i4), parameter :: MID_SD = 7
integer(i4), parameter :: MID_D = 8
integer(i4), parameter :: MID_Gp = 9
integer(i4), parameter :: MID_PGp = 10
integer(i4), parameter :: MID_WORK = 11
integer(i4), parameter :: MID_Gp = 7
integer(i4), parameter :: MID_PGp = 8
integer(i4), parameter :: MID_WORK = 9
integer(i4), parameter :: MID_SC = 10
integer(i4), parameter :: MID_SD = 11
!&>
! Matrix ID of size n by n
......@@ -95,6 +95,10 @@ contains
ijob = ijob + 1
iter = 20
call rci_op_null(task)
elseif ((iter > 9) .and. (r_h%ovlp_is_unit)) then
ijob = ijob + 1
iter = 20
call rci_op_null(task)
else
call rci_op_allocate(iS, task, m, n, iter)
end if
......@@ -143,6 +147,10 @@ contains
ijob = ijob + 1
iter = 20
call rci_op_null(task)
elseif ((iter > 9) .and. (r_h%ovlp_is_unit)) then
ijob = ijob + 1
iter = 20
call rci_op_null(task)
else
call rci_op_deallocate(iS, task, iter)
end if
......@@ -210,6 +218,7 @@ contains
real(r8), save :: TrHdSdd
real(r8), save :: TrHddSdd
logical, save :: ovlp_is_unit ! ovlp is unit flag
logical, save :: conv
!**********************************************!
......@@ -229,6 +238,7 @@ contains
m = r_h%n_basis
n = r_h%n_state
max_n = r_h%max_n
ovlp_is_unit = r_h%ovlp_is_unit
icg = 0
lambda = 0.0_r8
call rci_op_null(task)
......@@ -251,14 +261,23 @@ contains
! -calculate the overlap matrix in WF basis: SW=C^T*S*C
! -- SC = S*C
if (ijob == SID_INIT + 3) then
call rci_op_s_multi(iS, task, 'N', m, n, MID_C, MID_SC)
if (ovlp_is_unit) then
call rci_op_null(task)
else
call rci_op_s_multi(iS, task, 'N', m, n, MID_C, MID_SC)
end if
ijob = ijob + 1
return
end if
! -- SW = C'*SC
if (ijob == SID_INIT + 4) then
call rci_op_gemm(iS, task, 'C', 'N', n, n, m, 1.0_r8, &
MID_C, m, MID_SC, m, 0.0_r8, MID_SW, max_n)
if (ovlp_is_unit) then
call rci_op_gemm(iS, task, 'C', 'N', n, n, m, 1.0_r8, &
MID_C, m, MID_C, m, 0.0_r8, MID_SW, max_n)
else
call rci_op_gemm(iS, task, 'C', 'N', n, n, m, 1.0_r8, &
MID_C, m, MID_SC, m, 0.0_r8, MID_SW, max_n)
end if
ijob = ijob + 1
return
end if
......@@ -280,8 +299,13 @@ contains
end if
! -- G = -2 SC*HW + G
if (ijob == SID_INIT + 7) then
call rci_op_gemm(iS, task, 'N', 'N', m, n, n, -2.0_r8, &
MID_SC, m, MID_HW, max_n, 1.0_r8, MID_G, m)
if (ovlp_is_unit) then
call rci_op_gemm(iS, task, 'N', 'N', m, n, n, -2.0_r8, &
MID_C, m, MID_HW, max_n, 1.0_r8, MID_G, m)
else
call rci_op_gemm(iS, task, 'N', 'N', m, n, n, -2.0_r8, &
MID_SC, m, MID_HW, max_n, 1.0_r8, MID_G, m)
end if
ijob = ijob + 1
return
end if
......@@ -304,8 +328,13 @@ contains
! - calculate SWd = G'*S*C = PG'*SC
if (ijob == SID_INIT + 10) then
call rci_op_gemm(iS, task, 'C', 'N', n, n, m, 1.0_r8, &
MID_PG, m, MID_SC, m, 0.0_r8, MID_SWd, max_n)
if (ovlp_is_unit) then
call rci_op_gemm(iS, task, 'C', 'N', n, n, m, 1.0_r8, &
MID_PG, m, MID_C, m, 0.0_r8, MID_SWd, max_n)
else
call rci_op_gemm(iS, task, 'C', 'N', n, n, m, 1.0_r8, &
MID_PG, m, MID_SC, m, 0.0_r8, MID_SWd, max_n)
end if
ijob = ijob + 1
return
end if
......@@ -330,15 +359,25 @@ contains
! - calculate SWdd = G'*S*G = PG'*(S*PG)
! -- SD = S*PG
if (ijob == SID_INIT + 13) then
call rci_op_s_multi(iS, task, 'N', m, n, MID_PG, MID_SD)
if (ovlp_is_unit) then
call rci_op_null(task)
else
call rci_op_s_multi(iS, task, 'N', m, n, MID_PG, MID_SD)
end if
ijob = ijob + 1
return
end if
! -- SWdd = PG'*SD
if (ijob == SID_INIT + 14) then
call rci_op_gemm(iS, task, 'C', 'N', n, n, m, 1.0_r8, &
MID_PG, m, MID_SD, m, 0.0_r8, &
MID_SWdd, max_n)
if (ovlp_is_unit) then
call rci_op_gemm(iS, task, 'C', 'N', n, n, m, 1.0_r8, &
MID_PG, m, MID_PG, m, 0.0_r8, &
MID_SWdd, max_n)
else
call rci_op_gemm(iS, task, 'C', 'N', n, n, m, 1.0_r8, &
MID_PG, m, MID_SD, m, 0.0_r8, &
MID_SWdd, max_n)
end if
ijob = SID_COEFF
return
end if
......@@ -522,8 +561,13 @@ contains
end if
! SWd = D'*SC
if (ijob == SID_ITER + 6) then
call rci_op_gemm(iS, task, 'C', 'N', n, n, m, 1.0_r8, &
MID_D, m, MID_SC, m, 0.0_r8, MID_SWd, max_n)
if (ovlp_is_unit) then
call rci_op_gemm(iS, task, 'C', 'N', n, n, m, 1.0_r8, &
MID_D, m, MID_C, m, 0.0_r8, MID_SWd, max_n)
else
call rci_op_gemm(iS, task, 'C', 'N', n, n, m, 1.0_r8, &
MID_D, m, MID_SC, m, 0.0_r8, MID_SWd, max_n)
end if
ijob = ijob + 1
return
end if
......@@ -548,15 +592,25 @@ contains
! - calculate SWdd = D'*S*D
! -- SD = S*D
if (ijob == SID_ITER + 9) then
call rci_op_s_multi(iS, task, 'N', m, n, MID_D, MID_SD)
if (ovlp_is_unit) then
call rci_op_null(task)
else
call rci_op_s_multi(iS, task, 'N', m, n, MID_D, MID_SD)
end if
ijob = ijob + 1
return
end if
! -- SWdd = D'*SD
if (ijob == SID_ITER + 10) then
call rci_op_gemm(iS, task, 'C', 'N', n, n, m, 1.0_r8, &
MID_D, m, MID_SD, m, 0.0_r8, &
MID_SWdd, max_n)
if (ovlp_is_unit) then
call rci_op_gemm(iS, task, 'C', 'N', n, n, m, 1.0_r8, &
MID_D, m, MID_D, m, 0.0_r8, &
MID_SWdd, max_n)
else
call rci_op_gemm(iS, task, 'C', 'N', n, n, m, 1.0_r8, &
MID_D, m, MID_SD, m, 0.0_r8, &
MID_SWdd, max_n)
end if
ijob = SID_COEFF
return
end if
......@@ -672,7 +726,11 @@ contains
! - SC = x_min*SD + SC
if (ijob == SID_UPDATE + 11) then
call rci_op_axpy(iS, task, m, n, x_min, MID_SD, m, MID_SC, m)
if (ovlp_is_unit) then
call rci_op_null(task)
else
call rci_op_axpy(iS, task, m, n, x_min, MID_SD, m, MID_SC, m)
end if
ijob = ijob + 1
return
end if
......@@ -686,8 +744,13 @@ contains
end if
! - G = -2 SC*HW + G
if (ijob == SID_UPDATE + 13) then
call rci_op_gemm(iS, task, 'N', 'N', m, n, n, -2.0_r8, &
MID_SC, m, MID_HW, max_n, 1.0_r8, MID_G, m)
if (ovlp_is_unit) then
call rci_op_gemm(iS, task, 'N', 'N', m, n, n, -2.0_r8, &
MID_C, m, MID_HW, max_n, 1.0_r8, MID_G, m)
else
call rci_op_gemm(iS, task, 'N', 'N', m, n, n, -2.0_r8, &
MID_SC, m, MID_HW, max_n, 1.0_r8, MID_G, m)
end if
ijob = ijob + 1
return
end if
......@@ -772,10 +835,17 @@ contains
! - SW = C'*SC
if (ijob == SID_FINISH + 2) then
call rci_op_gemm(iS, task, 'C', 'N', n, n, &
m, 1.0_r8, &
MID_C, m, MID_SC, m, 0.0_r8, &
MID_SW, max_n)
if (ovlp_is_unit) then
call rci_op_gemm(iS, task, 'C', 'N', n, n, &
m, 1.0_r8, &
MID_C, m, MID_C, m, 0.0_r8, &
MID_SW, max_n)
else
call rci_op_gemm(iS, task, 'C', 'N', n, n, &
m, 1.0_r8, &
MID_C, m, MID_SC, m, 0.0_r8, &
MID_SW, max_n)
end if
ijob = ijob + 1
return
end if
......
This diff is collapsed.
......@@ -9,7 +9,8 @@ LIST(APPEND ftest_rci_src
test_rci_ev_real_den.f90
# test_rci_ev_real_csc.f90
test_rci_ev_cmplx_den.f90
test_rci_ev_cmplx_pw.f90)
test_rci_ev_cmplx_pw.f90
test_rci_ev_cmplx_pw_unit.f90)
ADD_EXECUTABLE(elsi_rci_test ${ftest_rci_src})
TARGET_LINK_LIBRARIES(elsi_rci_test PRIVATE elsi_rci)
......
......@@ -80,7 +80,7 @@ program elsi_rci_test
case("3") ! Planewave
select case(arg3(1:1))
case("c") ! complex
call test_rci_ev_cmplx_pw(solver,arg5)
call test_rci_ev_cmplx_pw_unit(solver,arg5)
case default
call test_die()
end select
......
This diff is collapsed.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment