!
! Copyright (C) 2017 Mitsuaki Kawamura
! This file is distributed under the terms of the
! GNU General Public License. See the file `License'
! in the root directory of the present distribution,
! or http://www.gnu.org/copyleft/gpl.txt .
!
MODULE sctk_invert
  !
  IMPLICIT NONE
  !
CONTAINS
!
! Invertion of matrices
!
SUBROUTINE invert()
  !
  USE kinds, ONLY : DP
  USE mp_world, ONLY : nproc, mpime, world_comm
  USE mp, ONLY : mp_sum
  USE sctk_val, ONLY : ngv, nmf, wscr
  !
  USE sctk_cnt_dsp, ONLY : cnt_and_dsp
  !
  IMPLICIT NONE
  !
  INTEGER :: cnt, dsp, imf, ig, jg, imaxpiv(nproc)
  REAL(dp) :: maxpiv(nproc)
  COMPLEX(dp) :: key1(ngv), key2(ngv), piv
  COMPLEX(dp),ALLOCATABLE :: wscr1(:,:), wscr2(:,:)
  !
  CALL cnt_and_dsp(ngv, cnt, dsp)
  ALLOCATE(wscr1(ngv,dsp + 1:dsp + cnt), wscr2(ngv,dsp + 1:dsp + cnt))
  !
  DO imf = 0, nmf
     !
     wscr1(1:ngv, dsp + 1:dsp + cnt) = &
     &    wscr(1:ngv, dsp + 1:dsp + cnt,imf)
     !
     wscr2(1:ngv,dsp + 1:dsp + cnt) = CMPLX(0.0_dp, 0.0_dp, KIND=dp)
     DO ig = dsp + 1, dsp + cnt
        wscr2(ig,ig) = CMPLX(1.0_dp, 0.0_dp, KIND=dp)
     END DO
     !
     DO ig = 1, ngv
        !
        ! Percial pivotting
        !
        jg = MAX(ig, dsp + 1)
        maxpiv( 1:nproc) = 0.0_dp
        imaxpiv(1:nproc) = 0
        IF(jg > dsp + cnt) THEN
           maxpiv( mpime) = - 1e10_dp
           imaxpiv(mpime) =   1
        ELSE
           maxpiv( mpime) = MAXVAL(REAL(CONJG(wscr1(ig, jg:dsp + cnt)) &
           &                                * wscr1(ig, jg:dsp + cnt), dp ))
           imaxpiv(mpime) = maxloc(REAL(CONJG(wscr1(ig, jg:dsp + cnt)) &
           &                                * wscr1(ig, jg:dsp + cnt), dp ), 1) &
           &              + jg - 1
        END IF
        !
        CALL mp_sum(maxpiv, world_comm)
        CALL mp_sum(imaxpiv, world_comm)
        jg = imaxpiv(MAXLOC(maxpiv, 1))
        !
        ! Exchange wscr1
        !
        key1(1:ngv) = CMPLX(0.0_dp, 0.0_dp, KIND=dp)
        key2(1:ngv) = CMPLX(0.0_dp, 0.0_dp, KIND=dp)
        !
        IF(dsp + 1 <= ig  .AND. ig  <= dsp + cnt) key1(1:ngv) = wscr1(1:ngv,ig)
        IF(dsp + 1 <= jg  .AND. jg  <= dsp + cnt) key2(1:ngv) = wscr1(1:ngv,jg)
        !
        CALL mp_sum( key1, world_comm )
        !
        CALL mp_sum( key2, world_comm )
        !
        IF(dsp + 1 <= ig  .AND. ig  <= dsp + cnt) wscr1(1:ngv,ig) = key2(1:ngv)
        IF(dsp + 1 <= jg  .AND. jg  <= dsp + cnt) wscr1(1:ngv,jg) = key1(1:ngv)
        !
        ! Exchange wscr2
        !
        key1(1:ngv) = CMPLX(0.0_dp, 0.0_dp, KIND=dp)
        key2(1:ngv) = CMPLX(0.0_dp, 0.0_dp, KIND=dp)
        !
        IF(dsp + 1 <= ig  .AND. ig  <= dsp + cnt) key1(1:ngv) = wscr2(1:ngv,ig)
        IF(dsp + 1 <= jg  .AND. jg  <= dsp + cnt) key2(1:ngv) = wscr2(1:ngv,jg)
        !
        CALL mp_sum( key1, world_comm )
        !
        CALL mp_sum( key2, world_comm )
        !
        IF(dsp + 1 <= ig  .AND. ig  <= dsp + cnt) wscr2(1:ngv,ig) = key2(1:ngv)
        IF(dsp + 1 <= jg  .AND. jg  <= dsp + cnt) wscr2(1:ngv,jg) = key1(1:ngv)
        !
        ! Ordinally Gauss-Jordan
        !
        IF(dsp + 1 <= ig .AND. ig <= dsp + cnt) THEN
           !
           piv = CMPLX(1.0_dp, 0.0_dp, KIND=dp) / wscr1(ig, ig)
           !
           IF(ABS(piv) < 1e-12_dp) THEN
              CALL errore ('invert', 'Singular imf = ', imf)
           END IF
           !
           CALL zscal(ngv, piv, wscr1(1:ngv, ig), 1)
           CALL zscal(ngv, piv, wscr2(1:ngv, ig), 1)
           CALL zcopy(ngv, wscr1(1:ngv, ig), 1, key1, 1)
           CALL zcopy(ngv, wscr2(1:ngv, ig), 1, key2, 1)
           !  
        ELSE
           !
           key1(1:ngv) = CMPLX(0.0_dp, 0.0_dp, KIND=dp)
           key2(1:ngv) = CMPLX(0.0_dp, 0.0_dp, KIND=dp)
           !
        END IF ! IF(dsp + 1 <= ig .AND. ig <= dsp + cnt) 
        !
        CALL mp_sum( key1, world_comm )
        !
        CALL mp_sum( key2, world_comm )
        !
        !$OMP PARALLEL DEFAULT(NONE) &
        !$OMP & SHARED(dsp,cnt,ig,wscr1,wscr2,ngv,key1,key2) &
        !$OMP & PRIVATE(jg,piv)
        !
        !$OMP DO
        DO jg = dsp + 1, dsp + cnt
           !
           IF(jg == ig) CYCLE
           !
           piv = - wscr1(ig, jg)
           CALL zaxpy(ngv, piv, key1, 1, wscr1(1:ngv, jg), 1)
           CALL zaxpy(ngv, piv, key2, 1, wscr2(1:ngv, jg), 1)
           !
        END DO ! jg
        !$OMP END DO
        !$OMP END PARALLEL
        !
     END DO ! ig
     !
     wscr(1:ngv,dsp + 1:dsp + cnt, imf) = &
     &  wscr2(1:ngv,dsp + 1:dsp + cnt)
     !
  END DO ! imf
  !
  DEALLOCATE(wscr1,wscr2)
  !
END SUBROUTINE invert
!
END MODULE sctk_invert
