PaStiX Handbook 6.4.0
Loading...
Searching...
No Matches
core_stqrcp.c
Go to the documentation of this file.
1/**
2 *
3 * @file core_stqrcp.c
4 *
5 * PaStiX implementation of the truncated rank-revealing QR with column pivoting
6 * based on Lapack GEQP3.
7 *
8 * @copyright 2016-2024 Bordeaux INP, CNRS (LaBRI UMR 5800), Inria,
9 * Univ. Bordeaux. All rights reserved.
10 *
11 * @version 6.4.0
12 * @author Alfredo Buttari
13 * @author Gregoire Pichon
14 * @author Esragul Korkmaz
15 * @author Mathieu Faverge
16 * @date 2024-07-05
17 * @generated from /builds/2mk6rsew/0/solverstack/pastix/kernels/core_ztqrcp.c, normal z -> s, Tue Feb 25 14:34:57 2025
18 *
19 **/
20#include "common.h"
21#include <cblas.h>
22#include <lapacke.h>
23#include "blend/solver.h"
24#include "pastix_scores.h"
25#include "pastix_slrcores.h"
26#include "s_nan_check.h"
27
28#ifndef DOXYGEN_SHOULD_SKIP_THIS
29static float msone = -1.0;
30static float sone = 1.0;
31static float szero = 0.0;
32#endif /* DOXYGEN_SHOULD_SKIP_THIS */
33
34/**
35 *******************************************************************************
36 *
37 * @brief Compute a randomized QR factorization with truncated updates.
38 *
39 * This routine is derivated from "Randomized QR with Column Pivoting",
40 * J. A. Duersch and M. Gu, SIAM Journal on Scientific Computing, vol. 39,
41 * no. 4, pp. C263-C291, 2017.
42 *
43 *******************************************************************************
44 *
45 * @param[in] tol
46 * The relative tolerance criterion. Computations are stopped when the
47 * frobenius norm of the residual matrix is lower than tol.
48 * If tol < 0, then maxrank reflectors are computed.
49 *
50 * @param[in] maxrank
51 * Maximum number of reflectors computed. Computations are stopped when
52 * the rank exceeds maxrank. If maxrank < 0, all reflectors are computed
53 * or up to the tolerance criterion.
54 *
55 * @param[in] refine
56 * TODO
57 *
58 * @param[in] nb
59 * Tuning parameter for the GEMM blocking size. if nb < 0, nb is set to
60 * 32.
61 *
62 * @param[in] m
63 * Number of rows of the matrix A.
64 *
65 * @param[in] n
66 * Number of columns of the matrix A.
67 *
68 * @param[in] A
69 * The matrix of dimension lda-by-n that needs to be compressed.
70 *
71 * @param[in] lda
72 * The leading dimension of the matrix A. lda >= max(1, m).
73 *
74 * @param[out] jpvt
75 * The array that describes the permutation of A.
76 *
77 * @param[out] tau
78 * Contains scalar factors of the elementary reflectors for the matrix
79 * Q.
80 *
81 * @param[in] work
82 * Workspace array of size lwork.
83 *
84 * @param[in] lwork
85 * The dimension of the work area. lwork >= (nb * n + max(n, m) )
86 * If lwork == -1, the functions returns immediately and work[0]
87 * contains the optimal size of work.
88 *
89 * @param[in] rwork
90 * Workspace array used to store partial and exact column norms (2-by-n)
91 *
92 *******************************************************************************
93 *
94 * @return This routine will return the rank of A (>=0) or -1 if it didn't
95 * manage to compress within the margins of tolerance and maximum rank.
96 *
97 *******************************************************************************/
98int
99core_stqrcp( float tol,
100 pastix_int_t maxrank,
101 int refine,
102 pastix_int_t nb,
103 pastix_int_t m,
104 pastix_int_t n,
105 float *A,
106 pastix_int_t lda,
107 pastix_int_t *jpvt,
108 float *tau,
109 float *work,
110 pastix_int_t lwork,
111 float *rwork )
112{
113 int SEED[4] = {26, 67, 52, 197};
114 pastix_int_t j, k, in, itmp, d, ib, loop = 1;
115 int ret;
116 pastix_int_t minMN, lwkopt;
117 pastix_int_t p = 5;
118 pastix_int_t bp = ( nb < p ) ? 32 : nb;
119 pastix_int_t b = bp - p;
120 pastix_int_t size_B, size_O, size_W, size_Y, size_A, size_T, sublw;
121 pastix_int_t ldb, ldw, ldy;
122 pastix_int_t *jpvt_b;
123 pastix_int_t rk;
124 float tolB = sqrtf( (float)(bp) ) * tol;
125 float *AP, *Y, *WT, *T, *B, *tau_b, *omega, *subw;
126
127 minMN = pastix_imin(m, n);
128 if ( maxrank < 0 ) {
129 maxrank = minMN;
130 }
131 maxrank = pastix_imin( maxrank, minMN );
132
133 ldb = bp;
134 ldw = maxrank;
135
136 size_B = ldb * n;
137 size_O = ldb * m;
138 size_W = n * maxrank;
139 size_Y = b * b;
140 ldy = b;
141 size_A = m * n;
142 size_T = b * b;
143
144 sublw = n * bp + pastix_imax( bp, n ); /* pqrcp */
145 sublw = pastix_imax( sublw, size_O ); /* Omega */
146 sublw = pastix_imax( sublw, b * maxrank ); /* update */
147
148 lwkopt = size_A + size_Y + size_W
149 + size_T + size_B + n + sublw;
150
151 if ( lwork == -1 ) {
152 work[0] = (float)lwkopt;
153 return 0;
154 }
155#if !defined(NDEBUG)
156 if (m < 0) {
157 return -1;
158 }
159 if (n < 0) {
160 return -2;
161 }
162 if (lda < pastix_imax(1, m)) {
163 return -4;
164 }
165 if( lwork < lwkopt ) {
166 return -8;
167 }
168#endif
169
170 /**
171 * If maximum rank is 0, then either the matrix norm is below the tolerance,
172 * and we can return a null rank matrix, or it is not and we need to return
173 * a full rank matrix.
174 */
175 if ( maxrank == 0 ) {
176 float norm;
177 if ( tol < 0. ) {
178 return 0;
179 }
180 norm = LAPACKE_slange_work( LAPACK_COL_MAJOR, 'f', m, n,
181 A, lda, NULL );
182 if ( norm < tol ) {
183 return 0;
184 }
185 return -1;
186 }
187
188 jpvt_b = malloc( n * sizeof(pastix_int_t) );
189
190 AP = work;
191 Y = AP + size_A;
192 WT = Y + size_Y;
193 T = WT + size_W;
194 B = T + size_T;
195 tau_b = B + size_B;
196 omega = tau_b + n;
197 subw = tau_b + n;
198
199 /* Initialize diagonal block of Housholders reflectors */
200 ret = LAPACKE_slaset_work( LAPACK_COL_MAJOR, 'A', b, b,
201 0., 1., Y, ldy );
202 assert( ret == 0 );
203
204 /* Initialize T */
205 memset(T, 0, size_T * sizeof(float));
206
207 /* Backup A */
208 ret = LAPACKE_slacpy_work( LAPACK_COL_MAJOR, 'A', m, n,
209 A, lda, AP, m );
210 assert( ret == 0 );
211
212 /* Initialize pivots */
213 for (j=0; j<n; j++) jpvt[j] = j;
214
215 /*
216 * Computation of the Gaussian matrix
217 */
218 ret = LAPACKE_slarnv_work(3, SEED, size_O, omega);
219 assert( ret == 0 );
220 cblas_sgemm( CblasColMajor, CblasNoTrans, CblasNoTrans,
221 bp, n, m,
222 (sone), omega, bp,
223 A, lda,
224 (szero), B, ldb );
225
226 rk = 0;
227 d = 0;
228 while ( loop )
229 {
230 ib = pastix_imin( b, minMN-rk );
231 d = core_spqrcp( tolB, ib, 1, bp,
232 bp, n-rk,
233 B + rk * ldb, ldb,
234 jpvt_b + rk, tau_b,
235 subw, sublw, rwork );
236
237 /* If fails to reach the tolerance before maxrank, let's restore the max value */
238 if ( d == -1 ) {
239 d = ib;
240 }
241 /* If smaller than ib, we reached the threshold */
242 if ( d < ib ) {
243 loop = 0;
244 }
245 if ( d == 0 ) {
246 break;
247 }
248 /* If we exceeded the max rank, let's stop now */
249 if ( (rk + d) > maxrank ) {
250 rk = -1;
251 break;
252 }
253
254 /* Updating jpvt, A, and AP */
255 for (j = rk; j < rk + d; j++) {
256 if (jpvt_b[j] >= 0) {
257 k = j;
258 in = jpvt_b[k] + rk;
259
260 /* Mark as done */
261 jpvt_b[k] = - jpvt_b[k] - 1;
262
263 while( jpvt_b[in] >= 0 ) {
264
265 if (k != in) {
266 cblas_sswap( m, A + k * lda, 1,
267 A + in * lda, 1 );
268 cblas_sswap( m, AP + k * m, 1,
269 AP + in * m, 1 );
270
271 itmp = jpvt[k];
272 jpvt[k] = jpvt[in];
273 jpvt[in] = itmp;
274
275 if (rk > 0) {
276 cblas_sswap( rk, WT + k * ldw, 1,
277 WT + in * ldw, 1 );
278 }
279 }
280 itmp = jpvt_b[in];
281 jpvt_b[in] = - jpvt_b[in] - 1;
282 k = in;
283 in = itmp + rk;
284 }
285 }
286 }
287
288 if (rk > 0) {
289 /* Update the selected columns before factorization */
290 cblas_sgemm( CblasColMajor, CblasNoTrans, CblasNoTrans,
291 m-rk, d, rk,
292 (msone), A + rk, lda,
293 WT + rk * ldw, ldw,
294 (sone), A + rk * lda + rk, lda );
295 }
296
297 /*
298 * Factorize the d selected columns of A without pivoting
299 */
300 ret = LAPACKE_sgeqrf_work( LAPACK_COL_MAJOR, m-rk, d,
301 A + rk * lda + rk, lda, tau + rk,
302 work, lwork );
303 assert( ret == 0 );
304
305 ret = LAPACKE_slarft_work( LAPACK_COL_MAJOR, 'F', 'C', m-rk, d,
306 A + rk * lda + rk, lda, tau + rk, T, b );
307 assert( ret == 0 );
308
309 /*
310 * Compute the update line 11 of algorithm 6 in "Randomized QR with
311 * Column pivoting" from Duersch and Gu
312 *
313 * W_2^h = T^h ( Y_2^h * A - (Y_2^h * Y) * W_1^h )
314 *
315 * Step 1: Y_2^h * A
316 * a) W[rk:rk+d] <- A
317 * b) W[rk:rk+d] <- Y_2^h * A, split in triangular part + rectangular part
318 */
319 ret = LAPACKE_slacpy_work( LAPACK_COL_MAJOR, 'L', d-1, d-1,
320 A + lda * rk + rk + 1, lda,
321 Y + 1, ldy );
322 assert( ret == 0 );
323
324 /* Triangular part */
325 cblas_sgemm( CblasColMajor, CblasTrans, CblasNoTrans,
326 d, n, d,
327 (sone), Y, ldy,
328 AP + rk, m,
329 (szero), WT + rk, ldw );
330
331 /* Rectangular part */
332 if ( rk + d < m ) {
333 cblas_sgemm( CblasColMajor, CblasTrans, CblasNoTrans,
334 d, n, m-rk-d,
335 (sone), A + rk * lda + rk + d, lda,
336 AP + rk + d, m,
337 (sone), WT + rk, ldw );
338 }
339
340 /*
341 * Step 2: (Y_2^h * A) - (Y_2^h * Y) * W_1^h
342 * a) work = (Y_2^h * Y)
343 * b) (Y_2^h * A) - work * W_1^h
344 */
345 if ( rk > 0 ) {
346 /* Triangular part */
347 cblas_sgemm( CblasColMajor, CblasTrans, CblasNoTrans,
348 d, rk, d,
349 (sone), Y, ldy,
350 A + rk, lda,
351 (szero), subw, d );
352
353 /* Rectangular part */
354 if ( rk + d < m ) {
355 cblas_sgemm( CblasColMajor, CblasTrans, CblasNoTrans,
356 d, rk, m-rk-d,
357 (sone), A + rk * lda + rk + d, lda,
358 A + rk + d, lda,
359 (sone), subw, d );
360 }
361
362 cblas_sgemm( CblasColMajor, CblasNoTrans, CblasNoTrans,
363 d, n, rk,
364 (msone), subw, d,
365 WT, ldw,
366 (sone), WT + rk, ldw );
367 }
368
369 /*
370 * Step 3: W_2^h = T^h ( Y_2^h * A - (Y_2^h * Y) * W_1^h )
371 * W_2^h = T^h W_2^h
372 */
373 cblas_strmm( CblasColMajor, CblasLeft, CblasUpper, CblasTrans, CblasNonUnit,
374 d, n, (sone),
375 T, b,
376 WT + rk, ldw );
377
378 /* Update current d rows of R */
379 if ( rk+d < n ) {
380 cblas_sgemm( CblasColMajor, CblasNoTrans, CblasNoTrans,
381 d, n-rk-d, rk,
382 (msone), A + rk, lda,
383 WT + (rk+d)*ldw, ldw,
384 (sone), A + rk + (rk+d)*lda, lda );
385
386 cblas_sgemm( CblasColMajor, CblasNoTrans, CblasNoTrans,
387 d, n-rk-d, d,
388 (msone), Y, ldy,
389 WT + rk + (rk+d)*ldw, ldw,
390 (sone), A + rk + (rk+d)*lda, lda );
391 }
392
393 if ( loop && (rk+d < maxrank) ) {
394 /*
395 * The Q from partial QRCP is stored in the lower part of the matrix,
396 * we need to remove it
397 */
398 ret = LAPACKE_slaset_work( LAPACK_COL_MAJOR, 'L', d-1, d-1,
399 0, 0, B + rk*ldb + 1, ldb );
400 assert( ret == 0 );
401
402 /* Updating B */
403 /* Solving S_11 * R_11^{-1} */
404 cblas_strsm( CblasColMajor, CblasRight, CblasUpper,
405 CblasNoTrans, CblasNonUnit,
406 d, d,
407 (sone), A + rk*lda + rk, lda,
408 B + rk*ldb, ldb );
409
410 /* Updating S_12 = S_12 - (S_11 * R_11^{-1}) * R_12 */
411 cblas_sgemm( CblasColMajor, CblasNoTrans, CblasNoTrans,
412 d, n - (rk+d), d,
413 (msone), B + rk *ldb, ldb,
414 A + (rk+d)*lda + rk, lda,
415 (sone), B + (rk+d)*ldb, ldb );
416 }
417 rk += d;
418 }
419 free( jpvt_b );
420
421 (void)ret;
422 (void)refine;
423 return rk;
424}
425
426/**
427 *******************************************************************************
428 *
429 * @brief Convert a full rank matrix in a low rank matrix, using TQRCP.
430 *
431 *******************************************************************************
432 *
433 * @param[in] use_reltol
434 * Defines if the kernel should use relative tolerance (tol *||A||), or
435 * absolute tolerance (tol).
436 *
437 * @param[in] tol
438 * The tolerance used as a criterion to eliminate information from the
439 * full rank matrix
440 *
441 * @param[in] rklimit
442 * The maximum rank to store the matrix in low-rank format. If
443 * -1, set to min(m, n) / PASTIX_LR_MINRATIO.
444 *
445 * @param[in] m
446 * Number of rows of the matrix A, and of the low rank matrix Alr.
447 *
448 * @param[in] n
449 * Number of columns of the matrix A, and of the low rank matrix Alr.
450 *
451 * @param[in] A
452 * The matrix of dimension lda-by-n that needs to be compressed
453 *
454 * @param[in] lda
455 * The leading dimension of the matrix A. lda >= max(1, m)
456 *
457 * @param[out] Alr
458 * The low rank matrix structure that will store the low rank
459 * representation of A
460 *
461 *******************************************************************************
462 *
463 * @return TODO
464 *
465 *******************************************************************************/
467core_sge2lr_tqrcp( int use_reltol,
468 pastix_fixdbl_t tol,
469 pastix_int_t rklimit,
470 pastix_int_t m,
471 pastix_int_t n,
472 const void *A,
473 pastix_int_t lda,
474 pastix_lrblock_t *Alr )
475{
476 return core_sge2lr_qrcp( core_stqrcp, use_reltol, tol, rklimit,
477 m, n, A, lda, Alr );
478}
479
480
481/**
482 *******************************************************************************
483 *
484 * @brief Add two LR structures A=(-u1) v1^T and B=u2 v2^T into u2 v2^T
485 *
486 * u2v2^T - u1v1^T = (u2 u1) (v2 v1)^T
487 * Orthogonalize (u2 u1) = (u2, u1 - u2(u2^T u1)) * (I u2^T u1)
488 * (0 I )
489 * Compute TQRCP decomposition of (I u2^T u1) * (v2 v1)^T
490 * (0 I )
491 *
492 *******************************************************************************
493 *
494 * @param[in] lowrank
495 * The structure with low-rank parameters.
496 *
497 * @param[in] transA1
498 * @arg PastixNoTrans: No transpose, op( A ) = A;
499 * @arg PastixTrans: Transpose, op( A ) = A';
500 *
501 * @param[in] alphaptr
502 * alpha * A is add to B
503 *
504 * @param[in] M1
505 * The number of rows of the matrix A.
506 *
507 * @param[in] N1
508 * The number of columns of the matrix A.
509 *
510 * @param[in] A
511 * The low-rank representation of the matrix A.
512 *
513 * @param[in] M2
514 * The number of rows of the matrix B.
515 *
516 * @param[in] N2
517 * The number of columns of the matrix B.
518 *
519 * @param[in] B
520 * The low-rank representation of the matrix B.
521 *
522 * @param[in] offx
523 * The horizontal offset of A with respect to B.
524 *
525 * @param[in] offy
526 * The vertical offset of A with respect to B.
527 *
528 *******************************************************************************
529 *
530 * @return The new rank of u2 v2^T or -1 if ranks are too large for
531 * recompression
532 *
533 *******************************************************************************/
536 pastix_trans_t transA1,
537 const void *alphaptr,
538 pastix_int_t M1,
539 pastix_int_t N1,
540 const pastix_lrblock_t *A,
541 pastix_int_t M2,
542 pastix_int_t N2,
544 pastix_int_t offx,
545 pastix_int_t offy)
546{
547 return core_srradd_qr( core_stqrcp, lowrank, transA1, alphaptr,
548 M1, N1, A, M2, N2, B, offx, offy );
549}
BEGIN_C_DECLS typedef int pastix_int_t
Definition datatypes.h:51
double pastix_fixdbl_t
Definition datatypes.h:65
int core_stqrcp(float tol, pastix_int_t maxrank, int refine, pastix_int_t nb, pastix_int_t m, pastix_int_t n, float *A, pastix_int_t lda, pastix_int_t *jpvt, float *tau, float *work, pastix_int_t lwork, float *rwork)
Compute a randomized QR factorization with truncated updates.
Definition core_stqrcp.c:99
int core_spqrcp(float tol, pastix_int_t maxrank, int full_update, pastix_int_t nb, pastix_int_t m, pastix_int_t n, float *A, pastix_int_t lda, pastix_int_t *jpvt, float *tau, float *work, pastix_int_t lwork, float *rwork)
Compute a rank-reavealing QR factorization.
Structure to define the type of function to use for the low-rank kernels and their parameters.
The block low-rank structure to hold a matrix in low-rank form.
pastix_fixdbl_t core_sge2lr_tqrcp(int use_reltol, pastix_fixdbl_t tol, pastix_int_t rklimit, pastix_int_t m, pastix_int_t n, const void *A, pastix_int_t lda, pastix_lrblock_t *Alr)
Convert a full rank matrix in a low rank matrix, using TQRCP.
pastix_fixdbl_t core_sge2lr_qrcp(core_srrqr_cp_t rrqrfct, int use_reltol, pastix_fixdbl_t tol, pastix_int_t rklimit, pastix_int_t m, pastix_int_t n, const void *Avoid, pastix_int_t lda, pastix_lrblock_t *Alr)
Template to convert a full rank matrix into a low rank matrix through QR decompositions.
pastix_fixdbl_t core_srradd_tqrcp(const pastix_lr_t *lowrank, pastix_trans_t transA1, const void *alphaptr, pastix_int_t M1, pastix_int_t N1, const pastix_lrblock_t *A, pastix_int_t M2, pastix_int_t N2, pastix_lrblock_t *B, pastix_int_t offx, pastix_int_t offy)
Add two LR structures A=(-u1) v1^T and B=u2 v2^T into u2 v2^T.
pastix_fixdbl_t core_srradd_qr(core_srrqr_cp_t rrqrfct, const pastix_lr_t *lowrank, pastix_trans_t transA1, const void *alphaptr, pastix_int_t M1, pastix_int_t N1, const pastix_lrblock_t *A, pastix_int_t M2, pastix_int_t N2, pastix_lrblock_t *B, pastix_int_t offx, pastix_int_t offy)
Template to perform the addition of two low-rank structures with compression kernel based on QR decom...
enum pastix_trans_e pastix_trans_t
Transpostion.
Manage nancheck for lowrank kernels. This header describes all the LAPACKE functions used for low-ran...