27 #include "kernels_trace.h"
29 #ifndef DOXYGEN_SHOULD_SKIP_THIS
30 static pastix_complex32_t cone = 1.0;
31 static pastix_complex32_t czero = 0.0;
61 pastix_complex32_t *u, *v;
64 u = malloc( M * N *
sizeof(pastix_complex32_t) );
65 memset( u, 0, M * N *
sizeof(pastix_complex32_t) );
71 else if ( rkmax == 0 ) {
78 pastix_int_t rk = pastix_imin( M, N );
79 rkmax = pastix_imin( rkmax, rk );
81 #if defined(PASTIX_DEBUG_LR)
82 u = malloc( M * rkmax *
sizeof(pastix_complex32_t) );
83 v = malloc( N * rkmax *
sizeof(pastix_complex32_t) );
86 memset(u, 0, M * rkmax *
sizeof(pastix_complex32_t));
87 memset(v, 0, N * rkmax *
sizeof(pastix_complex32_t));
89 u = malloc( (M+N) * rkmax *
sizeof(pastix_complex32_t));
92 memset(u, 0, (M+N) * rkmax *
sizeof(pastix_complex32_t));
124 #if defined(PASTIX_DEBUG_LR)
176 pastix_int_t newrkmax,
177 pastix_int_t rklimit )
183 newrkmax = (newrkmax == -1) ? newrk : newrkmax;
184 newrkmax = pastix_imax( newrkmax, newrk );
189 if ( (newrk > rklimit) || (newrk == -1) )
191 A->u = realloc( A->u, M * N *
sizeof(pastix_complex32_t) );
192 #if defined(PASTIX_DEBUG_LR)
203 else if (newrkmax == 0)
209 #if defined(PASTIX_DEBUG_LR)
222 pastix_complex32_t *u, *v;
225 if ( ( A->rk == -1 ) ||
226 (( A->rk != -1 ) && (newrkmax != A->rkmax)) )
228 #if defined(PASTIX_DEBUG_LR)
229 u = malloc( M * newrkmax *
sizeof(pastix_complex32_t) );
230 v = malloc( N * newrkmax *
sizeof(pastix_complex32_t) );
232 u = malloc( (M+N) * newrkmax *
sizeof(pastix_complex32_t) );
233 v = u + M * newrkmax;
236 assert( A->rk != -1 );
237 ret = LAPACKE_clacpy_work( LAPACK_COL_MAJOR,
'A', M, newrk,
240 ret = LAPACKE_clacpy_work( LAPACK_COL_MAJOR,
'A', newrk, N,
241 A->v, A->rkmax, v, newrkmax );
245 #if defined(PASTIX_DEBUG_LR)
258 assert( A->rk <= A->rkmax);
302 pastix_complex32_t *A, pastix_int_t lda )
313 if (Alr == NULL || Alr->
rk > Alr->
rkmax) {
321 if ( Alr->
rk == -1 ) {
322 if (Alr->
u == NULL || Alr->
v != NULL || (Alr->
rkmax < m))
327 else if ( Alr->
rk != 0){
328 if (Alr->
u == NULL || Alr->
v == NULL) {
335 if ( Alr->
rk == -1 ) {
336 ret = LAPACKE_clacpy_work( LAPACK_COL_MAJOR,
'A', m, n,
337 Alr->
u, Alr->
rkmax, A, lda );
340 else if ( Alr->
rk == 0 ) {
341 ret = LAPACKE_claset_work( LAPACK_COL_MAJOR,
'A', m, n,
346 cblas_cgemm(CblasColMajor, CblasNoTrans, CblasNoTrans,
348 CBLAS_SADDR(cone), Alr->
u, m,
350 CBLAS_SADDR(czero), A, lda);
354 if ( Alr->
rk == -1 ) {
357 else if ( Alr->
rk == 0 ) {
358 ret = LAPACKE_claset_work( LAPACK_COL_MAJOR,
'A', n, m,
363 cblas_cgemm(CblasColMajor, CblasTrans, CblasTrans,
365 CBLAS_SADDR(cone), Alr->
v, Alr->
rkmax,
367 CBLAS_SADDR(czero), A, lda);
422 pastix_int_t offx, pastix_int_t offy )
424 pastix_complex32_t *u, *v;
425 pastix_int_t ldau, ldav;
428 assert( (M1 + offx) <= M2 );
429 assert( (N1 + offy) <= N2 );
431 ldau = (A->rk == -1) ? A->rkmax : M1;
442 if ( (M1 != M2) || (N1 != N2) ) {
443 LAPACKE_claset_work( LAPACK_COL_MAJOR,
'A', M2, N2,
448 0.0, u + M2 * offy + offx, M2 );
453 LAPACKE_claset_work( LAPACK_COL_MAJOR,
'A', M2, B->
rk,
456 ret = LAPACKE_clacpy_work( LAPACK_COL_MAJOR,
'A', M1, A->rk,
462 LAPACKE_claset_work( LAPACK_COL_MAJOR,
'A', B->
rk, N2,
463 0.0, 0.0, v, B->
rkmax );
473 pastix_complex32_t *work = malloc( M2 * N2 *
sizeof(pastix_complex32_t) );
525 pastix_complex32_t *u1u2 )
527 pastix_complex32_t *tmp;
528 pastix_int_t i, ret, rank;
529 pastix_int_t ldau, ldbu;
531 rank = (A->rk == -1) ? pastix_imin(M1, N1) : A->rk;
534 ldau = (A->rk == -1) ? A->rkmax : M1;
537 ret = LAPACKE_clacpy_work( LAPACK_COL_MAJOR,
'A', M2, B->
rk,
538 B->
u, ldbu, u1u2, M2 );
541 tmp = u1u2 + B->
rk * M2;
549 memset(tmp, 0, M2 * M1 *
sizeof(pastix_complex32_t));
553 for (i=0; i<M1; i++, tmp += M2+1) {
559 ret = LAPACKE_claset_work( LAPACK_COL_MAJOR,
'A', M2, M1,
569 memset(tmp, 0, M2 * N1 *
sizeof(pastix_complex32_t));
571 ret = LAPACKE_clacpy_work( LAPACK_COL_MAJOR,
'A', M1, N1,
572 A->u, ldau, tmp + offx, M2 );
581 memset(tmp, 0, M2 * A->rk *
sizeof(pastix_complex32_t));
583 ret = LAPACKE_clacpy_work( LAPACK_COL_MAJOR,
'A', M1, A->rk,
584 A->u, ldau, tmp + offx, M2 );
633 pastix_complex32_t *v1v2 )
635 pastix_complex32_t *tmp;
636 pastix_int_t i, ret, rank;
637 pastix_int_t ldau, ldav, ldbv;
639 rank = (A->rk == -1) ? pastix_imin(M1, N1) : A->rk;
642 ldau = (A->rk == -1) ? A->rkmax : M1;
646 ret = LAPACKE_clacpy_work( LAPACK_COL_MAJOR,
'A', B->
rk, N2,
647 B->
v, ldbv, v1v2, rank );
658 ret = LAPACKE_claset_work( LAPACK_COL_MAJOR,
'A', M1, N2,
659 0.0, 0.0, tmp, rank );
664 0.0, tmp + offy * rank, rank );
672 ret = LAPACKE_claset_work( LAPACK_COL_MAJOR,
'A', N1, N2,
673 0.0, 0.0, tmp, rank );
678 for (i=0; i<N1; i++, tmp += rank+1) {
684 ret = LAPACKE_claset_work( LAPACK_COL_MAJOR,
'A', N1, N2,
685 0.0, alpha, tmp + offy * rank, rank );
695 ret = LAPACKE_claset_work( LAPACK_COL_MAJOR,
'A', A->rk, N2,
696 0.0, 0.0, tmp, rank );
701 0.0, tmp + offy * rank, rank );
750 int use_reltol, pastix_fixdbl_t tol, pastix_int_t rklimit,
751 pastix_int_t m, pastix_int_t n,
752 const void *Avoid, pastix_int_t lda,
756 pastix_int_t nb = 32;
757 pastix_complex32_t *A = (pastix_complex32_t*)Avoid;
758 pastix_complex32_t *Acpy;
760 pastix_complex32_t *work, *tau, zzsize;
763 pastix_int_t zsize, rsize;
764 pastix_complex32_t *zwork;
765 pastix_fixdbl_t flops;
767 float norm = LAPACKE_clange_work( LAPACK_COL_MAJOR,
'f', m, n,
770 if ( (norm == 0.) && (tol >= 0.)) {
780 else if ( use_reltol ) {
784 ret = rrqrfct( tol, rklimit, 0, nb,
789 lwork = (pastix_int_t)zzsize;
798 #if defined(PASTIX_DEBUG_LR)
800 Acpy = malloc( m * n *
sizeof(pastix_complex32_t) );
801 tau = malloc( n *
sizeof(pastix_complex32_t) );
802 work = malloc( lwork *
sizeof(pastix_complex32_t) );
803 rwork = malloc( rsize *
sizeof(
float) );
805 zwork = malloc( zsize *
sizeof(pastix_complex32_t) + rsize *
sizeof(
float) );
809 rwork = (
float*)(work + lwork);
812 jpvt = malloc( n *
sizeof(pastix_int_t) );
817 ret = LAPACKE_clacpy_work( LAPACK_COL_MAJOR,
'A', m, n,
821 newrk = rrqrfct( tol, rklimit, 0, nb,
824 work, lwork, rwork );
826 flops = FLOPS_CGEQRF( m, n );
829 flops = FLOPS_CGEQRF( m, newrk ) + FLOPS_CUNMQR( m, n-newrk, newrk,
PastixLeft );
839 ret = LAPACKE_clacpy_work( LAPACK_COL_MAJOR,
'A', m, n,
840 A, lda, Alr->
u, Alr->
rkmax );
843 else if ( newrk > 0 ) {
848 pastix_complex32_t *U, *V;
854 ret = LAPACKE_clacpy_work( LAPACK_COL_MAJOR,
'A', m, Alr->
rk,
858 ret = LAPACKE_cungqr_work( LAPACK_COL_MAJOR, m, Alr->
rk, Alr->
rk,
859 U, m, tau, work, lwork );
861 flops += FLOPS_CUNGQR( m, Alr->
rk, Alr->
rk );
864 ret = LAPACKE_claset_work( LAPACK_COL_MAJOR,
'L', Alr->
rk-1, Alr->
rk-1,
865 0.0, 0.0, Acpy + 1, m );
868 memcpy( V + jpvt[i] * Alr->
rk,
870 Alr->
rk *
sizeof(pastix_complex32_t) );
874 #if defined(PASTIX_DEBUG_LR)
878 fprintf(stderr,
"Failed to compress a matrix and generate an orthogonal u\n" );
885 #if defined(PASTIX_DEBUG_LR)
938 int use_reltol, pastix_fixdbl_t tol, pastix_int_t rklimit,
939 pastix_int_t m, pastix_int_t n,
940 const void *Avoid, pastix_int_t lda,
944 pastix_int_t nb = 32;
945 pastix_complex32_t *A = (pastix_complex32_t*)Avoid;
946 pastix_complex32_t *Acpy;
948 pastix_complex32_t *work, *tau, *B, *tau_b, zzsize;
950 pastix_int_t zsize, bsize;
951 pastix_complex32_t *zwork;
952 pastix_fixdbl_t flops;
955 #if defined(PRECISION_c) || defined(PRECISION_z)
961 float norm = LAPACKE_clange_work( LAPACK_COL_MAJOR,
'f', m, n,
964 if ( (norm == 0.) && (tol >= 0.)) {
974 else if ( use_reltol ) {
978 ret = rrqrfct( tol, rklimit, nb,
984 lwork = (pastix_int_t)zzsize;
994 #if defined(PASTIX_DEBUG_LR)
996 Acpy = malloc( m * n *
sizeof(pastix_complex32_t) );
997 tau = malloc( n *
sizeof(pastix_complex32_t) );
998 B = malloc( bsize *
sizeof(pastix_complex32_t) );
999 tau_b = malloc( n *
sizeof(pastix_complex32_t) );
1000 work = malloc( lwork *
sizeof(pastix_complex32_t) );
1002 zwork = malloc( zsize *
sizeof(pastix_complex32_t) );
1010 jpvt = malloc( n *
sizeof(pastix_int_t) );
1015 ret = LAPACKE_clacpy_work( LAPACK_COL_MAJOR,
'A', m, n,
1019 newrk = rrqrfct( tol, rklimit, nb,
1023 work, lwork, norm );
1025 flops = FLOPS_CGEQRF( m, n );
1028 flops = FLOPS_CGEQRF( m, newrk ) + FLOPS_CUNMQR( m, n-newrk, newrk,
PastixLeft );
1037 if ( newrk == -1 ) {
1038 ret = LAPACKE_clacpy_work( LAPACK_COL_MAJOR,
'A', m, n,
1039 A, lda, Alr->
u, Alr->
rkmax );
1042 else if ( newrk > 0 ) {
1046 pastix_complex32_t *U, *V;
1047 pastix_int_t d, rk = 0;
1053 ret = LAPACKE_clacpy_work( LAPACK_COL_MAJOR,
'A', m, Alr->
rk,
1057 ret = LAPACKE_cungqr_work( LAPACK_COL_MAJOR, m, Alr->
rk, Alr->
rk,
1058 U, m, tau, work, lwork );
1060 flops += FLOPS_CUNGQR( m, Alr->
rk, Alr->
rk );
1063 ret = LAPACKE_clacpy_work( LAPACK_COL_MAJOR,
'U', Alr->
rk, n,
1064 Acpy, m, V, Alr->
rk );
1066 ret = LAPACKE_claset_work( LAPACK_COL_MAJOR,
'L', Alr->
rk-1, Alr->
rk-1,
1067 0.0, 0.0, V + 1, Alr->
rk );
1076 rk = (Alr->
rk / nb) * nb;
1078 d = pastix_imin( nb, Alr->
rk - rk );
1079 ret = LAPACKE_cunmqr_work( LAPACK_COL_MAJOR,
'R', trans,
1080 Alr->
rk - rk, n - rk, d,
1081 B + rk * n + rk, n, tau_b + rk,
1082 V + rk * Alr->
rk + rk, Alr->
rk,
1090 #if defined(PASTIX_DEBUG_LR)
1091 if ( Alr->
rk > 0 ) {
1094 fprintf(stderr,
"Failed to compress a matrix and generate an orthogonal u\n" );
1101 #if defined(PASTIX_DEBUG_LR)
1178 pastix_int_t offx, pastix_int_t offy )
1180 pastix_int_t rankA, rank, M, N, minV;
1181 pastix_int_t i, ret, new_rank, rklimit;
1182 pastix_int_t ldau, ldav, ldbu, ldbv, ldu, ldv;
1183 pastix_complex32_t *u1u2, *v1v2, *u;
1184 pastix_complex32_t *zbuf, *tauV;
1189 pastix_int_t nb = 32;
1192 pastix_complex32_t *zwork, zzsize;
1194 pastix_complex32_t alpha = *((pastix_complex32_t*)alphaptr);
1195 pastix_fixdbl_t flops, total_flops = 0.;
1197 #if defined(PASTIX_DEBUG_LR)
1201 fprintf(stderr,
"Failed to have B->u orthogonal in entry of rradd\n" );
1206 rankA = (A->rk == -1) ? pastix_imin(M1, N1) : A->rk;
1207 rank = rankA + B->
rk;
1208 M = pastix_imax(M2, M1);
1209 N = pastix_imax(N2, N1);
1211 minV = pastix_imin(N, rank);
1213 assert(M2 == M && N2 == N);
1214 assert(B->
rk != -1);
1216 assert( A->rk <= A->rkmax);
1219 if ( ((M1 + offx) > M2) ||
1220 ((N1 + offy) > N2) )
1222 pastix_print_error(
"Dimensions are not correct" );
1240 M1, N1, A, M2, N2, B,
1248 if ( rank > pastix_imin( M, N ) ) {
1255 ldau = (A->rk == -1) ? A->rkmax : M1;
1266 wzsize = (M+N) * rank;
1272 rrqrfct( tol, rklimit, 1, nb,
1275 &zzsize, -1, NULL );
1276 lwork = (pastix_int_t)(zzsize);
1279 #if defined(PASTIX_DEBUG_LR)
1281 u1u2 = malloc( ldu * rank *
sizeof(pastix_complex32_t) );
1282 v1v2 = malloc( ldv * N *
sizeof(pastix_complex32_t) );
1283 tauV = malloc( rank *
sizeof(pastix_complex32_t) );
1284 zwork = malloc( lwork *
sizeof(pastix_complex32_t) );
1286 rwork = malloc( 2 * pastix_imax( rank, N ) *
sizeof(
float) );
1288 zbuf = malloc( wzsize *
sizeof(pastix_complex32_t) + 2 * pastix_imax(rank, N) *
sizeof(
float) );
1291 v1v2 = u1u2 + ldu * rank;
1292 tauV = v1v2 + ldv * N;
1293 zwork = tauV + rank;
1295 rwork = (
float*)(zwork + lwork);
1328 kernel_trace_start_lvl2( PastixKernelLvl2_LR_add2C_rradd_orthogonalize );
1332 u1u2, ldu, v1v2, ldv );
1337 u1u2, ldu, v1v2, ldv );
1341 pastix_attr_fallthrough;
1345 u1u2, ldu, v1v2, ldv );
1347 kernel_trace_stop_lvl2( flops );
1349 total_flops += flops;
1352 rank = B->
rk + rankA;
1359 LAPACKE_clacpy_work( LAPACK_COL_MAJOR,
'A', M, B->
rk, u1u2, ldu, B->
u, ldbu );
1360 LAPACKE_clacpy_work( LAPACK_COL_MAJOR,
'A', B->
rk, N, v1v2, ldv, B->
v, ldbv );
1363 #if defined(PASTIX_DEBUG_LR)
1373 MALLOC_INTERN( jpvt, pastix_imax(rank, N), pastix_int_t );
1401 tol = tol * ( cabsf(alpha) * normA + normB );
1408 kernel_trace_start_lvl2( PastixKernelLvl2_LR_add2C_rradd_recompression );
1409 rklimit = pastix_imin( rklimit, rank );
1410 new_rank = rrqrfct( tol, rklimit, 1, nb,
1413 zwork, lwork, rwork );
1414 flops = (new_rank == -1) ? FLOPS_CGEQRF( rank, N )
1415 : (FLOPS_CGEQRF( rank, new_rank ) +
1416 FLOPS_CUNMQR( rank, N-new_rank, new_rank,
PastixLeft ));
1417 kernel_trace_stop_lvl2_rank( flops, new_rank );
1418 total_flops += flops;
1423 if ( (new_rank > rklimit) ||
1432 flops = FLOPS_CGEMM( M, N, Bbackup.
rk );
1433 kernel_trace_start_lvl2( PastixKernelLvl2_LR_add2C_uncompress );
1434 cblas_cgemm(CblasColMajor, CblasNoTrans, CblasNoTrans,
1436 CBLAS_SADDR(cone), Bbackup.
u, ldbu,
1438 CBLAS_SADDR(czero), u, M );
1439 kernel_trace_stop_lvl2( flops );
1440 total_flops += flops;
1443 if ( A->rk == -1 ) {
1444 flops = 2 * M1 * N1;
1445 kernel_trace_start_lvl2( PastixKernelLvl2_FR_GEMM );
1448 cone, u + offy * M + offx, M);
1449 kernel_trace_stop_lvl2( flops );
1452 flops = FLOPS_CGEMM( M1, N1, A->rk );
1453 kernel_trace_start_lvl2( PastixKernelLvl2_FR_GEMM );
1454 cblas_cgemm(CblasColMajor, CblasNoTrans, (CBLAS_TRANSPOSE)transA1,
1456 CBLAS_SADDR(alpha), A->u, ldau,
1458 CBLAS_SADDR(cone), u + offy * M + offx, M);
1459 kernel_trace_stop_lvl2( flops );
1461 total_flops += flops;
1465 #if defined(PASTIX_DEBUG_LR)
1474 else if ( new_rank == 0 ) {
1478 #if defined(PASTIX_DEBUG_LR)
1492 ret =
core_clrsze( 0, M, N, B, new_rank, -1, -1 );
1493 assert( ret != -1 );
1494 assert( B->
rkmax >= new_rank );
1501 pastix_complex32_t *tmpV;
1504 memset(B->
v, 0, N * ldbv *
sizeof(pastix_complex32_t));
1506 for (i=0; i<N; i++){
1507 lm = pastix_imin( new_rank, i+1 );
1508 memcpy(tmpV + jpvt[i] * ldbv,
1510 lm *
sizeof(pastix_complex32_t));
1516 flops = FLOPS_CUNGQR( rank, new_rank, new_rank )
1517 + FLOPS_CGEMM( M, new_rank, rank );
1519 kernel_trace_start_lvl2( PastixKernelLvl2_LR_add2C_rradd_computeNewU );
1520 ret = LAPACKE_cungqr_work( LAPACK_COL_MAJOR, rank, new_rank, new_rank,
1521 v1v2, ldv, tauV, zwork, lwork );
1524 cblas_cgemm(CblasColMajor, CblasNoTrans, CblasNoTrans,
1526 CBLAS_SADDR(cone), u1u2, ldu,
1528 CBLAS_SADDR(czero), B->
u, ldbu);
1529 kernel_trace_stop_lvl2( flops );
1530 total_flops += flops;
1536 #if defined(PASTIX_DEBUG_LR)
1572 if ( A->rk != -1 ) {
1573 return A->rk * ( M + N );
1608 int rkmax = A->rkmax;
1613 memcpy( buffer, &rk,
sizeof(
int ) );
1614 buffer +=
sizeof( int );
1618 memcpy( buffer, u, rk * M *
sizeof( pastix_complex32_t ) );
1619 buffer += rk * M *
sizeof( pastix_complex32_t );
1622 if ( rk == rkmax ) {
1623 memcpy( buffer, v, rk * N *
sizeof( pastix_complex32_t ) );
1624 buffer += rk * N *
sizeof( pastix_complex32_t );
1627 LAPACKE_clacpy_work(
1628 LAPACK_COL_MAJOR,
'A', rk, N, v, rkmax, (pastix_complex32_t *)buffer, rk );
1629 buffer += rk * N *
sizeof( pastix_complex32_t );
1633 memcpy( buffer, u, M * N *
sizeof( pastix_complex32_t ) );
1634 buffer += M * N *
sizeof( pastix_complex32_t );
1668 memcpy( &rk, buffer,
sizeof(
int ) );
1669 buffer +=
sizeof( int );
1676 memcpy( A->u, buffer, M * rk *
sizeof( pastix_complex32_t ) );
1677 buffer += M * rk *
sizeof( pastix_complex32_t );
1680 memcpy( A->v, buffer, N * rk *
sizeof( pastix_complex32_t ) );
1681 buffer += N * rk *
sizeof( pastix_complex32_t );
1685 memcpy( A->u, buffer, M * N *
sizeof( pastix_complex32_t ) );
1686 buffer += M * N *
sizeof( pastix_complex32_t );
1718 const char *input,
char **outptr )
1720 char *output = *outptr;
1724 rk = *((
int *)input);
1725 input +=
sizeof( int );
1732 size = M * rk *
sizeof( pastix_complex32_t );
1735 memcpy( A->u, input, size );
1740 size = N * rk *
sizeof( pastix_complex32_t );
1743 memcpy( A->v, input, size );
1753 size = M * N *
sizeof( pastix_complex32_t );
1756 memcpy( A->u, input, size );