Commit 9bb857a4 authored by Victor Yu's avatar Victor Yu
Browse files

Clean up density matrix construction code

Next we will try using GPU in the construction of the density matrix from
the computed occupation numbers and eigenvectors.
parent e28f5b44
......@@ -7,7 +7,7 @@ SET(elsi_URL "http://elsi-interchange.org")
SET(elsi_EMAIL "elsi-team@duke.edu")
SET(elsi_LICENSE "BSD 3")
SET(elsi_DESCRIPTION "Electronic Structure Infrastructure")
SET(elsi_DATESTAMP "20200424")
SET(elsi_DATESTAMP "20200427")
### CMake modules ###
LIST(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake)
......
......@@ -16,7 +16,6 @@ module ELSI_DECISION
use ELSI_MPI, only: elsi_check_mpi,MPI_SUM,MPI_INTEGER4,MPI_REAL8
use ELSI_OUTPUT, only: elsi_say
use ELSI_PRECISION, only: r8,i4
use ELSI_UTIL, only: elsi_get_nnz
implicit none
......@@ -76,7 +75,7 @@ subroutine elsi_decide_dm_real(ph,bh,mat)
if(ph%solver == AUTO_SOLVER) then
if(ph%i_spin == 1 .and. ph%i_kpt == 1) then
call elsi_get_nnz(bh%def0,bh%n_lrow,bh%n_lcol,mat,nnz_l)
nnz_l = count(abs(mat) > bh%def0)
call MPI_Allreduce(nnz_l,nnz_g,1,MPI_INTEGER4,MPI_SUM,bh%comm,ierr)
......@@ -114,7 +113,7 @@ subroutine elsi_decide_dm_cmplx(ph,bh,mat)
if(ph%solver == AUTO_SOLVER) then
if(ph%i_spin == 1 .and. ph%i_kpt == 1) then
call elsi_get_nnz(bh%def0,bh%n_lrow,bh%n_lcol,mat,nnz_l)
nnz_l = count(abs(mat) > bh%def0)
call MPI_Allreduce(nnz_l,nnz_g,1,MPI_INTEGER4,MPI_SUM,bh%comm,ierr)
......
......@@ -18,7 +18,6 @@ module ELSI_EIGENEXA
use ELSI_OUTPUT, only: elsi_say,elsi_get_time
use ELSI_PRECISION, only: r8,i4
use ELSI_REDIST, only: elsi_blacs_to_eigenexa_h,elsi_eigenexa_to_blacs_ev
use ELSI_UTIL, only: elsi_get_nnz
use EIGEN_LIBS_MOD, only: eigen_init,eigen_get_procs,eigen_get_id,&
eigen_get_matdims,eigen_s,eigen_sx,eigen_free
......@@ -90,7 +89,7 @@ subroutine elsi_solve_eigenexa_real(ph,bh,ham,ovlp,eval,evec)
! Compute sparsity
if(bh%nnz_g == UNSET) then
if(bh%nnz_l == UNSET) then
call elsi_get_nnz(bh%def0,bh%n_lrow,bh%n_lcol,ham,bh%nnz_l)
bh%nnz_l = count(abs(ham) > bh%def0)
end if
call MPI_Allreduce(bh%nnz_l,bh%nnz_g,1,MPI_INTEGER4,MPI_SUM,bh%comm,ierr)
......
......@@ -16,7 +16,7 @@ module ELSI_ELPA
MPI_INTEGER4
use ELSI_OUTPUT, only: elsi_say,elsi_get_time
use ELSI_PRECISION, only: r4,r8,i4
use ELSI_UTIL, only: elsi_get_nnz,elsi_get_gid,elsi_set_full_mat
use ELSI_UTIL, only: elsi_get_gid,elsi_set_full_mat
use ELPA, only: elpa_init,elpa_allocate,elpa_deallocate,&
elpa_autotune_deallocate,ELPA_2STAGE_REAL_GPU,ELPA_2STAGE_COMPLEX_GPU,&
ELPA_AUTOTUNE_FAST,ELPA_AUTOTUNE_MEDIUM,ELPA_AUTOTUNE_DOMAIN_REAL,&
......@@ -351,7 +351,7 @@ subroutine elsi_solve_elpa_real(ph,bh,ham,ovlp,eval,evec)
! Compute sparsity
if(bh%nnz_g == UNSET) then
if(bh%nnz_l == UNSET) then
call elsi_get_nnz(bh%def0,bh%n_lrow,bh%n_lcol,ham,bh%nnz_l)
bh%nnz_l = count(abs(ham) > bh%def0)
end if
call MPI_Allreduce(bh%nnz_l,bh%nnz_g,1,MPI_INTEGER4,MPI_SUM,bh%comm,ierr)
......@@ -694,7 +694,7 @@ subroutine elsi_solve_elpa_cmplx(ph,bh,ham,ovlp,eval,evec)
! Compute sparsity
if(bh%nnz_g == UNSET) then
if(bh%nnz_l == UNSET) then
call elsi_get_nnz(bh%def0,bh%n_lrow,bh%n_lcol,ham,bh%nnz_l)
bh%nnz_l = count(abs(ham) > bh%def0)
end if
call MPI_Allreduce(bh%nnz_l,bh%nnz_g,1,MPI_INTEGER4,MPI_SUM,bh%comm,ierr)
......
......@@ -49,41 +49,20 @@ subroutine elsi_mu_and_occ(ph,bh,n_electron,n_state,n_spin,n_kpt,k_wt,eval,occ,&
real(kind=r8), intent(out) :: occ(n_state,n_spin,n_kpt)
real(kind=r8), intent(out) :: mu
real(kind=r8) :: e_min
real(kind=r8) :: e_max
real(kind=r8) :: mu_min
real(kind=r8) :: mu_max
real(kind=r8) :: buf
real(kind=r8) :: diff_min ! Error on lower bound
real(kind=r8) :: diff_max ! Error on upper bound
integer(kind=i4) :: i_state
integer(kind=i4) :: i_kpt
integer(kind=i4) :: i_spin
integer(kind=i4) :: i_step
character(len=200) :: msg
character(len=*), parameter :: caller = "elsi_mu_and_occ"
! Determine smallest and largest eivenvalues
e_min = eval(1,1,1)
e_max = eval(n_state,1,1)
do i_kpt = 1,n_kpt
do i_spin = 1,n_spin
do i_state = 1,n_state
if(eval(i_state,i_spin,i_kpt) < e_min) then
e_min = eval(i_state,i_spin,i_kpt)
end if
if(eval(i_state,i_spin,i_kpt) > e_max) then
e_max = eval(i_state,i_spin,i_kpt)
end if
end do
end do
end do
! Determine upper and lower bounds of mu
mu_min = e_min
mu_max = e_max
mu_min = minval(eval)
mu_max = maxval(eval)
buf = 0.5_r8*abs(mu_max-mu_min)
if(mu_max - mu_min < ph%mu_tol) then
mu_min = mu_min-1.0_r8
......@@ -109,8 +88,8 @@ subroutine elsi_mu_and_occ(ph,bh,n_electron,n_state,n_spin,n_kpt,k_wt,eval,occ,&
call elsi_stop(bh,msg,caller)
end if
mu_min = mu_min-0.5_r8*abs(e_max-e_min)
mu_max = mu_max+0.5_r8*abs(e_max-e_min)
mu_min = mu_min-buf
mu_max = mu_max+buf
call elsi_check_electrons(ph,n_electron,n_state,n_spin,n_kpt,k_wt,eval,&
occ,mu_min,diff_min)
......
......@@ -15,7 +15,6 @@ module ELSI_OMM
use ELSI_MPI, only: elsi_check_mpi,MPI_SUM,MPI_INTEGER4
use ELSI_OUTPUT, only: elsi_say,elsi_get_time
use ELSI_PRECISION, only: r8,i4
use ELSI_UTIL, only: elsi_get_nnz
use MATRIXSWITCH, only: matrix,m_register_pdbc,ms_scalapack_setup,&
m_deallocate
......@@ -111,7 +110,7 @@ subroutine elsi_solve_omm_real(ph,bh,ham,ovlp,coeff,dm)
! Compute sparsity
if(bh%nnz_g == UNSET) then
if(bh%nnz_l == UNSET) then
call elsi_get_nnz(bh%def0,bh%n_lrow,bh%n_lcol,ham,bh%nnz_l)
bh%nnz_l = count(abs(ham) > bh%def0)
end if
call MPI_Allreduce(bh%nnz_l,bh%nnz_g,1,MPI_INTEGER4,MPI_SUM,bh%comm,ierr)
......@@ -269,7 +268,7 @@ subroutine elsi_solve_omm_cmplx(ph,bh,ham,ovlp,coeff,dm)
! Compute sparsity
if(bh%nnz_g == UNSET) then
if(bh%nnz_l == UNSET) then
call elsi_get_nnz(bh%def0,bh%n_lrow,bh%n_lcol,ham,bh%nnz_l)
bh%nnz_l = count(abs(ham) > bh%def0)
end if
call MPI_Allreduce(bh%nnz_l,bh%nnz_g,1,MPI_INTEGER4,MPI_SUM,bh%comm,ierr)
......
......@@ -21,7 +21,7 @@ module ELSI_REDIST
use ELSI_OUTPUT, only: elsi_say,elsi_get_time
use ELSI_PRECISION, only: r8,i4,i8
use ELSI_SORT, only: elsi_heapsort,elsi_permute,elsi_unpermute
use ELSI_UTIL, only: elsi_get_nnz,elsi_get_gid,elsi_get_lid
use ELSI_UTIL, only: elsi_get_gid,elsi_get_lid
implicit none
......@@ -1682,8 +1682,7 @@ subroutine elsi_sips_to_blacs_ev_real(ph,bh,evec_sips,evec)
call elsi_get_time(t0)
n_lrow_aux = ph%n_basis/bh%n_procs
call elsi_get_nnz(bh%def0,bh%n_lcol_sp,ph%n_states,evec_sips,nnz_l_before)
nnz_l_before = count(abs(evec_sips) > bh%def0)
call elsi_allocate(bh,dest,nnz_l_before,"dest",caller)
call elsi_allocate(bh,perm,nnz_l_before,"perm",caller)
......@@ -1846,7 +1845,7 @@ subroutine elsi_blacs_to_sips_dm_real(ph,bh,dm_den,dm_sp,row_ind,col_ptr)
call elsi_get_time(t0)
call elsi_get_nnz(bh%def0,bh%n_lrow,bh%n_lcol,dm_den,bh%nnz_l)
bh%nnz_l = count(abs(dm_den) > bh%def0)
call elsi_allocate(bh,val_send,bh%nnz_l,"val_send",caller)
call elsi_allocate(bh,row_send,bh%nnz_l,"row_send",caller)
......@@ -2002,7 +2001,7 @@ subroutine elsi_blacs_to_sips_dm_cmplx(ph,bh,dm_den,dm_sp,row_ind,col_ptr)
call elsi_get_time(t0)
call elsi_get_nnz(bh%def0,bh%n_lrow,bh%n_lcol,dm_den,bh%nnz_l)
bh%nnz_l = count(abs(dm_den) > bh%def0)
call elsi_allocate(bh,val_send,bh%nnz_l,"val_send",caller)
call elsi_allocate(bh,row_send,bh%nnz_l,"row_send",caller)
......@@ -2582,7 +2581,7 @@ subroutine elsi_blacs_to_siesta_dm_real(bh,dm_den,dm_sp,row_ind,col_ptr)
call elsi_get_time(t0)
call elsi_get_nnz(bh%def0,bh%n_lrow,bh%n_lcol,dm_den,bh%nnz_l)
bh%nnz_l = count(abs(dm_den) > bh%def0)
call elsi_allocate(bh,dest,bh%nnz_l,"dest",caller)
call elsi_allocate(bh,perm,bh%nnz_l,"perm",caller)
......@@ -2748,7 +2747,7 @@ subroutine elsi_blacs_to_siesta_dm_cmplx(bh,dm_den,dm_sp,row_ind,col_ptr)
call elsi_get_time(t0)
call elsi_get_nnz(bh%def0,bh%n_lrow,bh%n_lcol,dm_den,bh%nnz_l)
bh%nnz_l = count(abs(dm_den) > bh%def0)
call elsi_allocate(bh,dest,bh%nnz_l,"dest",caller)
call elsi_allocate(bh,perm,bh%nnz_l,"perm",caller)
......@@ -7662,7 +7661,7 @@ subroutine elsi_blacs_to_eigenexa_h_real(ph,bh,ham_den,ham_exa)
call elsi_get_time(t0)
call elsi_get_nnz(bh%def0,bh%n_lrow,bh%n_lcol,ham_den,nnz_l_before)
nnz_l_before = count(abs(ham_den) > bh%def0)
call elsi_allocate(bh,dest,nnz_l_before,"dest",caller)
call elsi_allocate(bh,perm,nnz_l_before,"perm",caller)
......
......@@ -20,7 +20,7 @@ module ELSI_RW
use ELSI_PRECISION, only: r8,i4,i8
use ELSI_REDIST, only: elsi_blacs_to_mask,elsi_blacs_to_sips_hs_dim,&
elsi_blacs_to_sips_hs,elsi_sips_to_blacs_dm
use ELSI_UTIL, only: elsi_get_nnz,elsi_reset_basic,elsi_check_init
use ELSI_UTIL, only: elsi_reset_basic,elsi_check_init
implicit none
......@@ -1667,7 +1667,7 @@ subroutine elsi_write_mat_sp_real(rwh,f_name,mat)
character(len=*), parameter :: caller = "elsi_write_mat_sp_real"
! Compute nnz
call elsi_get_nnz(rwh%bh%def0,rwh%bh%n_lrow,rwh%bh%n_lcol,mat,nnz_g)
nnz_g = count(abs(mat) > rwh%bh%def0)
! Convert to CSC
call elsi_allocate(rwh%bh,col_ptr,rwh%n_basis+1,"col_ptr",caller)
......@@ -1765,7 +1765,7 @@ subroutine elsi_write_mat_sp_cmplx(rwh,f_name,mat)
character(len=*), parameter :: caller = "elsi_write_mat_sp_cmplx"
! Compute nnz
call elsi_get_nnz(rwh%bh%def0,rwh%bh%n_lrow,rwh%bh%n_lcol,mat,nnz_g)
nnz_g = count(abs(mat) > rwh%bh%def0)
! Convert to CSC
call elsi_allocate(rwh%bh,col_ptr,rwh%n_basis+1,"col_ptr",caller)
......
......@@ -29,18 +29,12 @@ module ELSI_UTIL
public :: elsi_reset_basic
public :: elsi_get_gid
public :: elsi_get_lid
public :: elsi_get_nnz
public :: elsi_reduce_energy
public :: elsi_set_full_mat
public :: elsi_build_dm
public :: elsi_build_edm
public :: elsi_gram_schmidt
interface elsi_get_nnz
module procedure elsi_get_nnz_real
module procedure elsi_get_nnz_cmplx
end interface
interface elsi_set_full_mat
module procedure elsi_set_full_mat_real
module procedure elsi_set_full_mat_cmplx
......@@ -576,66 +570,6 @@ subroutine elsi_get_lid(n_procs,blk,gid,lid)
end subroutine
!>
!! Count the number of nonzero elements in a matrix.
!!
subroutine elsi_get_nnz_real(def0,n_row,n_col,mat,nnz)
implicit none
real(kind=r8), intent(in) :: def0
integer(kind=i4), intent(in) :: n_row
integer(kind=i4), intent(in) :: n_col
real(kind=r8), intent(in) :: mat(n_row,n_col)
integer(kind=i4), intent(out) :: nnz
integer(kind=i4) :: i_row
integer(kind=i4) :: i_col
character(len=*), parameter :: caller = "elsi_get_nnz_real"
nnz = 0
do i_col = 1,n_col
do i_row = 1,n_row
if(abs(mat(i_row,i_col)) > def0) then
nnz = nnz+1
end if
end do
end do
end subroutine
!>
!! Count the number of nonzero elements in a matrix.
!!
subroutine elsi_get_nnz_cmplx(def0,n_row,n_col,mat,nnz)
implicit none
real(kind=r8), intent(in) :: def0
integer(kind=i4), intent(in) :: n_row
integer(kind=i4), intent(in) :: n_col
complex(kind=r8), intent(in) :: mat(n_row,n_col)
integer(kind=i4), intent(out) :: nnz
integer(kind=i4) :: i_row
integer(kind=i4) :: i_col
character(len=*), parameter :: caller = "elsi_get_nnz_cmplx"
nnz = 0
do i_col = 1,n_col
do i_row = 1,n_row
if(abs(mat(i_row,i_col)) > def0) then
nnz = nnz+1
end if
end do
end do
end subroutine
!>
!! Reduce energy over spin channels and k-points.
!!
......@@ -719,6 +653,7 @@ subroutine elsi_set_full_mat_real(ph,bh,uplo,mat)
end do
end if
! Allocate slightly more to work around crashes in some PBLAS
call elsi_allocate(bh,tmp,bh%n_lrow,bh%n_lcol+2*bh%blk,"tmp",caller)
call pdtran(ph%n_basis,ph%n_basis,1.0_r8,mat,1,1,bh%desc,0.0_r8,tmp,1,1,&
......@@ -781,6 +716,7 @@ subroutine elsi_set_full_mat_cmplx(ph,bh,uplo,mat)
end do
end if
! Allocate slightly more to work around crashes in some PBLAS
call elsi_allocate(bh,tmp,bh%n_lrow,bh%n_lcol+2*bh%blk,"tmp",caller)
call pztranc(ph%n_basis,ph%n_basis,(1.0_r8,0.0_r8),mat,1,1,bh%desc,&
......@@ -823,59 +759,38 @@ subroutine elsi_build_dm_real(ph,bh,occ,evec,dm)
integer(kind=i4) :: i
integer(kind=i4) :: gid
integer(kind=i4) :: max_state
logical :: use_gemm
character(len=200) :: msg
real(kind=r8), allocatable :: factor(:)
real(kind=r8), allocatable :: tmp(:,:)
character(len=*), parameter :: caller = "elsi_build_dm_real"
call elsi_get_time(t0)
call elsi_allocate(bh,factor,ph%n_states_solve,"factor",caller)
call elsi_allocate(bh,tmp,bh%n_lrow,bh%n_lcol,"tmp",caller)
tmp(:,:) = evec
dm(:,:) = 0.0_r8
max_state = 0
use_gemm = .false.
do i = 1,ph%n_states_solve
if(occ(i) > 0.0_r8) then
factor(i) = sqrt(occ(i))
max_state = i
else if(occ(i) < 0.0_r8) then
use_gemm = .true.
exit
end if
end do
! Compute density matrix
if(use_gemm) then
if(any(occ(1:ph%n_states_solve) < 0.0_r8)) then
do i = 1,bh%n_lcol
call elsi_get_gid(bh%my_pcol,bh%n_pcol,bh%blk,i,gid)
if(gid <= ph%n_states_solve) then
tmp(:,i) = tmp(:,i)*occ(gid)
else
tmp(:,i) = 0.0_r8
tmp(:,i) = evec(:,i)*occ(gid)
end if
end do
call pdgemm("N","T",ph%n_basis,ph%n_basis,ph%n_states_solve,1.0_r8,tmp,1,&
1,bh%desc,evec,1,1,bh%desc,0.0_r8,dm,1,1,bh%desc)
else
max_state = count(occ(1:ph%n_states_solve) > 0.0_r8)
do i = 1,bh%n_lcol
call elsi_get_gid(bh%my_pcol,bh%n_pcol,bh%blk,i,gid)
if(gid <= ph%n_states_solve) then
if(factor(gid) > 0.0_r8) then
tmp(:,i) = tmp(:,i)*factor(gid)
else
tmp(:,i) = 0.0_r8
end if
if(gid <= max_state) then
tmp(:,i) = evec(:,i)*sqrt(occ(gid))
end if
end do
......@@ -885,7 +800,6 @@ subroutine elsi_build_dm_real(ph,bh,occ,evec,dm)
call elsi_set_full_mat(ph,bh,UT_MAT,dm)
end if
call elsi_deallocate(bh,factor,"factor")
call elsi_deallocate(bh,tmp,"tmp")
call elsi_get_time(t1)
......@@ -915,44 +829,25 @@ subroutine elsi_build_dm_cmplx(ph,bh,occ,evec,dm)
integer(kind=i4) :: i
integer(kind=i4) :: gid
integer(kind=i4) :: max_state
logical :: use_gemm
character(len=200) :: msg
real(kind=r8), allocatable :: factor(:)
complex(kind=r8), allocatable :: tmp(:,:)
character(len=*), parameter :: caller = "elsi_build_dm_cmplx"
call elsi_get_time(t0)
call elsi_allocate(bh,factor,ph%n_states_solve,"factor",caller)
call elsi_allocate(bh,tmp,bh%n_lrow,bh%n_lcol,"tmp",caller)
tmp(:,:) = evec
dm(:,:) = (0.0_r8,0.0_r8)
max_state = 0
use_gemm = .false.
do i = 1,ph%n_states_solve
if(occ(i) > 0.0_r8) then
factor(i) = sqrt(occ(i))
max_state = i
else if(occ(i) < 0.0_r8) then
use_gemm = .true.
exit
end if
end do
! Compute density matrix
if(use_gemm) then
if(any(occ(1:ph%n_states_solve) < 0.0_r8)) then
do i = 1,bh%n_lcol
call elsi_get_gid(bh%my_pcol,bh%n_pcol,bh%blk,i,gid)
if(gid <= ph%n_states_solve) then
tmp(:,i) = tmp(:,i)*occ(gid)
else
tmp(:,i) = (0.0_r8,0.0_r8)
tmp(:,i) = evec(:,i)*occ(gid)
end if
end do
......@@ -960,15 +855,13 @@ subroutine elsi_build_dm_cmplx(ph,bh,occ,evec,dm)
(1.0_r8,0.0_r8),tmp,1,1,bh%desc,evec,1,1,bh%desc,(0.0_r8,0.0_r8),dm,&
1,1,bh%desc)
else
max_state = count(occ(1:ph%n_states_solve) > 0.0_r8)
do i = 1,bh%n_lcol
call elsi_get_gid(bh%my_pcol,bh%n_pcol,bh%blk,i,gid)
if(gid <= ph%n_states_solve) then
if(factor(gid) > 0.0_r8) then
tmp(:,i) = tmp(:,i)*factor(gid)
else
tmp(:,i) = (0.0_r8,0.0_r8)
end if
if(gid <= max_state) then
tmp(:,i) = evec(:,i)*sqrt(occ(gid))
end if
end do
......@@ -978,7 +871,6 @@ subroutine elsi_build_dm_cmplx(ph,bh,occ,evec,dm)
call elsi_set_full_mat(ph,bh,UT_MAT,dm)
end if
call elsi_deallocate(bh,factor,"factor")
call elsi_deallocate(bh,tmp,"tmp")
call elsi_get_time(t1)
......@@ -1010,7 +902,6 @@ subroutine elsi_build_edm_real(ph,bh,occ,eval,evec,edm)
integer(kind=i4) :: i
integer(kind=i4) :: gid
integer(kind=i4) :: max_state
logical :: use_gemm
character(len=200) :: msg
real(kind=r8), allocatable :: factor(:)
......@@ -1023,47 +914,30 @@ subroutine elsi_build_edm_real(ph,bh,occ,eval,evec,edm)
call elsi_allocate(bh,factor,ph%n_states_solve,"factor",caller)
call elsi_allocate(bh,tmp,bh%n_lrow,bh%n_lcol,"tmp",caller)
max_state = 0
tmp(:,:) = evec
edm(:,:) = 0.0_r8
use_gemm = .false.
do i = 1,ph%n_states_solve
factor(i) = -occ(i)*eval(i)
if(factor(i) > 0.0_r8) then
max_state = i
else if(factor(i) < 0.0_r8) then
use_gemm = .true.
end if
end do
factor(:) = -occ(1:ph%n_states_solve)*eval(1:ph%n_states_solve)
! Compute density matrix
if(use_gemm) then
if(any(factor < 0.0_r8)) then
do i = 1,bh%n_lcol
call elsi_get_gid(bh%my_pcol,bh%n_pcol,bh%blk,i,gid)
if(gid <= ph%n_states_solve) then
tmp(:,i) = tmp(:,i)*factor(gid)
else
tmp(:,i) = 0.0_r8
tmp(:,i) = evec(:,i)*factor(gid)
end if
end do
call pdgemm("N","T",ph%n_basis,ph%n_basis,ph%n_states_solve,-1.0_r8,tmp,&
1,1,bh%desc,evec,1,1,bh%desc,0.0_r8,edm,1,1,bh%desc)
else
factor = sqrt(factor)
max_state = count(factor > 0.0_r8)
do i = 1,bh%n_lcol
call elsi_get_gid(bh%my_pcol,bh%n_pcol,bh%blk,i,gid)
if(gid <= ph%n_states_solve) then
if(factor(gid) > 0.0_r8) then
tmp(:,i) = tmp(:,i)*factor(gid)
else
tmp(:,i) = 0.0_r8
end if
if(gid <= max_state) then
tmp(:,i) = evec(:,i)*sqrt(factor(gid))
end if
end do
......@@ -1105,7 +979,6 @@ subroutine elsi_build_edm_cmplx(ph,bh,occ,eval,evec,edm)
integer(kind=i4) :: i
integer(kind=i4) :: gid
integer(kind=i4) :: max_state
logical :: use_gemm
character(len=200) :: msg
real(kind=r8), allocatable :: factor(:)
......@@ -1118,30 +991,17 @@ subroutine elsi_build_edm_cmplx(ph,bh,occ,eval,evec,edm)
call elsi_allocate(bh,factor,ph%n_states_solve,"factor",caller)
call elsi_allocate(bh,tmp,bh%n_lrow,bh%n_lcol,"tmp",caller)
max_state = 0
tmp(:,:) = evec
edm(:,:) = (0.0_r8,0.0_r8)
use_gemm = .false.
do i = 1,ph%n_states_solve
factor(i) = -occ(i)*eval(i)
if(factor(i) > 0.0_r8) then