PaStiX Handbook  6.4.0
core_slrothu.c
Go to the documentation of this file.
1 /**
2  *
3  * @file core_slrothu.c
4  *
5  * PaStiX low-rank kernel routines to othogonalize the U matrix with QR approximations.
6  *
7  * @copyright 2016-2024 Bordeaux INP, CNRS (LaBRI UMR 5800), Inria,
8  * Univ. Bordeaux. All rights reserved.
9  *
10  * @version 6.4.0
11  * @author Alfredo Buttari
12  * @author Gregoire Pichon
13  * @author Esragul Korkmaz
14  * @author Mathieu Faverge
15  * @date 2024-07-05
16  * @generated from /builds/solverstack/pastix/kernels/core_zlrothu.c, normal z -> s, Tue Oct 8 14:17:22 2024
17  *
18  **/
19 #include "common.h"
20 #include <cblas.h>
21 #include <lapacke.h>
22 #include "flops.h"
23 #include "kernels_trace.h"
24 #include "blend/solver.h"
25 #include "pastix_scores.h"
26 #include "pastix_slrcores.h"
27 #include "s_nan_check.h"
28 #include "pastix_lowrank.h"
29 
30 #ifndef DOXYGEN_SHOULD_SKIP_THIS
31 static float msone = -1.0;
32 static float sone = 1.0;
33 static float szero = 0.0;
34 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
35 
36 /**
37  *******************************************************************************
38  *
39  * @brief Try to orthognalize the u part of the low-rank form, and update the v
40  * part accordingly using full QR.
41  *
42  * This function considers a low-rank matrix resulting from the addition of two
43  * matrices B += A, with A of smaller or equal size to B.
44  * The product has the form: U * V^t
45  *
46  * The U part of the low-rank form must be orthognalized to get the smaller
47  * possible rank during the rradd operation. This function perfoms this by
48  * applying a full QR factorization on the U part.
49  *
50  * U = Q R, then U' = Q, and V' = R * V
51  *
52  *******************************************************************************
53  *
54  * @param[in] M
55  * The number of rows of the u1u2 matrix.
56  *
57  * @param[in] N
58  * The number of columns of the v1v2 matrix.
59  *
60  * @param[in] rank
61  * The number of columns of the U matrix, and the number of rows of the
62  * V part in the v1v2 matrix.
63  *
64  * @param[inout] U
65  * The U matrix of size ldu -by- rank. On exit, Q from U = Q R.
66  *
67  * @param[in] ldu
68  * The leading dimension of the U matrix. ldu >= max(1, M)
69  *
70  * @param[inout] V
71  * The V matrix of size ldv -by- N.
72  * On exit, R * V, with R from U = Q R.
73  *
74  * @param[in] ldv
75  * The leading dimension of the V matrix. ldv >= max(1, rank)
76  *
77  *******************************************************************************
78  *
79  * @return The number of flops required to perform the operation.
80  *
81  *******************************************************************************/
84  pastix_int_t N,
85  pastix_int_t rank,
86  float *U,
87  pastix_int_t ldu,
88  float *V,
89  pastix_int_t ldv )
90 {
91  pastix_int_t minMK = pastix_imin( M, rank );
92  pastix_int_t lwork = M * 32 + minMK;
93  pastix_int_t ret;
94  float *W = malloc( lwork * sizeof(float) );
95  float *tau, *work;
96  pastix_fixdbl_t flops = 0.;
97 
98  tau = W;
99  work = W + minMK;
100  lwork -= minMK;
101 
102  assert( M >= rank );
103 
104  /* Compute U = Q * R */
105  ret = LAPACKE_sgeqrf_work( LAPACK_COL_MAJOR, M, rank,
106  U, ldu, tau, work, lwork );
107  assert( ret == 0 );
108  flops += FLOPS_SGEQRF( M, rank );
109 
110  /* Compute V' = R * V' */
111  cblas_strmm( CblasColMajor,
112  CblasLeft, CblasUpper,
113  CblasNoTrans, CblasNonUnit,
114  rank, N, (sone),
115  U, ldu, V, ldv );
116  flops += FLOPS_STRMM( PastixLeft, rank, N );
117 
118  /* Generate the Q */
119  ret = LAPACKE_sorgqr_work( LAPACK_COL_MAJOR, M, rank, rank,
120  U, ldu, tau, work, lwork );
121  assert( ret == 0 );
122  flops += FLOPS_SORGQR( M, rank, rank );
123 
124  free(W);
125 
126  (void)ret;
127  return flops;
128 }
129 
130 /**
131  *******************************************************************************
132  *
133  * @brief Try to orthognalize the U part of the low-rank form, and update the V
134  * part accordingly using partial QR.
135  *
136  * This function considers a low-rank matrix resulting from the addition of two
137  * matrices B += A, with A of smaller or equal size to B.
138  * The product has the form: U * V^t
139  *
140  * The U part of the low-rank form must be orthognalized to get the smaller
141  * possible rank during the rradd operation. This function perfoms this by
142  * applying a full QR factorization on the U part.
143  *
144  * In that case, it takes benefit from the fact that U = [ u1, u2 ], and V = [
145  * v1, v2 ] with u2 and v2 wich are matrices of respective size M2-by-r2, and
146  * r2-by-N2, offset by offx and offy
147  *
148  * The steps are:
149  * - Scaling of u2 with removal of the null columns
150  * - Orthogonalization of u2 relatively to u1
151  * - Application of the update to v2
152  * - orthogonalization through QR of u2
153  * - Update of V
154  *
155  *******************************************************************************
156  *
157  * @param[in] M
158  * The number of rows of the u1u2 matrix.
159  *
160  * @param[in] N
161  * The number of columns of the v1v2 matrix.
162  *
163  * @param[in] r1
164  * The number of columns of the U matrix in the u1 part, and the number
165  * of rows of the V part in the v1 part.
166  *
167  * @param[inout] r2ptr
168  * The number of columns of the U matrix in the u2 part, and the number
169  * of rows of the V part in the v2 part. On exit, this rank is reduced
170  * y the number of null columns found in U.
171  *
172  * @param[in] offx
173  * The row offset of the matrix u2 in U.
174  *
175  * @param[in] offy
176  * The column offset of the matrix v2 in V.
177  *
178  * @param[inout] U
179  * The U matrix of size ldu -by- rank. On exit, the orthogonalized U.
180  *
181  * @param[in] ldu
182  * The leading dimension of the U matrix. ldu >= max(1, M)
183  *
184  * @param[inout] V
185  * The V matrix of size ldv -by- N.
186  * On exit, the updated V matrix.
187  *
188  * @param[in] ldv
189  * The leading dimension of the V matrix. ldv >= max(1, rank)
190  *
191  *******************************************************************************
192  *
193  * @return The number of flops required to perform the operation.
194  *
195  *******************************************************************************/
198  pastix_int_t N,
199  pastix_int_t r1,
200  pastix_int_t *r2ptr,
201  pastix_int_t offx,
202  pastix_int_t offy,
203  float *U,
204  pastix_int_t ldu,
205  float *V,
206  pastix_int_t ldv )
207 {
208  pastix_int_t r2 = *r2ptr;
209  pastix_int_t minMN = pastix_imin( M, r2 );
210  pastix_int_t ldwork = pastix_imax( r1 * r2, M * 32 + minMN );
211  pastix_int_t ret, i;
212  float *u1 = U;
213  float *u2 = U + r1 * ldu;
214  float *v1 = V;
215  float *v2 = V + r1;
216  float *W = malloc( ldwork * sizeof(float) );
217  float *tau, *work;
218  pastix_fixdbl_t flops = 0.;
219  float norm, eps;
220 
221  tau = W;
222  work = W + minMN;
223  ldwork -= minMN;
224 
225  eps = LAPACKE_slamch_work('e');
226 
227  /* Scaling */
228  for (i=0; i<r2; i++, u2 += ldu, v2++) {
229  norm = cblas_snrm2( M, u2, 1 );
230  if ( norm > (M * eps) ) {
231  cblas_sscal( M, 1. / norm, u2, 1 );
232  cblas_sscal( N, norm, v2, ldv );
233  }
234  else {
235  if ( i < (r2-1) ) {
236  cblas_sswap( M, u2, 1, U + (r1+r2-1) * ldu, 1 );
237  memset( U + (r1+r2-1) * ldu, 0, M * sizeof(float) );
238 
239  cblas_sswap( N, v2, ldv, V + (r1+r2-1), ldv );
240  ret = LAPACKE_slaset_work( LAPACK_COL_MAJOR, 'A', 1, N,
241  0., 0., V + (r1+r2-1), ldv );
242  assert( ret == 0 );
243  r2--;
244  i--;
245  u2-= ldu;
246  v2--;
247  }
248  else {
249  memset( u2, 0, M * sizeof(float) );
250  ret = LAPACKE_slaset_work( LAPACK_COL_MAJOR, 'A', 1, N,
251  0., 0., v2, ldv );
252  assert( ret == 0 );
253  r2--;
254  }
255  }
256  }
257  u2 = U + r1 * ldu;
258  v2 = V + r1;
259 
260  *r2ptr = r2;
261 
262  if ( r2 == 0 ) {
263  free( W );
264  return 0.;
265  }
266 
267  /* Compute W = u1^t u2 */
268  cblas_sgemm( CblasColMajor, CblasTrans, CblasNoTrans,
269  r1, r2, M,
270  (sone), u1, ldu,
271  u2, ldu,
272  (szero), W, r1 );
273  flops += FLOPS_SGEMM( r1, r2, M );
274 
275  /* Compute u2 = u2 - u1 ( u1^t u2 ) = u2 - u1 * W */
276  cblas_sgemm( CblasColMajor, CblasNoTrans, CblasNoTrans,
277  M, r2, r1,
278  (msone), u1, ldu,
279  W, r1,
280  (sone), u2, ldu );
281  flops += FLOPS_SGEMM( M, r2, r1 );
282 
283  /* Update v1 = v1 + ( u1^t u2 ) v2 = v1 + W * v2 */
284  cblas_sgemm( CblasColMajor, CblasNoTrans, CblasNoTrans,
285  r1, N, r2,
286  (sone), W, r1,
287  v2, ldv,
288  (sone), v1, ldv );
289  flops += FLOPS_SGEMM( r1, N, r2 );
290 
291 #if !defined(PASTIX_LR_CGS1)
292  /* Compute W = u1^t u2 */
293  cblas_sgemm( CblasColMajor, CblasTrans, CblasNoTrans,
294  r1, r2, M,
295  (sone), u1, ldu,
296  u2, ldu,
297  (szero), W, r1 );
298  flops += FLOPS_SGEMM( r1, r2, M );
299 
300  /* Compute u2 = u2 - u1 ( u1^t u2 ) = u2 - u1 * W */
301  cblas_sgemm( CblasColMajor, CblasNoTrans, CblasNoTrans,
302  M, r2, r1,
303  (msone), u1, ldu,
304  W, r1,
305  (sone), u2, ldu );
306  flops += FLOPS_SGEMM( M, r2, r1 );
307 
308  /* Update v1 = v1 + ( u1^t u2 ) v2 = v1 + W * v2 */
309  cblas_sgemm( CblasColMajor, CblasNoTrans, CblasNoTrans,
310  r1, N, r2,
311  (sone), W, r1,
312  v2, ldv,
313  (sone), v1, ldv );
314  flops += FLOPS_SGEMM( r1, N, r2 );
315 #endif
316 
317 #if defined(PASTIX_DEBUG_LR)
318  if ( core_slrdbg_check_orthogonality_AB( M, r1, r2, u1, ldu, u2, ldu ) != 0 ) {
319  fprintf(stderr, "partialQR: u2 not correctly projected with u1\n" );
320  }
321 #endif
322 
323  /* Compute u2 = Q * R */
324  ret = LAPACKE_sgeqrf_work( LAPACK_COL_MAJOR, M, r2,
325  u2, ldu, tau, work, ldwork );
326  assert( ret == 0 );
327  flops += FLOPS_SGEQRF( M, r2 );
328 
329  /* Compute v2' = R * v2 */
330  cblas_strmm( CblasColMajor,
331  CblasLeft, CblasUpper,
332  CblasNoTrans, CblasNonUnit,
333  r2, N, (sone),
334  u2, ldu, v2, ldv);
335  flops += FLOPS_STRMM( PastixLeft, r2, N );
336 
337  /* Generate the Q */
338  ret = LAPACKE_sorgqr_work( LAPACK_COL_MAJOR, M, r2, r2,
339  u2, ldu, tau, work, ldwork );
340  assert( ret == 0 );
341  flops += FLOPS_SORGQR( M, r2, r2 );
342 
343 #if defined(PASTIX_DEBUG_LR)
344  if ( core_slrdbg_check_orthogonality_AB( M, r1, r2, u1, ldu, u2, ldu ) != 0 ) {
345  fprintf(stderr, "partialQR: Final u2 not orthogonal to u1\n" );
346  }
347 #endif
348 
349  free( W );
350 
351  (void)ret;
352  (void)offx;
353  (void)offy;
354 
355  return flops;
356 }
357 
358 /**
359  *******************************************************************************
360  *
361  * @brief Try to orthognalize the U part of the low-rank form, and update the V
362  * part accordingly using CGS.
363  *
364  * This function considers a low-rank matrix resulting from the addition of two
365  * matrices B += A, with A of smaller or equal size to B.
366  * The product has the form: U * V^t
367  *
368  * The U part of the low-rank form must be orthognalized to get the smaller
369  * possible rank during the rradd operation. This function perfoms this by
370  * applying a full QR factorization on the U part.
371  *
372  * In that case, it takes benefit from the fact that U = [ u1, u2 ], and V = [
373  * v1, v2 ] with u2 and v2 wich are matrices of respective size M2-by-r2, and
374  * r2-by-N2, offset by offx and offy
375  *
376  * The steps are:
377  * - for each column of u2
378  * - Scaling of u2 with removal of the null columns
379  * - Orthogonalization of u2 relatively to u1
380  * - Remove the column if null
381  *
382  *******************************************************************************
383  *
384  * @param[in] M1
385  * The number of rows of the U matrix.
386  *
387  * @param[in] N1
388  * The number of columns of the U matrix.
389  *
390  * @param[in] M2
391  * The number of rows of the u2 part of the U matrix.
392  *
393  * @param[in] N2
394  * The number of columns of the v2 part of the V matrix.
395  *
396  * @param[in] r1
397  * The number of columns of the U matrix in the u1 part, and the number
398  * of rows of the V part in the v1 part.
399  *
400  * @param[inout] r2ptr
401  * The number of columns of the U matrix in the u2 part, and the number
402  * of rows of the V part in the v2 part. On exit, this rank is reduced
403  * y the number of null columns found in U.
404  *
405  * @param[in] offx
406  * The row offset of the matrix u2 in U.
407  *
408  * @param[in] offy
409  * The column offset of the matrix v2 in V.
410  *
411  * @param[inout] U
412  * The U matrix of size ldu -by- rank. On exit, the orthogonalized U.
413  *
414  * @param[in] ldu
415  * The leading dimension of the U matrix. ldu >= max(1, M)
416  *
417  * @param[inout] V
418  * The V matrix of size ldv -by- N.
419  * On exit, the updated V matrix.
420  *
421  * @param[in] ldv
422  * The leading dimension of the V matrix. ldv >= max(1, rank)
423  *
424  *******************************************************************************
425  *
426  * @return The number of flops required to perform the operation.
427  *
428  *******************************************************************************/
431  pastix_int_t N1,
432  pastix_int_t M2,
433  pastix_int_t N2,
434  pastix_int_t r1,
435  pastix_int_t *r2ptr,
436  pastix_int_t offx,
437  pastix_int_t offy,
438  float *U,
439  pastix_int_t ldu,
440  float *V,
441  pastix_int_t ldv )
442 {
443  pastix_int_t r2 = *r2ptr;
444  float *u1 = U;
445  float *u2 = U + r1 * ldu;
446  float *v1 = V;
447  float *v2 = V + r1;
448  float *W;
449  pastix_fixdbl_t flops = 0.0;
450  pastix_int_t i, rank = r1 + r2;
451  pastix_int_t ldwork = rank;
452  pastix_int_t ret;
453  float eps, norm;
454  float norm_before, alpha;
455 
456  assert( M1 >= (M2 + offx) );
457  assert( N1 >= (N2 + offy) );
458 
459  W = malloc(ldwork * sizeof(float));
460  eps = LAPACKE_slamch_work( 'e' );
461  alpha = 1. / sqrtf(2);
462 
463  /* Classical Gram-Schmidt */
464  for (i=r1; i<rank; i++, u2 += ldu, v2++) {
465 
466  norm = cblas_snrm2( M2, u2 + offx, 1 );
467  if ( norm > ( M2 * eps ) ) {
468  cblas_sscal( M2, 1. / norm, u2 + offx, 1 );
469  cblas_sscal( N2, norm, v2 + offy * ldv, ldv );
470  }
471  else {
472  rank--; r2--;
473  if ( i < rank ) {
474  cblas_sswap( M2, u2 + offx, 1, U + rank * ldu + offx, 1 );
475 #if !defined(NDEBUG)
476  memset( U + rank * ldu, 0, M1 * sizeof(float) );
477 #endif
478 
479  cblas_sswap( N2, v2 + offy * ldv, ldv, V + offy * ldv + rank, ldv );
480 
481 #if !defined(NDEBUG)
482  ret = LAPACKE_slaset_work( LAPACK_COL_MAJOR, 'A', 1, N1,
483  0., 0., V + rank, ldv );
484  assert( ret == 0 );
485 #endif
486  i--;
487  u2-= ldu;
488  v2--;
489  }
490 #if !defined(NDEBUG)
491  else {
492  memset( u2, 0, M1 * sizeof(float) );
493  ret = LAPACKE_slaset_work( LAPACK_COL_MAJOR, 'A', 1, N1,
494  0., 0., v2, ldv );
495  assert( ret == 0 );
496  }
497 #endif
498  continue;
499  }
500 
501  /* Compute W = u1^t u2 */
502  cblas_sgemv( CblasColMajor, CblasTrans,
503  M2, i,
504  (sone), u1+offx, ldu,
505  u2+offx, 1,
506  (szero), W, 1 );
507  flops += FLOPS_SGEMM( M2, i, 1 );
508 
509  /* Compute u2 = u2 - u1 ( u1^t u2 ) = u2 - u1 * W */
510  cblas_sgemv( CblasColMajor, CblasNoTrans,
511  M1, i,
512  (msone), u1, ldu,
513  W, 1,
514  (sone), u2, 1 );
515  flops += FLOPS_SGEMM( M1, i, 1 );
516 
517  /* Update v1 = v1 + ( u1^t u2 ) v2 = v1 + W * v2 */
518  cblas_sgemm( CblasColMajor, CblasNoTrans, CblasNoTrans,
519  i, N1, 1,
520  (sone), W, i,
521  v2, ldv,
522  (sone), v1, ldv );
523  flops += FLOPS_SGEMM( i, N1, 1 );
524 
525  norm_before = cblas_snrm2( i, W, 1 );
526  norm = cblas_snrm2( M1, u2, 1 );
527 
528 #if !defined(PASTIX_LR_CGS1)
529  if ( norm <= (alpha * norm_before) ){
530  /* Compute W = u1^t u2 */
531  cblas_sgemv( CblasColMajor, CblasTrans,
532  M1, i,
533  (sone), u1, ldu,
534  u2, 1,
535  (szero), W, 1 );
536  flops += FLOPS_SGEMM( M1, i, 1 );
537 
538  /* Compute u2 = u2 - u1 ( u1^t u2 ) = u2 - u1 * W */
539  cblas_sgemv( CblasColMajor, CblasNoTrans,
540  M1, i,
541  (msone), u1, ldu,
542  W, 1,
543  (sone), u2, 1 );
544  flops += FLOPS_SGEMM( M1, i, 1 );
545 
546  /* Update v1 = v1 + ( u1^t u2 ) v2 = v1 + W * v2 */
547  cblas_sgemm( CblasColMajor, CblasNoTrans, CblasNoTrans,
548  i, N1, 1,
549  (sone), W, i,
550  v2, ldv,
551  (sone), v1, ldv );
552  flops += FLOPS_SGEMM( i, N1, 1 );
553 
554  norm = cblas_snrm2( M1, u2, 1 );
555  }
556 #endif
557 
558  if ( norm > M1 * eps ) {
559  cblas_sscal( M1, 1. / norm, u2, 1 );
560  cblas_sscal( N1, norm, v2, ldv );
561  }
562  else {
563  rank--; r2--;
564  if ( i < rank ) {
565  cblas_sswap( M1, u2, 1, U + rank * ldu, 1 );
566  memset( U + rank * ldu, 0, M1 * sizeof(float) );
567 
568  cblas_sswap( N1, v2, ldv, V + rank, ldv );
569  ret = LAPACKE_slaset_work( LAPACK_COL_MAJOR, 'A', 1, N1,
570  0., 0., V + rank, ldv );
571  assert( ret == 0 );
572  i--;
573  u2-= ldu;
574  v2--;
575  }
576  else {
577  memset( u2, 0, M1 * sizeof(float) );
578  ret = LAPACKE_slaset_work( LAPACK_COL_MAJOR, 'A', 1, N1,
579  0., 0., v2, ldv );
580  assert( ret == 0 );
581  }
582  }
583  }
584  free(W);
585 
586 #if defined(PASTIX_DEBUG_LR)
587  {
588  u2 = U + r1 * ldu;
589  if ( core_slrdbg_check_orthogonality_AB( M1, r1, r2, u1, ldu, u2, ldu ) != 0 ) {
590  fprintf(stderr, "cgs: Final u2 not orthogonal to u1\n" );
591  }
592  }
593 #endif
594 
595  *r2ptr = r2;
596 
597  (void)offy;
598  (void)N2;
599  (void)ret;
600  return flops;
601 }
BEGIN_C_DECLS typedef int pastix_int_t
Definition: datatypes.h:51
double pastix_fixdbl_t
Definition: datatypes.h:65
pastix_fixdbl_t core_slrorthu_fullqr(pastix_int_t M, pastix_int_t N, pastix_int_t rank, float *U, pastix_int_t ldu, float *V, pastix_int_t ldv)
Try to orthognalize the u part of the low-rank form, and update the v part accordingly using full QR.
Definition: core_slrothu.c:83
pastix_fixdbl_t core_slrorthu_cgs(pastix_int_t M1, pastix_int_t N1, pastix_int_t M2, pastix_int_t N2, pastix_int_t r1, pastix_int_t *r2ptr, pastix_int_t offx, pastix_int_t offy, float *U, pastix_int_t ldu, float *V, pastix_int_t ldv)
Try to orthognalize the U part of the low-rank form, and update the V part accordingly using CGS.
Definition: core_slrothu.c:430
pastix_fixdbl_t core_slrorthu_partialqr(pastix_int_t M, pastix_int_t N, pastix_int_t r1, pastix_int_t *r2ptr, pastix_int_t offx, pastix_int_t offy, float *U, pastix_int_t ldu, float *V, pastix_int_t ldv)
Try to orthognalize the U part of the low-rank form, and update the V part accordingly using partial ...
Definition: core_slrothu.c:197
int core_slrdbg_check_orthogonality_AB(pastix_int_t M, pastix_int_t NA, pastix_int_t NB, const float *A, pastix_int_t lda, const float *B, pastix_int_t ldb)
Check the orthogonality of the matrix A relatively to the matrix B.
Definition: core_slrdbg.c:186
@ PastixLeft
Definition: api.h:495
Manage nancheck for lowrank kernels. This header describes all the LAPACKE functions used for low-ran...