Commit 11316236 authored by Victor Yu's avatar Victor Yu

Use ELPA routines to build density matrix

Density matrix P = C * C^T. The ELPA matrix multiplication routine is GPU
accelerated, so it is faster than pdsyrk/pzherk if GPUs are available.
parent 569f2f9e
......@@ -7,7 +7,7 @@ SET(elsi_URL "http://elsi-interchange.org")
SET(elsi_EMAIL "elsi-team@duke.edu")
SET(elsi_LICENSE "BSD 3")
SET(elsi_DESCRIPTION "Electronic Structure Infrastructure")
SET(elsi_DATESTAMP "20200610")
SET(elsi_DATESTAMP "20200612")
### CMake modules ###
LIST(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake)
......
This diff is collapsed.
......@@ -236,6 +236,7 @@ static const elpa_index_int_entry_t int_entries[] = {
BOOL_ENTRY("measure_performance", "Also measure with flops (via papi) with the timings", 0, ELPA_AUTOTUNE_NOT_TUNABLE, 0, PRINT_YES),
BOOL_ENTRY("check_pd", "Check eigenvalues to be positive", 0, ELPA_AUTOTUNE_NOT_TUNABLE, 0, PRINT_YES),
BOOL_ENTRY("cannon_for_generalized", "Whether to use Cannons algorithm for the generalized EVP", 1, ELPA_AUTOTUNE_NOT_TUNABLE, 0, PRINT_YES),
BOOL_ENTRY("multiply_at_a", "Switch to A^T * A in the matrix multiplication routine", 0, ELPA_AUTOTUNE_NOT_TUNABLE, 0, PRINT_NO),
};
#define READONLY_DOUBLE_ENTRY(option_name, option_description) \
......@@ -728,118 +729,118 @@ static int skewsymmetric_is_valid(elpa_index_t index, int n, int new_value) {
}
static int band_to_full_cardinality(elpa_index_t index) {
return 10;
return 10;
}
static int band_to_full_enumerate(elpa_index_t index, int i) {
return i+1;
return i+1;
}
// TODO shouldnt it be only for ELPA2??
static int band_to_full_is_valid(elpa_index_t index, int n, int new_value) {
int max_block=10;
int max_block=10;
return (1 <= new_value) && (new_value <= max_block);
}
static int stripewidth_real_cardinality(elpa_index_t index) {
return 17;
return 17;
}
static int stripewidth_complex_cardinality(elpa_index_t index) {
return 17;
return 17;
}
static int stripewidth_real_enumerate(elpa_index_t index, int i) {
switch(i) {
case 0:
return 32;
case 1:
return 36;
case 2:
return 40;
case 3:
return 44;
case 4:
return 48;
case 5:
return 52;
case 6:
return 56;
case 7:
return 60;
case 8:
return 64;
case 9:
return 68;
case 10:
return 72;
case 11:
return 76;
case 12:
return 80;
case 13:
return 84;
case 14:
return 88;
case 15:
return 92;
case 16:
return 96;
}
switch(i) {
case 0:
return 32;
case 1:
return 36;
case 2:
return 40;
case 3:
return 44;
case 4:
return 48;
case 5:
return 52;
case 6:
return 56;
case 7:
return 60;
case 8:
return 64;
case 9:
return 68;
case 10:
return 72;
case 11:
return 76;
case 12:
return 80;
case 13:
return 84;
case 14:
return 88;
case 15:
return 92;
case 16:
return 96;
}
}
static int stripewidth_complex_enumerate(elpa_index_t index, int i) {
switch(i) {
case 0:
return 48;
case 1:
return 56;
case 2:
return 64;
case 3:
return 72;
case 4:
return 80;
case 5:
return 88;
case 6:
return 96;
case 7:
return 104;
case 8:
return 112;
case 9:
return 120;
case 10:
return 128;
case 11:
return 136;
case 12:
return 144;
case 13:
return 152;
case 14:
return 160;
case 15:
return 168;
case 16:
return 176;
}
switch(i) {
case 0:
return 48;
case 1:
return 56;
case 2:
return 64;
case 3:
return 72;
case 4:
return 80;
case 5:
return 88;
case 6:
return 96;
case 7:
return 104;
case 8:
return 112;
case 9:
return 120;
case 10:
return 128;
case 11:
return 136;
case 12:
return 144;
case 13:
return 152;
case 14:
return 160;
case 15:
return 168;
case 16:
return 176;
}
}
static int stripewidth_real_is_valid(elpa_index_t index, int n, int new_value) {
return (32 <= new_value) && (new_value <= 96);
return (32 <= new_value) && (new_value <= 96);
}
static int stripewidth_complex_is_valid(elpa_index_t index, int n, int new_value) {
return (48 <= new_value) && (new_value <= 176);
return (48 <= new_value) && (new_value <= 176);
}
static int omp_threads_cardinality(elpa_index_t index) {
int max_threads;
max_threads_glob = 1;
set_max_threads_glob = 1;
max_threads = max_threads_glob;
return max_threads;
int max_threads;
max_threads_glob = 1;
set_max_threads_glob = 1;
max_threads = max_threads_glob;
return max_threads;
}
static int omp_threads_enumerate(elpa_index_t index, int i) {
......@@ -848,9 +849,9 @@ static int omp_threads_enumerate(elpa_index_t index, int i) {
static int omp_threads_is_valid(elpa_index_t index, int n, int new_value) {
int max_threads;
max_threads_glob = 1;
set_max_threads_glob = 1;
max_threads = max_threads_glob;
max_threads_glob = 1;
set_max_threads_glob = 1;
max_threads = max_threads_glob;
return (1 <= new_value) && (new_value <= max_threads);
}
......@@ -887,28 +888,28 @@ static int valid_with_gpu_elpa2(elpa_index_t index, int n, int new_value) {
}
static int max_stored_rows_cardinality(elpa_index_t index) {
return 8;
return 8;
}
static int max_stored_rows_enumerate(elpa_index_t index, int i) {
switch(i) {
case 0:
return 15;
case 1:
return 31;
case 2:
return 47;
case 3:
return 63;
case 4:
return 79;
case 5:
return 95;
case 6:
return 111;
case 7:
return 127;
}
switch(i) {
case 0:
return 15;
case 1:
return 31;
case 2:
return 47;
case 3:
return 63;
case 4:
return 79;
case 5:
return 95;
case 6:
return 111;
case 7:
return 127;
}
}
static int max_stored_rows_is_valid(elpa_index_t index, int n, int new_value) {
......
......@@ -773,11 +773,14 @@ subroutine elsi_build_dm_edm_real(ph,bh,factor,evec,dm,which)
integer(kind=i4), intent(in) :: which
real(kind=r8) :: alpha
real(kind=r8) :: dummy(1,1)
real(kind=r8) :: t0
real(kind=r8) :: t1
integer(kind=i4) :: i
integer(kind=i4) :: gid
integer(kind=i4) :: max_state
integer(kind=i4) :: ierr
logical :: use_elpa_mult
character(len=200) :: msg
real(kind=r8), allocatable :: tmp(:,:)
......@@ -819,8 +822,37 @@ subroutine elsi_build_dm_edm_real(ph,bh,factor,evec,dm,which)
end if
end do
call pdsyrk("U","N",ph%n_basis,max_state,alpha,tmp,1,1,bh%desc,0.0_r8,dm,&
1,1,bh%desc)
if(associated(ph%elpa_aux)) then
call ph%elpa_aux%set("multiply_at_a",1,ierr)
if(ierr /= 0 .or. ph%elpa_gpu == 0 .or. ph%solver /= ELPA_SOLVER&
.or. bh%blk*(max(bh%n_prow,bh%n_pcol)-1) >= max_state) then
use_elpa_mult = .false.
else
use_elpa_mult = .true.
end if
else
use_elpa_mult = .false.
end if
if(use_elpa_mult) then
call pdtran(ph%n_basis,ph%n_basis,1.0_r8,tmp,1,1,bh%desc,0.0_r8,dm,1,&
1,bh%desc)
call ph%elpa_aux%hermitian_multiply("N","U",max_state,dm,dummy,&
bh%n_lrow,bh%n_lcol,tmp,bh%n_lrow,bh%n_lcol,ierr)
call elsi_check_err(bh,"ELPA matrix multiplication",ierr,caller)
dm(:,:) = alpha*tmp
else
call pdsyrk("U","N",ph%n_basis,max_state,alpha,tmp,1,1,bh%desc,0.0_r8,&
dm,1,1,bh%desc)
end if
if(associated(ph%elpa_aux)) then
call ph%elpa_aux%set("multiply_at_a",0,ierr)
end if
call elsi_set_full_mat(ph,bh,UT_MAT,dm)
end if
......@@ -856,11 +888,14 @@ subroutine elsi_build_dm_edm_cmplx(ph,bh,factor,evec,dm,which)
integer(kind=i4), intent(in) :: which
complex(kind=r8) :: alpha
complex(kind=r8) :: dummy(1,1)
real(kind=r8) :: t0
real(kind=r8) :: t1
integer(kind=i4) :: i
integer(kind=i4) :: gid
integer(kind=i4) :: max_state
integer(kind=i4) :: ierr
logical :: use_elpa_mult
character(len=200) :: msg
complex(kind=r8), allocatable :: tmp(:,:)
......@@ -902,8 +937,37 @@ subroutine elsi_build_dm_edm_cmplx(ph,bh,factor,evec,dm,which)
end if
end do
call pzherk("U","N",ph%n_basis,max_state,alpha,tmp,1,1,bh%desc,&
(0.0_r8,0.0_r8),dm,1,1,bh%desc)
if(associated(ph%elpa_aux)) then
call ph%elpa_aux%set("multiply_at_a",1,ierr)
if(ierr /= 0 .or. ph%elpa_gpu == 0 .or. ph%solver /= ELPA_SOLVER&
.or. bh%blk*(max(bh%n_prow,bh%n_pcol)-1) >= max_state) then
use_elpa_mult = .false.
else
use_elpa_mult = .true.
end if
else
use_elpa_mult = .false.
end if
if(use_elpa_mult) then
call pztranc(ph%n_basis,ph%n_basis,(1.0_r8,0.0_r8),tmp,1,1,bh%desc,&
(0.0_r8,0.0_r8),dm,1,1,bh%desc)
call ph%elpa_aux%hermitian_multiply("N","U",max_state,dm,dummy,&
bh%n_lrow,bh%n_lcol,tmp,bh%n_lrow,bh%n_lcol,ierr)
call elsi_check_err(bh,"ELPA matrix multiplication",ierr,caller)
dm(:,:) = alpha*tmp
else
call pzherk("U","N",ph%n_basis,max_state,alpha,tmp,1,1,bh%desc,&
(0.0_r8,0.0_r8),dm,1,1,bh%desc)
end if
if(associated(ph%elpa_aux)) then
call ph%elpa_aux%set("multiply_at_a",0,ierr)
end if
call elsi_set_full_mat(ph,bh,UT_MAT,dm)
end if
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment