PaStiX Handbook 6.4.0
Loading...
Searching...
No Matches
core_slrothu.c
Go to the documentation of this file.
1/**
2 *
3 * @file core_slrothu.c
4 *
5 * PaStiX low-rank kernel routines to othogonalize the U matrix with QR approximations.
6 *
7 * @copyright 2016-2024 Bordeaux INP, CNRS (LaBRI UMR 5800), Inria,
8 * Univ. Bordeaux. All rights reserved.
9 *
10 * @version 6.4.0
11 * @author Alfredo Buttari
12 * @author Gregoire Pichon
13 * @author Esragul Korkmaz
14 * @author Mathieu Faverge
15 * @date 2024-07-05
16 * @generated from /builds/2mk6rsew/0/solverstack/pastix/kernels/core_zlrothu.c, normal z -> s, Tue Feb 25 14:34:56 2025
17 *
18 **/
19#include "common.h"
20#include <cblas.h>
21#include <lapacke.h>
22#include "flops.h"
23#include "kernels_trace.h"
24#include "blend/solver.h"
25#include "pastix_scores.h"
26#include "pastix_slrcores.h"
27#include "s_nan_check.h"
28#include "pastix_lowrank.h"
29
30#ifndef DOXYGEN_SHOULD_SKIP_THIS
31static float msone = -1.0;
32static float sone = 1.0;
33static float szero = 0.0;
34#endif /* DOXYGEN_SHOULD_SKIP_THIS */
35
36/**
37 *******************************************************************************
38 *
39 * @brief Try to orthognalize the u part of the low-rank form, and update the v
40 * part accordingly using full QR.
41 *
42 * This function considers a low-rank matrix resulting from the addition of two
43 * matrices B += A, with A of smaller or equal size to B.
44 * The product has the form: U * V^t
45 *
46 * The U part of the low-rank form must be orthognalized to get the smaller
47 * possible rank during the rradd operation. This function perfoms this by
48 * applying a full QR factorization on the U part.
49 *
50 * U = Q R, then U' = Q, and V' = R * V
51 *
52 *******************************************************************************
53 *
54 * @param[in] M
55 * The number of rows of the u1u2 matrix.
56 *
57 * @param[in] N
58 * The number of columns of the v1v2 matrix.
59 *
60 * @param[in] rank
61 * The number of columns of the U matrix, and the number of rows of the
62 * V part in the v1v2 matrix.
63 *
64 * @param[inout] U
65 * The U matrix of size ldu -by- rank. On exit, Q from U = Q R.
66 *
67 * @param[in] ldu
68 * The leading dimension of the U matrix. ldu >= max(1, M)
69 *
70 * @param[inout] V
71 * The V matrix of size ldv -by- N.
72 * On exit, R * V, with R from U = Q R.
73 *
74 * @param[in] ldv
75 * The leading dimension of the V matrix. ldv >= max(1, rank)
76 *
77 *******************************************************************************
78 *
79 * @return The number of flops required to perform the operation.
80 *
81 *******************************************************************************/
85 pastix_int_t rank,
86 float *U,
87 pastix_int_t ldu,
88 float *V,
89 pastix_int_t ldv )
90{
91 pastix_int_t minMK = pastix_imin( M, rank );
92 pastix_int_t lwork = M * 32 + minMK;
93 pastix_int_t ret;
94 float *W = malloc( lwork * sizeof(float) );
95 float *tau, *work;
96 pastix_fixdbl_t flops = 0.;
97
98 tau = W;
99 work = W + minMK;
100 lwork -= minMK;
101
102 assert( M >= rank );
103
104 /* Compute U = Q * R */
105 ret = LAPACKE_sgeqrf_work( LAPACK_COL_MAJOR, M, rank,
106 U, ldu, tau, work, lwork );
107 assert( ret == 0 );
108 flops += FLOPS_SGEQRF( M, rank );
109
110 /* Compute V' = R * V' */
111 cblas_strmm( CblasColMajor,
112 CblasLeft, CblasUpper,
113 CblasNoTrans, CblasNonUnit,
114 rank, N, (sone),
115 U, ldu, V, ldv );
116 flops += FLOPS_STRMM( PastixLeft, rank, N );
117
118 /* Generate the Q */
119 ret = LAPACKE_sorgqr_work( LAPACK_COL_MAJOR, M, rank, rank,
120 U, ldu, tau, work, lwork );
121 assert( ret == 0 );
122 flops += FLOPS_SORGQR( M, rank, rank );
123
124 free(W);
125
126 (void)ret;
127 return flops;
128}
129
130/**
131 *******************************************************************************
132 *
133 * @brief Try to orthognalize the U part of the low-rank form, and update the V
134 * part accordingly using partial QR.
135 *
136 * This function considers a low-rank matrix resulting from the addition of two
137 * matrices B += A, with A of smaller or equal size to B.
138 * The product has the form: U * V^t
139 *
140 * The U part of the low-rank form must be orthognalized to get the smaller
141 * possible rank during the rradd operation. This function perfoms this by
142 * applying a full QR factorization on the U part.
143 *
144 * In that case, it takes benefit from the fact that U = [ u1, u2 ], and V = [
145 * v1, v2 ] with u2 and v2 wich are matrices of respective size M2-by-r2, and
146 * r2-by-N2, offset by offx and offy
147 *
148 * The steps are:
149 * - Scaling of u2 with removal of the null columns
150 * - Orthogonalization of u2 relatively to u1
151 * - Application of the update to v2
152 * - orthogonalization through QR of u2
153 * - Update of V
154 *
155 *******************************************************************************
156 *
157 * @param[in] M
158 * The number of rows of the u1u2 matrix.
159 *
160 * @param[in] N
161 * The number of columns of the v1v2 matrix.
162 *
163 * @param[in] r1
164 * The number of columns of the U matrix in the u1 part, and the number
165 * of rows of the V part in the v1 part.
166 *
167 * @param[inout] r2ptr
168 * The number of columns of the U matrix in the u2 part, and the number
169 * of rows of the V part in the v2 part. On exit, this rank is reduced
170 * y the number of null columns found in U.
171 *
172 * @param[in] offx
173 * The row offset of the matrix u2 in U.
174 *
175 * @param[in] offy
176 * The column offset of the matrix v2 in V.
177 *
178 * @param[inout] U
179 * The U matrix of size ldu -by- rank. On exit, the orthogonalized U.
180 *
181 * @param[in] ldu
182 * The leading dimension of the U matrix. ldu >= max(1, M)
183 *
184 * @param[inout] V
185 * The V matrix of size ldv -by- N.
186 * On exit, the updated V matrix.
187 *
188 * @param[in] ldv
189 * The leading dimension of the V matrix. ldv >= max(1, rank)
190 *
191 *******************************************************************************
192 *
193 * @return The number of flops required to perform the operation.
194 *
195 *******************************************************************************/
198 pastix_int_t N,
199 pastix_int_t r1,
200 pastix_int_t *r2ptr,
201 pastix_int_t offx,
202 pastix_int_t offy,
203 float *U,
204 pastix_int_t ldu,
205 float *V,
206 pastix_int_t ldv )
207{
208 pastix_int_t r2 = *r2ptr;
209 pastix_int_t minMN = pastix_imin( M, r2 );
210 pastix_int_t ldwork = pastix_imax( r1 * r2, M * 32 + minMN );
211 pastix_int_t ret, i;
212 float *u1 = U;
213 float *u2 = U + r1 * ldu;
214 float *v1 = V;
215 float *v2 = V + r1;
216 float *W = malloc( ldwork * sizeof(float) );
217 float *tau, *work;
218 pastix_fixdbl_t flops = 0.;
219 float norm, eps;
220
221 tau = W;
222 work = W + minMN;
223 ldwork -= minMN;
224
225 eps = LAPACKE_slamch_work('e');
226
227 /* Scaling */
228 for (i=0; i<r2; i++, u2 += ldu, v2++) {
229 norm = cblas_snrm2( M, u2, 1 );
230 if ( norm > (M * eps) ) {
231 cblas_sscal( M, 1. / norm, u2, 1 );
232 cblas_sscal( N, norm, v2, ldv );
233 }
234 else {
235 if ( i < (r2-1) ) {
236 cblas_sswap( M, u2, 1, U + (r1+r2-1) * ldu, 1 );
237 memset( U + (r1+r2-1) * ldu, 0, M * sizeof(float) );
238
239 cblas_sswap( N, v2, ldv, V + (r1+r2-1), ldv );
240 ret = LAPACKE_slaset_work( LAPACK_COL_MAJOR, 'A', 1, N,
241 0., 0., V + (r1+r2-1), ldv );
242 assert( ret == 0 );
243 r2--;
244 i--;
245 u2-= ldu;
246 v2--;
247 }
248 else {
249 memset( u2, 0, M * sizeof(float) );
250 ret = LAPACKE_slaset_work( LAPACK_COL_MAJOR, 'A', 1, N,
251 0., 0., v2, ldv );
252 assert( ret == 0 );
253 r2--;
254 }
255 }
256 }
257 u2 = U + r1 * ldu;
258 v2 = V + r1;
259
260 *r2ptr = r2;
261
262 if ( r2 == 0 ) {
263 free( W );
264 return 0.;
265 }
266
267 /* Compute W = u1^t u2 */
268 cblas_sgemm( CblasColMajor, CblasTrans, CblasNoTrans,
269 r1, r2, M,
270 (sone), u1, ldu,
271 u2, ldu,
272 (szero), W, r1 );
273 flops += FLOPS_SGEMM( r1, r2, M );
274
275 /* Compute u2 = u2 - u1 ( u1^t u2 ) = u2 - u1 * W */
276 cblas_sgemm( CblasColMajor, CblasNoTrans, CblasNoTrans,
277 M, r2, r1,
278 (msone), u1, ldu,
279 W, r1,
280 (sone), u2, ldu );
281 flops += FLOPS_SGEMM( M, r2, r1 );
282
283 /* Update v1 = v1 + ( u1^t u2 ) v2 = v1 + W * v2 */
284 cblas_sgemm( CblasColMajor, CblasNoTrans, CblasNoTrans,
285 r1, N, r2,
286 (sone), W, r1,
287 v2, ldv,
288 (sone), v1, ldv );
289 flops += FLOPS_SGEMM( r1, N, r2 );
290
291#if !defined(PASTIX_LR_CGS1)
292 /* Compute W = u1^t u2 */
293 cblas_sgemm( CblasColMajor, CblasTrans, CblasNoTrans,
294 r1, r2, M,
295 (sone), u1, ldu,
296 u2, ldu,
297 (szero), W, r1 );
298 flops += FLOPS_SGEMM( r1, r2, M );
299
300 /* Compute u2 = u2 - u1 ( u1^t u2 ) = u2 - u1 * W */
301 cblas_sgemm( CblasColMajor, CblasNoTrans, CblasNoTrans,
302 M, r2, r1,
303 (msone), u1, ldu,
304 W, r1,
305 (sone), u2, ldu );
306 flops += FLOPS_SGEMM( M, r2, r1 );
307
308 /* Update v1 = v1 + ( u1^t u2 ) v2 = v1 + W * v2 */
309 cblas_sgemm( CblasColMajor, CblasNoTrans, CblasNoTrans,
310 r1, N, r2,
311 (sone), W, r1,
312 v2, ldv,
313 (sone), v1, ldv );
314 flops += FLOPS_SGEMM( r1, N, r2 );
315#endif
316
317#if defined(PASTIX_DEBUG_LR)
318 if ( core_slrdbg_check_orthogonality_AB( M, r1, r2, u1, ldu, u2, ldu ) != 0 ) {
319 fprintf(stderr, "partialQR: u2 not correctly projected with u1\n" );
320 }
321#endif
322
323 /* Compute u2 = Q * R */
324 ret = LAPACKE_sgeqrf_work( LAPACK_COL_MAJOR, M, r2,
325 u2, ldu, tau, work, ldwork );
326 assert( ret == 0 );
327 flops += FLOPS_SGEQRF( M, r2 );
328
329 /* Compute v2' = R * v2 */
330 cblas_strmm( CblasColMajor,
331 CblasLeft, CblasUpper,
332 CblasNoTrans, CblasNonUnit,
333 r2, N, (sone),
334 u2, ldu, v2, ldv);
335 flops += FLOPS_STRMM( PastixLeft, r2, N );
336
337 /* Generate the Q */
338 ret = LAPACKE_sorgqr_work( LAPACK_COL_MAJOR, M, r2, r2,
339 u2, ldu, tau, work, ldwork );
340 assert( ret == 0 );
341 flops += FLOPS_SORGQR( M, r2, r2 );
342
343#if defined(PASTIX_DEBUG_LR)
344 if ( core_slrdbg_check_orthogonality_AB( M, r1, r2, u1, ldu, u2, ldu ) != 0 ) {
345 fprintf(stderr, "partialQR: Final u2 not orthogonal to u1\n" );
346 }
347#endif
348
349 free( W );
350
351 (void)ret;
352 (void)offx;
353 (void)offy;
354
355 return flops;
356}
357
358/**
359 *******************************************************************************
360 *
361 * @brief Try to orthognalize the U part of the low-rank form, and update the V
362 * part accordingly using CGS.
363 *
364 * This function considers a low-rank matrix resulting from the addition of two
365 * matrices B += A, with A of smaller or equal size to B.
366 * The product has the form: U * V^t
367 *
368 * The U part of the low-rank form must be orthognalized to get the smaller
369 * possible rank during the rradd operation. This function perfoms this by
370 * applying a full QR factorization on the U part.
371 *
372 * In that case, it takes benefit from the fact that U = [ u1, u2 ], and V = [
373 * v1, v2 ] with u2 and v2 wich are matrices of respective size M2-by-r2, and
374 * r2-by-N2, offset by offx and offy
375 *
376 * The steps are:
377 * - for each column of u2
378 * - Scaling of u2 with removal of the null columns
379 * - Orthogonalization of u2 relatively to u1
380 * - Remove the column if null
381 *
382 *******************************************************************************
383 *
384 * @param[in] M1
385 * The number of rows of the U matrix.
386 *
387 * @param[in] N1
388 * The number of columns of the U matrix.
389 *
390 * @param[in] M2
391 * The number of rows of the u2 part of the U matrix.
392 *
393 * @param[in] N2
394 * The number of columns of the v2 part of the V matrix.
395 *
396 * @param[in] r1
397 * The number of columns of the U matrix in the u1 part, and the number
398 * of rows of the V part in the v1 part.
399 *
400 * @param[inout] r2ptr
401 * The number of columns of the U matrix in the u2 part, and the number
402 * of rows of the V part in the v2 part. On exit, this rank is reduced
403 * y the number of null columns found in U.
404 *
405 * @param[in] offx
406 * The row offset of the matrix u2 in U.
407 *
408 * @param[in] offy
409 * The column offset of the matrix v2 in V.
410 *
411 * @param[inout] U
412 * The U matrix of size ldu -by- rank. On exit, the orthogonalized U.
413 *
414 * @param[in] ldu
415 * The leading dimension of the U matrix. ldu >= max(1, M)
416 *
417 * @param[inout] V
418 * The V matrix of size ldv -by- N.
419 * On exit, the updated V matrix.
420 *
421 * @param[in] ldv
422 * The leading dimension of the V matrix. ldv >= max(1, rank)
423 *
424 *******************************************************************************
425 *
426 * @return The number of flops required to perform the operation.
427 *
428 *******************************************************************************/
431 pastix_int_t N1,
432 pastix_int_t M2,
433 pastix_int_t N2,
434 pastix_int_t r1,
435 pastix_int_t *r2ptr,
436 pastix_int_t offx,
437 pastix_int_t offy,
438 float *U,
439 pastix_int_t ldu,
440 float *V,
441 pastix_int_t ldv )
442{
443 pastix_int_t r2 = *r2ptr;
444 float *u1 = U;
445 float *u2 = U + r1 * ldu;
446 float *v1 = V;
447 float *v2 = V + r1;
448 float *W;
449 pastix_fixdbl_t flops = 0.0;
450 pastix_int_t i, rank = r1 + r2;
451 pastix_int_t ldwork = rank;
452 pastix_int_t ret;
453 float eps, norm;
454 float norm_before, alpha;
455
456 assert( M1 >= (M2 + offx) );
457 assert( N1 >= (N2 + offy) );
458
459 W = malloc(ldwork * sizeof(float));
460 eps = LAPACKE_slamch_work( 'e' );
461 alpha = 1. / sqrtf(2);
462
463 /* Classical Gram-Schmidt */
464 for (i=r1; i<rank; i++, u2 += ldu, v2++) {
465
466 norm = cblas_snrm2( M2, u2 + offx, 1 );
467 if ( norm > ( M2 * eps ) ) {
468 cblas_sscal( M2, 1. / norm, u2 + offx, 1 );
469 cblas_sscal( N2, norm, v2 + offy * ldv, ldv );
470 }
471 else {
472 rank--; r2--;
473 if ( i < rank ) {
474 cblas_sswap( M2, u2 + offx, 1, U + rank * ldu + offx, 1 );
475#if !defined(NDEBUG)
476 memset( U + rank * ldu, 0, M1 * sizeof(float) );
477#endif
478
479 cblas_sswap( N2, v2 + offy * ldv, ldv, V + offy * ldv + rank, ldv );
480
481#if !defined(NDEBUG)
482 ret = LAPACKE_slaset_work( LAPACK_COL_MAJOR, 'A', 1, N1,
483 0., 0., V + rank, ldv );
484 assert( ret == 0 );
485#endif
486 i--;
487 u2-= ldu;
488 v2--;
489 }
490#if !defined(NDEBUG)
491 else {
492 memset( u2, 0, M1 * sizeof(float) );
493 ret = LAPACKE_slaset_work( LAPACK_COL_MAJOR, 'A', 1, N1,
494 0., 0., v2, ldv );
495 assert( ret == 0 );
496 }
497#endif
498 continue;
499 }
500
501 /* Compute W = u1^t u2 */
502 cblas_sgemv( CblasColMajor, CblasTrans,
503 M2, i,
504 (sone), u1+offx, ldu,
505 u2+offx, 1,
506 (szero), W, 1 );
507 flops += FLOPS_SGEMM( M2, i, 1 );
508
509 /* Compute u2 = u2 - u1 ( u1^t u2 ) = u2 - u1 * W */
510 cblas_sgemv( CblasColMajor, CblasNoTrans,
511 M1, i,
512 (msone), u1, ldu,
513 W, 1,
514 (sone), u2, 1 );
515 flops += FLOPS_SGEMM( M1, i, 1 );
516
517 /* Update v1 = v1 + ( u1^t u2 ) v2 = v1 + W * v2 */
518 cblas_sgemm( CblasColMajor, CblasNoTrans, CblasNoTrans,
519 i, N1, 1,
520 (sone), W, i,
521 v2, ldv,
522 (sone), v1, ldv );
523 flops += FLOPS_SGEMM( i, N1, 1 );
524
525 norm_before = cblas_snrm2( i, W, 1 );
526 norm = cblas_snrm2( M1, u2, 1 );
527
528#if !defined(PASTIX_LR_CGS1)
529 if ( norm <= (alpha * norm_before) ){
530 /* Compute W = u1^t u2 */
531 cblas_sgemv( CblasColMajor, CblasTrans,
532 M1, i,
533 (sone), u1, ldu,
534 u2, 1,
535 (szero), W, 1 );
536 flops += FLOPS_SGEMM( M1, i, 1 );
537
538 /* Compute u2 = u2 - u1 ( u1^t u2 ) = u2 - u1 * W */
539 cblas_sgemv( CblasColMajor, CblasNoTrans,
540 M1, i,
541 (msone), u1, ldu,
542 W, 1,
543 (sone), u2, 1 );
544 flops += FLOPS_SGEMM( M1, i, 1 );
545
546 /* Update v1 = v1 + ( u1^t u2 ) v2 = v1 + W * v2 */
547 cblas_sgemm( CblasColMajor, CblasNoTrans, CblasNoTrans,
548 i, N1, 1,
549 (sone), W, i,
550 v2, ldv,
551 (sone), v1, ldv );
552 flops += FLOPS_SGEMM( i, N1, 1 );
553
554 norm = cblas_snrm2( M1, u2, 1 );
555 }
556#endif
557
558 if ( norm > M1 * eps ) {
559 cblas_sscal( M1, 1. / norm, u2, 1 );
560 cblas_sscal( N1, norm, v2, ldv );
561 }
562 else {
563 rank--; r2--;
564 if ( i < rank ) {
565 cblas_sswap( M1, u2, 1, U + rank * ldu, 1 );
566 memset( U + rank * ldu, 0, M1 * sizeof(float) );
567
568 cblas_sswap( N1, v2, ldv, V + rank, ldv );
569 ret = LAPACKE_slaset_work( LAPACK_COL_MAJOR, 'A', 1, N1,
570 0., 0., V + rank, ldv );
571 assert( ret == 0 );
572 i--;
573 u2-= ldu;
574 v2--;
575 }
576 else {
577 memset( u2, 0, M1 * sizeof(float) );
578 ret = LAPACKE_slaset_work( LAPACK_COL_MAJOR, 'A', 1, N1,
579 0., 0., v2, ldv );
580 assert( ret == 0 );
581 }
582 }
583 }
584 free(W);
585
586#if defined(PASTIX_DEBUG_LR)
587 {
588 u2 = U + r1 * ldu;
589 if ( core_slrdbg_check_orthogonality_AB( M1, r1, r2, u1, ldu, u2, ldu ) != 0 ) {
590 fprintf(stderr, "cgs: Final u2 not orthogonal to u1\n" );
591 }
592 }
593#endif
594
595 *r2ptr = r2;
596
597 (void)offy;
598 (void)N2;
599 (void)ret;
600 return flops;
601}
BEGIN_C_DECLS typedef int pastix_int_t
Definition datatypes.h:51
double pastix_fixdbl_t
Definition datatypes.h:65
pastix_fixdbl_t core_slrorthu_fullqr(pastix_int_t M, pastix_int_t N, pastix_int_t rank, float *U, pastix_int_t ldu, float *V, pastix_int_t ldv)
Try to orthognalize the u part of the low-rank form, and update the v part accordingly using full QR.
pastix_fixdbl_t core_slrorthu_cgs(pastix_int_t M1, pastix_int_t N1, pastix_int_t M2, pastix_int_t N2, pastix_int_t r1, pastix_int_t *r2ptr, pastix_int_t offx, pastix_int_t offy, float *U, pastix_int_t ldu, float *V, pastix_int_t ldv)
Try to orthognalize the U part of the low-rank form, and update the V part accordingly using CGS.
pastix_fixdbl_t core_slrorthu_partialqr(pastix_int_t M, pastix_int_t N, pastix_int_t r1, pastix_int_t *r2ptr, pastix_int_t offx, pastix_int_t offy, float *U, pastix_int_t ldu, float *V, pastix_int_t ldv)
Try to orthognalize the U part of the low-rank form, and update the V part accordingly using partial ...
int core_slrdbg_check_orthogonality_AB(pastix_int_t M, pastix_int_t NA, pastix_int_t NB, const float *A, pastix_int_t lda, const float *B, pastix_int_t ldb)
Check the orthogonality of the matrix A relatively to the matrix B.
@ PastixLeft
Definition api.h:495
Manage nancheck for lowrank kernels. This header describes all the LAPACKE functions used for low-ran...