Commit 94fbaa2f authored by Nicolas Tancogne-Dejean's avatar Nicolas Tancogne-Dejean Committed by Micael Oliveira

Consolidation and bugfixes of the Blas interface:

- Adding ASSERTs to check for array dimensions
- Bugfix in the routine X(symv_2): the leading dimension of b was wrong
- Specifying lead_dim instead of explicit sizes in Blas 2 and 3 routines if missing
- Bugfix: gemm_2 and gemmt_2 routines was wrong
- The routine symm_2 is remove as it is not used and does not make really sense
parent dba63290
......@@ -62,7 +62,7 @@ subroutine X(mesh_batch_dotp_matrix)(mesh, aa, bb, dot, symm, reduce)
ldaa = size(aa%X(ff), dim = 1)
ldbb = size(bb%X(ff), dim = 1)
call lalg_gemmt(aa%nst, bb%nst, mesh%np, R_TOTYPE(mesh%volume_element), &
call lalg_gemmt(aa%nst, aa%dim, bb%nst, bb%dim, mesh%np, R_TOTYPE(mesh%volume_element), &
aa%X(ff), bb%X(ff), R_TOTYPE(M_ZERO), dd)
else
......
......@@ -145,8 +145,6 @@ module lalg_basic_oct_m
interface lalg_symm
module procedure symm_1_2
module procedure symm_1_4
module procedure symm_2_2
module procedure symm_2_4
end interface lalg_symm
!> Matrix-matrix multiplication.
......
......@@ -525,8 +525,10 @@ subroutine FNAME(symv_1)(n, alpha, a, x, beta, y)
! no push_sub, called too frequently
ASSERT(ubound(a, dim=1) >= n)
call profiling_in(symv_profile, 'BLAS_SYMV')
call blas_symv('U', n, alpha, a(1, 1), n, x(1), 1, beta, y(1), 1)
call blas_symv('U', n, alpha, a(1, 1), lead_dim(a), x(1), 1, beta, y(1), 1)
call profiling_out(symv_profile)
end subroutine FNAME(symv_1)
......@@ -540,8 +542,13 @@ subroutine FNAME(symv_2)(n1, n2, alpha, a, x, beta, y)
PUSH_SUB(FNAME(symv_2))
ASSERT(ubound(a, dim=1) == n1)
ASSERT(ubound(a, dim=2) == n2)
ASSERT(ubound(y, dim=1) == n1)
ASSERT(ubound(y, dim=2) >= n2)
call profiling_in(symv_profile, 'BLAS_SYMV')
call blas_symv('U', n1*n2, alpha, a(1, 1, 1), n1*2, x(1), 1, beta, y(1, 1), 1)
call blas_symv('U', n1*n2, alpha, a(1, 1, 1), n1*n2, x(1), 1, beta, y(1, 1), 1)
call profiling_out(symv_profile)
POP_SUB(FNAME(symv_2))
......@@ -556,8 +563,10 @@ subroutine FNAME(gemv_1)(m, n, alpha, a, x, beta, y)
PUSH_SUB(FNAME(gemv_1))
ASSERT(ubound(a, dim=1) >= m)
call profiling_in(gemv_profile, "BLAS_GEMV")
call blas_gemv('N', m, n, alpha, a(1,1), m, x(1), 1, beta, y(1), 1)
call blas_gemv('N', m, n, alpha, a(1,1), lead_dim(a), x(1), 1, beta, y(1), 1)
call profiling_out(gemv_profile)
POP_SUB(FNAME(gemv_1))
......@@ -572,6 +581,11 @@ subroutine FNAME(gemv_2)(m1, m2, n, alpha, a, x, beta, y)
PUSH_SUB(FNAME(gemv_2))
ASSERT(ubound(a, dim=1) == m1)
ASSERT(ubound(a, dim=2) == m2)
ASSERT(ubound(y, dim=1) == m1)
ASSERT(ubound(y, dim=2) >= m2)
call profiling_in(gemv_profile, "BLAS_GEMV")
call blas_gemv('N', m1*m2, n, alpha, a(1,1,1), m1*m2, x(1), 1, beta, y(1,1), 1)
call profiling_out(gemv_profile)
......@@ -597,21 +611,36 @@ subroutine FNAME(gemm_1)(m, n, k, alpha, a, b, beta, c)
! no PUSH SUB, called too often
ASSERT(ubound(a, dim=1) >= m)
ASSERT(ubound(a, dim=2) >= n)
ASSERT(ubound(b, dim=1) >= k)
ASSERT(ubound(c, dim=1) >= m)
ASSERT(ubound(c, dim=2) >= n)
call blas_gemm('N', 'N', m, n, k, alpha, a(1, 1), lead_dim(a), b(1, 1), lead_dim(b), beta, c(1, 1), lead_dim(c))
end subroutine FNAME(gemm_1)
subroutine FNAME(gemm_2)(m, n, k, alpha, a, b, beta, c)
integer, intent(in) :: m, n, k
subroutine FNAME(gemm_2)(m1, m2, n, k, alpha, a, b, beta, c)
integer, intent(in) :: m1, m2, n, k
TYPE1, intent(in) :: alpha, beta
TYPE1, intent(in) :: a(:, :, :) !< a(m, k)
TYPE1, intent(in) :: a(:, :, :) !< a(m1, m2, k)
TYPE1, intent(in) :: b(:, :) !< b(k, n)
TYPE1, intent(inout) :: c(:, :, :) !< c(m, n)
TYPE1, intent(inout) :: c(:, :, :) !< c(m1, m2, n)
PUSH_SUB(FNAME(gemm_2))
call blas_gemm('N', 'N', m, n, k, alpha, a(1, 1, 1), lead_dim(a), &
b(1, 1), lead_dim(b), beta, c(1, 1, 1), lead_dim(c))
ASSERT(ubound(a, dim=1) == m1)
ASSERT(ubound(a, dim=2) == m2)
ASSERT(ubound(a, dim=3) >= k)
ASSERT(ubound(b, dim=1) >= k)
ASSERT(ubound(c, dim=1) == m1)
ASSERT(ubound(c, dim=2) == m2)
ASSERT(ubound(c, dim=3) >= n)
call blas_gemm('N', 'N', m1*m2, n, k, alpha, a(1, 1, 1), m1*m2, &
b(1, 1), lead_dim(b), beta, c(1, 1, 1), m1*m2)
POP_SUB(FNAME(gemm_2))
end subroutine FNAME(gemm_2)
......@@ -626,20 +655,36 @@ subroutine FNAME(gemmt_1)(m, n, k, alpha, a, b, beta, c)
! no PUSH_SUB, called too often
ASSERT(ubound(a, dim=1) >= k)
ASSERT(ubound(a, dim=2) == m)
ASSERT(ubound(b, dim=1) >= k)
ASSERT(ubound(b, dim=2) >= n)
ASSERT(ubound(c, dim=1) >= m)
ASSERT(ubound(c, dim=2) >= n)
call blas_gemm('C', 'N', m, n, k, alpha, a(1, 1), lead_dim(a), b(1, 1), lead_dim(b), beta, c(1, 1), lead_dim(c))
end subroutine FNAME(gemmt_1)
subroutine FNAME(gemmt_2)(m, n, k, alpha, a, b, beta, c)
integer, intent(in) :: m, n, k
subroutine FNAME(gemmt_2)(m1, m2, n1, n2, k, alpha, a, b, beta, c)
integer, intent(in) :: m1, m2, n1, n2, k
TYPE1, intent(in) :: alpha, beta
TYPE1, intent(in) :: a(:, :, :) !< a((k), m)
TYPE1, intent(in) :: b(:, :, :) !< b((k), n)
TYPE1, intent(inout) :: c(:, :) !< c(m, n)
TYPE1, intent(in) :: a(:, :, :) !< a(k, m2, m1)
TYPE1, intent(in) :: b(:, :, :) !< b(k, n2, n1)
TYPE1, intent(inout) :: c(:, :) !< c(m1*m2, n1*n2)
PUSH_SUB(FNAME(gemmt_2))
call blas_gemm('C', 'N', m, n, k, alpha, a(1, 1, 1), lead_dim(a), &
ASSERT(ubound(a, dim=1) >= k)
ASSERT(ubound(a, dim=2) == m2)
ASSERT(ubound(a, dim=3) == m1)
ASSERT(ubound(b, dim=1) >= k)
ASSERT(ubound(b, dim=2) == n2)
ASSERT(ubound(b, dim=3) == n1)
ASSERT(ubound(c, dim=1) >= m1*m2)
ASSERT(ubound(c, dim=2) >= n1*n2)
call blas_gemm('C', 'N', m1*m2, n1*n2, k, alpha, a(1, 1, 1), lead_dim(a), &
b(1, 1, 1), lead_dim(b), beta, c(1, 1), lead_dim(c))
POP_SUB(FNAME(gemmt_2))
......@@ -651,46 +696,28 @@ subroutine FNAME(symm_1)(m, n, side, alpha, a, b, beta, c)
integer, intent(in) :: m, n
character(1), intent(in) :: side
TYPE1, intent(in) :: alpha, beta, a(:, :), b(:, :)
TYPE1, intent(inout) :: c(:, :)
TYPE1, intent(inout) :: c(:, :) !c(m, n)
integer :: lda
! no push_sub, called too frequently
!The size specified are for the matrix C
ASSERT(ubound(c, dim=1) >= m)
ASSERT(ubound(c, dim=2) >= n)
select case(side)
case('l', 'L')
lda = max(1, m)
case('r', 'R')
lda = max(1, n)
case('l', 'L') ! Here we compute C := alpha*A*B + beta*C
ASSERT(ubound(a, dim=1) >= m)
ASSERT(ubound(b, dim=1) >= n)
case('r', 'R') ! Here we compute C := alpha*B*A + beta*C
ASSERT(ubound(a, dim=1) >= n)
ASSERT(ubound(b, dim=1) >= m)
end select
call blas_symm(side, 'U', m, n, alpha, a(1, 1), lda, b(1, 1), m, beta, c(1, 1), m)
call blas_symm(side, 'U', m, n, alpha, a(1, 1), lead_dim(a), b(1, 1), lead_dim(b), beta, c(1, 1), lead_dim(c))
end subroutine FNAME(symm_1)
subroutine FNAME(symm_2)(m, n, side, alpha, a, b, beta, c)
integer, intent(in) :: m, n
character(1), intent(in) :: side
TYPE1, intent(in) :: alpha, beta, a(:, :, :), b(:, :)
TYPE1, intent(inout) :: c(:, :, :)
integer :: lda
PUSH_SUB(FNAME(symm_2))
select case(side)
case('l', 'L')
lda = max(1, m)
case('r', 'R')
lda = max(1, n)
end select
call blas_symm(side, 'U', m, n, alpha, a(1, 1, 1), lda, b(1, 1), m, beta, c(1, 1, 1), m)
POP_SUB(FNAME(symm_2))
end subroutine FNAME(symm_2)
!> ------------------------------------------------------------------
!! Matrix-matrix multiplication.
!! ------------------------------------------------------------------
......@@ -706,14 +733,19 @@ subroutine FNAME(trmm_1)(m, n, uplo, transa, side, alpha, a, b)
! no push_sub, called too frequently
ASSERT(ubound(b, dim=1) >= m)
ASSERT(ubound(b, dim=2) >= n)
select case(side)
case('L', 'l')
lda = max(1, m)
ASSERT(ubound(a, dim=1) >= m)
ASSERT(ubound(a, dim=2) >= m)
case('R', 'r')
lda = max(1, n)
ASSERT(ubound(a, dim=1) >= n)
ASSERT(ubound(a, dim=2) >= n)
end select
call blas_trmm(side, uplo, transa, 'N', m, n, alpha, a(1, 1), lda, b(1, 1), m)
call blas_trmm(side, uplo, transa, 'N', m, n, alpha, a(1, 1), lead_dim(a), b(1, 1), lead_dim(b))
end subroutine FNAME(trmm_1)
......
......@@ -137,7 +137,9 @@ subroutine X(states_elec_blockt_mul)(mesh, st, psi1_start, psi2_start, &
SAFE_ALLOCATE(res_local(1:xpsi1_count(rank), 1:sendcnt))
call profiling_in(C_PROFILING_BLOCKT_MM, 'BLOCKT_MM')
call lalg_gemmt(xpsi1_count(rank), sendcnt, mesh%np*st%d%dim, R_TOTYPE(mesh%vol_pp(1)), &
!Due to the definition of the gemmt routine, the dim is set to 1 and the number of
!grid points to np*dim. Otherwise the code won't work for spinors
call lalg_gemmt(xpsi1_count(rank), 1, sendcnt, 1, mesh%np*st%d%dim, R_TOTYPE(mesh%vol_pp(1)), &
psi1_block, sendbuf, R_TOTYPE(M_ZERO), res_local)
call profiling_out(C_PROFILING_BLOCKT_MM)
......@@ -331,7 +333,7 @@ subroutine X(states_elec_block_matr_mul_add)(mesh, st, alpha, psi_start, res_sta
matr_col_offset+1:matr_col_offset+xres_count(rank))
call profiling_out(C_PROFILING_BLOCK_MATR_CP)
call profiling_in(C_PROFILING_BLOCK_MATR_MM, 'BLOCK_MATR_MM')
call lalg_gemm(mesh%np * st%d%dim, xres_count(rank), sendcnt, alpha, &
call lalg_gemm(mesh%np, st%d%dim, xres_count(rank), sendcnt, alpha, &
sendbuf, matr_block, R_TOTYPE(M_ONE), res_block)
call profiling_out(C_PROFILING_BLOCK_MATR_MM)
end if
......@@ -366,7 +368,9 @@ subroutine X(states_elec_block_matr_mul_add)(mesh, st, alpha, psi_start, res_sta
! matr_block is needed because matr may be an assumed-shape array.
SAFE_ALLOCATE(matr_block(1:psi_col, 1:matr_col))
matr_block = matr
call lalg_gemm(mesh%np * st%d%dim, matr_col, psi_col, alpha, &
ASSERT(matr_col == res_col)
call lalg_gemm(mesh%np, st%d%dim, matr_col, psi_col, alpha, &
psi_block, matr_block, beta, res_block)
SAFE_DEALLOCATE_A(matr_block)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment