PaStiX Handbook 6.4.0
Loading...
Searching...
No Matches
core_sxx2lr.c
Go to the documentation of this file.
1/**
2 *
3 * @file core_sxx2lr.c
4 *
5 * PaStiX low-rank kernel routines that form the product of two matrices A and B
6 * into a low-rank form for an update on a null or low-rank matrix.
7 *
8 * @copyright 2016-2024 Bordeaux INP, CNRS (LaBRI UMR 5800), Inria,
9 * Univ. Bordeaux. All rights reserved.
10 *
11 * @version 6.4.0
12 * @author Mathieu Faverge
13 * @author Gregoire Pichon
14 * @author Pierre Ramet
15 * @date 2024-07-05
16 * @generated from /builds/2mk6rsew/0/solverstack/pastix/kernels/core_zxx2lr.c, normal z -> s, Tue Feb 25 14:34:57 2025
17 *
18 **/
19#include "common.h"
20#include <cblas.h>
21#include "flops.h"
22#include "blend/solver.h"
23#include "pastix_slrcores.h"
24
25#ifndef DOXYGEN_SHOULD_SKIP_THIS
26static float sone = 1.0;
27static float szero = 0.0;
28#endif /* DOXYGEN_SHOULD_SKIP_THIS */
29
30/**
31 *******************************************************************************
32 *
33 * @brief Perform the operation AB = op(A) * op(B), with A and B full-rank and AB
34 * low-rank.
35 *
36 *******************************************************************************
37 *
38 * @param[inout] params
39 * The LRMM structure that stores all the parameters used in the LRMM
40 * functions family.
41 * On exit, the C matrix contains the product AB aligned with its own
42 * dimensions.
43 * @sa core_slrmm_t
44 *
45 * @param[inout] AB
46 * The low-rank structure of the AB matrix in which to store the AB product.
47 *
48 * @param[inout] infomask
49 * The mask of informations returned by the core_sxx2lr() functions.
50 * - If AB.u is orthogonal on exit, then PASTIX_LRM3_ORTHOU is set.
51 * - If AB.u is allocated, then PASTIX_LRM3_ALLOCU is set.
52 * - If AB.v is allocated, then PASTIX_LRM3_ALLOCV is set.
53 * - If AB.v is inistialized as one of the given pointer and op(B) is not
54 * applyed, then PASTIX_LRM3_TRANSB is set.
55 *
56 * @param[in] Kmax
57 * The maximum K value for which the AB product is contructed as AB.u =
58 * A, and AB.v = B
59 *
60 *******************************************************************************
61 *
62 * @return The number of flops required to perform the operation.
63 *
64 *******************************************************************************/
68 int *infomask,
69 pastix_int_t Kmax )
70{
72 pastix_int_t ldau, ldbu;
73 pastix_fixdbl_t flops = 0.0;
74
75 ldau = (transA == PastixNoTrans) ? M : K;
76 ldbu = (transB == PastixNoTrans) ? K : N;
77
78 /*
79 * Everything is full rank
80 */
81 if ( K < Kmax ) {
82 /*
83 * Let's build a low-rank matrix of rank K
84 */
85 AB->rk = K;
86 AB->rkmax = K;
87 AB->u = A->u;
88 AB->v = B->u;
89 *infomask |= PASTIX_LRM3_TRANSB;
90 }
91 else {
92 /*
93 * Let's compute the product to form a full-rank matrix of rank
94 * pastix_imin( M, N )
95 */
96 if ( (work = core_slrmm_getws( params, M * N )) == NULL ) {
97 work = malloc( M * N * sizeof(float) );
98 *infomask |= PASTIX_LRM3_ALLOCU;
99 }
100 AB->rk = -1;
101 AB->rkmax = M;
102 AB->u = work;
103 AB->v = NULL;
104
105 cblas_sgemm( CblasColMajor, CblasNoTrans, (CBLAS_TRANSPOSE)transB,
106 M, N, K,
107 (sone), A->u, ldau,
108 B->u, ldbu,
109 (szero), AB->u, M );
110 flops = FLOPS_SGEMM( M, N, K );
111 }
112
114 return flops;
115}
116
117/**
118 *******************************************************************************
119 *
120 * @brief Perform the operation AB = op(A) * op(B), with A full-rank and B and AB
121 * low-rank.
122 *
123 *******************************************************************************
124 *
125 * @param[inout] params
126 * The LRMM structure that stores all the parameters used in the LRMM
127 * functions family.
128 * On exit, the C matrix contains the product AB aligned with its own
129 * dimensions.
130 * @sa core_slrmm_t
131 *
132 * @param[inout] AB
133 * The low-rank structure of the AB matrix in which to store the AB product.
134 *
135 * @param[inout] infomask
136 * The mask of informations returned by the core_sxx2lr() functions.
137 * - If AB.u is orthogonal on exit, then PASTIX_LRM3_ORTHOU is set.
138 * - If AB.u is allocated, then PASTIX_LRM3_ALLOCU is set.
139 * - If AB.v is allocated, then PASTIX_LRM3_ALLOCV is set.
140 * - If AB.v is inistialized as one of the given pointer and op(B) is not
141 * applyed, then PASTIX_LRM3_TRANSB is set.
142 *
143 * @param[in] Brkmin
144 * Threshold for which B->rk is considered as the final rank of AB
145 *
146 *******************************************************************************
147 *
148 * @return The number of flops required to perform the operation.
149 *
150 *******************************************************************************/
154 int *infomask,
155 pastix_int_t Brkmin )
156{
157 PASTE_CORE_SLRMM_PARAMS( params );
158 pastix_int_t ldau, ldbu, ldbv;
159 pastix_fixdbl_t flops;
160
161 ldau = (transA == PastixNoTrans) ? M : K;
162 ldbu = (transB == PastixNoTrans) ? K : N;
163 ldbv = ( B->rk == -1 ) ? -1 : B->rkmax;
164
165 /*
166 * A(M-by-K) * B( N-by-rb x rb-by-K )^t
167 */
168 if ( B->rk > Brkmin ) {
169 /*
170 * We are in a similar case to the _Cfr function, and we
171 * choose the optimal number of flops.
172 */
173 pastix_fixdbl_t flops1 = FLOPS_SGEMM( M, B->rk, K ) + FLOPS_SGEMM( M, N, B->rk );
174 pastix_fixdbl_t flops2 = FLOPS_SGEMM( K, N, B->rk ) + FLOPS_SGEMM( M, N, K );
175 float *tmp;
176
177 AB->rk = -1;
178 AB->rkmax = M;
179 AB->v = NULL;
180
181 if ( flops1 <= flops2 ) {
182 if ( (work = core_slrmm_getws( params, M * B->rk + M * N )) == NULL ) {
183 work = malloc( (M * B->rk + M * N) * sizeof(float) );
184 *infomask |= PASTIX_LRM3_ALLOCU;
185 }
186
187 /* AB->u will be destroyed later */
188 AB->u = work;
189 tmp = work + M * N;
190
191 /*
192 * (A * Bv) * Bu^t
193 */
194 cblas_sgemm( CblasColMajor, (CBLAS_TRANSPOSE)transA, (CBLAS_TRANSPOSE)transB,
195 M, B->rk, K,
196 (sone), A->u, ldau,
197 B->v, ldbv,
198 (szero), tmp, M );
199
200 cblas_sgemm( CblasColMajor, CblasNoTrans, (CBLAS_TRANSPOSE)transB,
201 M, N, B->rk,
202 (sone), tmp, M,
203 B->u, ldbu,
204 (szero), AB->u, M );
205
206 flops = flops1;
207 }
208 else {
209 if ( (work = core_slrmm_getws( params, K * N + M * N )) == NULL ) {
210 work = malloc( (K * N + M * N) * sizeof(float) );
211 *infomask |= PASTIX_LRM3_ALLOCU;
212 }
213
214 /* AB->u will be destroyed later */
215 AB->u = work;
216 tmp = work + M * N;
217
218 /*
219 * A * (Bu * Bv^t)^t
220 */
221 cblas_sgemm( CblasColMajor, CblasNoTrans, CblasNoTrans,
222 K, N, B->rk,
223 (sone), B->u, ldbu,
224 B->v, ldbv,
225 (szero), tmp, K );
226
227 cblas_sgemm( CblasColMajor, (CBLAS_TRANSPOSE)transA, (CBLAS_TRANSPOSE)transB,
228 M, N, K,
229 (sone), A->u, ldau,
230 tmp, K,
231 (szero), AB->u, M );
232
233 flops = flops2;
234 }
235 }
236 else {
237 /*
238 * B->rk is the smallest rank
239 */
240 AB->rk = B->rk;
241 AB->rkmax = B->rkmax;
242 AB->v = B->u;
243 *infomask |= PASTIX_LRM3_TRANSB;
244
245 if ( (work = core_slrmm_getws( params, M * B->rk )) == NULL ) {
246 work = malloc( M * B->rk * sizeof(float) );
247 *infomask |= PASTIX_LRM3_ALLOCU;
248 }
249 AB->u = work;
250
251 cblas_sgemm( CblasColMajor, CblasNoTrans, (CBLAS_TRANSPOSE)transB,
252 M, B->rk, K,
253 (sone), A->u, ldau,
254 B->v, ldbv,
255 (szero), AB->u, M );
256 flops = FLOPS_SGEMM( M, B->rk, K );
257 }
258
260 return flops;
261}
262
263/**
264 *******************************************************************************
265 *
266 * @brief Perform the operation AB = op(A) * op(B), with B full-rank and A and AB
267 * low-rank.
268 *
269 *******************************************************************************
270 *
271 * @param[inout] params
272 * The LRMM structure that stores all the parameters used in the LRMM
273 * functions family.
274 * On exit, the C matrix contains the product AB aligned with its own
275 * dimensions.
276 * @sa core_slrmm_t
277 *
278 * @param[inout] AB
279 * The low-rank structure of the AB matrix in which to store the AB product.
280 *
281 * @param[inout] infomask
282 * The mask of informations returned by the core_sxx2lr() functions.
283 * - If AB.u is orthogonal on exit, then PASTIX_LRM3_ORTHOU is set.
284 * - If AB.u is allocated, then PASTIX_LRM3_ALLOCU is set.
285 * - If AB.v is allocated, then PASTIX_LRM3_ALLOCV is set.
286 * - If AB.v is inistialized as one of the given pointer and op(B) is not
287 * applyed, then PASTIX_LRM3_TRANSB is set.
288 *
289 * @param[in] Arkmin
290 * Threshold for which A->rk is considered as the final rank of AB
291 *
292 *******************************************************************************
293 *
294 * @return The number of flops required to perform the operation.
295 *
296 *******************************************************************************/
300 int *infomask,
301 pastix_int_t Arkmin )
302{
303 PASTE_CORE_SLRMM_PARAMS( params );
304 pastix_int_t ldau, ldav, ldbu;
305 pastix_fixdbl_t flops;
306
307 ldau = (transA == PastixNoTrans) ? M : K;
308 ldav = ( A->rk == -1 ) ? -1 : A->rkmax;
309 ldbu = (transB == PastixNoTrans) ? K : N;
310
311 /*
312 * A( M-by-ra x ra-by-K ) * B(N-by-K)^t
313 */
314 if ( A->rk > Arkmin ) {
315 /*
316 * We are in a similar case to the _Cfr function, and we
317 * choose the optimal number of flops.
318 */
319 pastix_fixdbl_t flops1 = FLOPS_SGEMM( A->rk, N, K ) + FLOPS_SGEMM( M, N, A->rk );
320 pastix_fixdbl_t flops2 = FLOPS_SGEMM( M, K, A->rk ) + FLOPS_SGEMM( M, N, K );
321 float *tmp;
322
323 AB->rk = -1;
324 AB->rkmax = M;
325 AB->v = NULL;
326
327 if ( flops1 <= flops2 ) {
328 if ( (work = core_slrmm_getws( params, A->rk * N + M * N )) == NULL ) {
329 work = malloc( (A->rk * N + M * N) * sizeof(float) );
330 *infomask |= PASTIX_LRM3_ALLOCU;
331 }
332
333 /* AB->u will be destroyed later */
334 AB->u = work;
335 tmp = work + M * N;
336
337 /*
338 * Au * (Av^t * B^t)
339 */
340 cblas_sgemm( CblasColMajor, CblasNoTrans, (CBLAS_TRANSPOSE)transB,
341 A->rk, N, K,
342 (sone), A->v, ldav,
343 B->u, ldbu,
344 (szero), tmp, A->rk );
345
346 cblas_sgemm( CblasColMajor, CblasNoTrans, CblasNoTrans,
347 M, N, A->rk,
348 (sone), A->u, ldau,
349 tmp, A->rk,
350 (szero), AB->u, M );
351
352 flops = flops1;
353 }
354 else {
355 if ( (work = core_slrmm_getws( params, M * K + M * N )) == NULL ) {
356 work = malloc( (M * K + M * N) * sizeof(float) );
357 *infomask |= PASTIX_LRM3_ALLOCU;
358 }
359
360 /* AB->u will be destroyed later */
361 AB->u = work;
362 tmp = work + M * N;
363
364 /*
365 * (Au * Av^t) * B^t
366 */
367 cblas_sgemm( CblasColMajor, CblasNoTrans, CblasNoTrans,
368 M, K, A->rk,
369 (sone), A->u, ldau,
370 A->v, ldav,
371 (szero), tmp, M );
372
373 cblas_sgemm( CblasColMajor, (CBLAS_TRANSPOSE)transA, (CBLAS_TRANSPOSE)transB,
374 M, N, K,
375 (sone), tmp, M,
376 B->u, ldbu,
377 (szero), AB->u, M );
378
379 flops = flops2;
380 }
381 }
382 else {
383 /*
384 * A->rk is the smallest rank
385 */
386 AB->rk = A->rk;
387 AB->rkmax = A->rk;
388 AB->u = A->u;
389 *infomask |= PASTIX_LRM3_ORTHOU;
390
391 if ( (work = core_slrmm_getws( params, A->rk * N )) == NULL ) {
392 work = malloc( A->rk * N * sizeof(float) );
393 *infomask |= PASTIX_LRM3_ALLOCV;
394 }
395 AB->v = work;
396
397 cblas_sgemm( CblasColMajor, CblasNoTrans, (CBLAS_TRANSPOSE)transB,
398 A->rk, N, K,
399 (sone), A->v, ldav,
400 B->u, ldbu,
401 (szero), AB->v, AB->rkmax );
402
403 flops = FLOPS_SGEMM( A->rk, N, K );
404 }
405
407 (void)infomask;
408 return flops;
409}
410
411/**
412 *******************************************************************************
413 *
414 * @brief Perform the operation AB = op(A) * op(B), with A, B, and AB low-rank.
415 *
416 *******************************************************************************
417 *
418 * @param[inout] params
419 * The LRMM structure that stores all the parameters used in the LRMM
420 * functions family.
421 * On exit, the C matrix contains the product AB aligned with its own
422 * dimensions.
423 * @sa core_slrmm_t
424 *
425 * @param[inout] AB
426 * The low-rank structure of the AB matrix in which to store the AB product.
427 *
428 * @param[inout] infomask
429 * The mask of informations returned by the core_sxx2lr() functions.
430 * - If AB.u is orthogonal on exit, then PASTIX_LRM3_ORTHOU is set.
431 * - If AB.u is allocated, then PASTIX_LRM3_ALLOCU is set.
432 * - If AB.v is allocated, then PASTIX_LRM3_ALLOCV is set.
433 * - If AB.v is inistialized as one of the given pointer and op(B) is not
434 * applyed, then PASTIX_LRM3_TRANSB is set.
435 *
436 *******************************************************************************
437 *
438 * @return The number of flops required to perform the operation.
439 *
440 *******************************************************************************/
444 int *infomask )
445{
446 PASTE_CORE_SLRMM_PARAMS( params );
447 pastix_int_t ldau, ldav, ldbu, ldbv;
448 float *work2;
449 pastix_lrblock_t rArB;
450 pastix_fixdbl_t flops = 0.0;
451 int allocated = 0;
452
453 assert( A->rk <= A->rkmax && A->rk > 0 );
454 assert( B->rk <= B->rkmax && B->rk > 0 );
455 assert( transA == PastixNoTrans );
456 assert( transB != PastixNoTrans );
457
458 *infomask = 0;
459 ldau = (A->rk == -1) ? A->rkmax : M;
460 ldav = A->rkmax;
461 ldbu = (B->rk == -1) ? B->rkmax : N;
462 ldbv = B->rkmax;
463
464 if ( (work2 = core_slrmm_getws( params, A->rk * B->rk )) == NULL ) {
465 work2 = malloc( A->rk * B->rk * sizeof(float) );
466 allocated = 1;
467 }
468
469 /*
470 * Let's compute A * B' = Au Av^h (Bu Bv^h)' with the smallest ws
471 */
472 cblas_sgemm( CblasColMajor, CblasNoTrans, (CBLAS_TRANSPOSE)transB,
473 A->rk, B->rk, K,
474 (sone), A->v, ldav,
475 B->v, ldbv,
476 (szero), work2, A->rk );
477 flops = FLOPS_SGEMM( A->rk, B->rk, K );
478
479 /*
480 * Try to compress (Av^h Bv^h')
481 */
482 flops += lowrank->core_ge2lr( lowrank->use_reltol, lowrank->tolerance, -1, A->rk, B->rk, work2, A->rk, &rArB );
483
484 /*
485 * The rank of AB is not smaller than min(rankA, rankB)
486 */
487 if ( rArB.rk == -1 ) {
488 if ( A->rk <= B->rk ) {
489 /*
490 * ABu = Au
491 * ABv = (Av^h Bv^h') * Bu'
492 */
493 if ( (work = core_slrmm_getws( params, A->rk * N )) == NULL ) {
494 work = malloc( A->rk * N * sizeof(float) );
495 *infomask |= PASTIX_LRM3_ALLOCV;
496 }
497
498 AB->rk = A->rk;
499 AB->rkmax = A->rk;
500 AB->u = A->u;
501 AB->v = work;
502 *infomask |= PASTIX_LRM3_ORTHOU;
503
504 cblas_sgemm( CblasColMajor, CblasNoTrans, (CBLAS_TRANSPOSE)transB,
505 A->rk, N, B->rk,
506 (sone), work2, A->rk,
507 B->u, ldbu,
508 (szero), AB->v, AB->rkmax );
509 flops += FLOPS_SGEMM( A->rk, N, B->rk );
510 }
511 else {
512 /*
513 * ABu = Au * (Av^h Bv^h')
514 * ABv = Bu'
515 */
516 if ( (work = core_slrmm_getws( params, B->rk * M )) == NULL ) {
517 work = malloc( B->rk * M * sizeof(float) );
518 *infomask |= PASTIX_LRM3_ALLOCU;
519 }
520
521 AB->rk = B->rk;
522 AB->rkmax = B->rk;
523 AB->u = work;
524 AB->v = B->u;
525
526 cblas_sgemm( CblasColMajor, CblasNoTrans, CblasNoTrans,
527 M, B->rk, A->rk,
528 (sone), A->u, ldau,
529 work2, A->rk,
530 (szero), AB->u, M );
531 flops += FLOPS_SGEMM( M, B->rk, A->rk );
532
533 *infomask |= PASTIX_LRM3_TRANSB;
534 }
535 }
536 else if ( rArB.rk == 0 ) {
537 AB->rk = 0;
538 AB->rkmax = 0;
539 AB->u = NULL;
540 AB->v = NULL;
541 *infomask |= PASTIX_LRM3_ORTHOU;
542 }
543 /**
544 * The rank of AB is smaller than min(rankA, rankB)
545 */
546 else {
547 if ( (work = core_slrmm_getws( params, (M + N) * rArB.rk )) == NULL ) {
548 work = malloc( (M + N) * rArB.rk * sizeof(float) );
549 *infomask |= PASTIX_LRM3_ALLOCU;
550 }
551
552 AB->rk = rArB.rk;
553 AB->rkmax = rArB.rk;
554 AB->u = work;
555 AB->v = work + M * rArB.rk;
556 *infomask |= PASTIX_LRM3_ORTHOU;
557
558 cblas_sgemm( CblasColMajor, CblasNoTrans, CblasNoTrans,
559 M, rArB.rk, A->rk,
560 (sone), A->u, ldau,
561 rArB.u, A->rk,
562 (szero), AB->u, M );
563
564 cblas_sgemm( CblasColMajor, CblasNoTrans, (CBLAS_TRANSPOSE)transB,
565 rArB.rk, N, B->rk,
566 (sone), rArB.v, rArB.rkmax,
567 B->u, ldbu,
568 (szero), AB->v, rArB.rk );
569
570 flops += FLOPS_SGEMM( M, rArB.rk, A->rk ) + FLOPS_SGEMM( rArB.rk, N, B->rk );
571 }
572 core_slrfree(&rArB);
573
574 if ( allocated ) {
575 free( work2 );
576 }
578 return flops;
579}
BEGIN_C_DECLS typedef int pastix_int_t
Definition datatypes.h:51
double pastix_fixdbl_t
Definition datatypes.h:65
static float * core_slrmm_getws(core_slrmm_t *params, ssize_t newsize)
Function to get a workspace pointer if space is available in the one provided.
#define PASTE_CORE_SLRMM_PARAMS(_a_)
Initialize all the parameters of the core_slrmm family functions to ease the access.
pastix_fixdbl_t core_slrlr2lr(core_slrmm_t *params, pastix_lrblock_t *AB, int *infomask)
Perform the operation AB = op(A) * op(B), with A, B, and AB low-rank.
#define PASTE_CORE_SLRMM_VOID
Void all the parameters of the core_slrmm family functions to silent warnings.
pastix_fixdbl_t core_sfrlr2lr(core_slrmm_t *params, pastix_lrblock_t *AB, int *infomask, pastix_int_t Brkmin)
Perform the operation AB = op(A) * op(B), with A full-rank and B and AB low-rank.
pastix_fixdbl_t core_slrfr2lr(core_slrmm_t *params, pastix_lrblock_t *AB, int *infomask, pastix_int_t Arkmin)
Perform the operation AB = op(A) * op(B), with B full-rank and A and AB low-rank.
pastix_fixdbl_t core_sfrfr2lr(core_slrmm_t *params, pastix_lrblock_t *AB, int *infomask, pastix_int_t Kmax)
Perform the operation AB = op(A) * op(B), with A and B full-rank and AB low-rank.
Definition core_sxx2lr.c:66
Structure to store all the parameters of the core_slrmm family functions.
#define PASTIX_LRM3_ALLOCV
Macro to specify if the V part of a low-rank matrix has been allocated and need to be freed or not (U...
#define PASTIX_LRM3_TRANSB
Macro to specify if the the operator on B, still needs to be applied to the V part of the low-rank ma...
#define PASTIX_LRM3_ALLOCU
Macro to specify if the U part of a low-rank matrix has been allocated and need to be freed or not (U...
#define PASTIX_LRM3_ORTHOU
Macro to specify if the U part of a low-rank matrix is orthogonal or not (Used in LRMM functions).
The block low-rank structure to hold a matrix in low-rank form.
void core_slrfree(pastix_lrblock_t *A)
Free a low-rank matrix.
@ PastixNoTrans
Definition api.h:445