PaStiX Handbook  6.3.0
core_sscalo.c
Go to the documentation of this file.
1 /**
2  *
3  * @file core_sscalo.c
4  *
5  * PaStiX kernel routines
6  *
7  * @copyright 2010-2015 Univ. of Tennessee, Univ. of California Berkeley and
8  * Univ. of Colorado Denver. All rights reserved.
9  * @copyright 2012-2023 Bordeaux INP, CNRS (LaBRI UMR 5800), Inria,
10  * Univ. Bordeaux. All rights reserved.
11  *
12  * @version 6.3.0
13  * @author Mathieu Faverge
14  * @date 2023-01-13
15  * @generated from /builds/solverstack/pastix/kernels/core_zscalo.c, normal z -> s, Mon Aug 28 13:40:35 2023
16  *
17  **/
18 #include "common.h"
19 #include "blend/solver.h"
20 #include "pastix_scores.h"
21 #include "cblas.h"
22 #include "kernels_trace.h"
23 
24 /**
25  ******************************************************************************
26  *
27  * @brief Scale a matrix by a diagonal out of place
28  *
29  * Perform the operation: B <- op(A) * D, where A is a general matrix, and D a
30  * diagonal matrix.
31  *
32  *******************************************************************************
33  *
34  * @param[in] trans
35  * @arg PastixNoTrans: No transpose, op( A ) = A;
36  * @arg PastixTrans: Transpose, op( A ) = A;
37  * @arg PastixTrans: Conjugate Transpose, op( A ) = (A).
38  *
39  * @param[in] M
40  * Number of rows of the matrix B.
41  * Number of rows of the matrix A.
42  *
43  * @param[in] N
44  * Number of columns of the matrix B.
45  * Number of columns of the matrix A.
46  *
47  * @param[in] A
48  * Matrix of size lda-by-N.
49  *
50  * @param[in] lda
51  * Leading dimension of the array A. lda >= max(1,M).
52  *
53  * @param[in] D
54  * Diagonal matrix of size ldd-by-N.
55  *
56  * @param[in] ldd
57  * Leading dimension of the array D. ldd >= 1.
58  *
59  * @param[inout] B
60  * Matrix of size LDB-by-N.
61  *
62  * @param[in] ldb
63  * Leading dimension of the array B. ldb >= max(1,M)
64  *
65  *******************************************************************************
66  *
67  * @retval PASTIX_SUCCESS successful exit
68  * @retval <0 if -i, the i-th argument had an illegal value
69  * @retval 1, not yet implemented
70  *
71  ******************************************************************************/
72 int
74  pastix_int_t M,
75  pastix_int_t N,
76  const float *A,
77  pastix_int_t lda,
78  const float *D,
79  pastix_int_t ldd,
80  float *B,
81  pastix_int_t ldb )
82 {
83  float alpha;
84  pastix_int_t i, j;
85 
86 #if !defined(NDEBUG)
87  if ((trans < PastixNoTrans) ||
88  (trans > PastixTrans))
89  {
90  return -1;
91  }
92 
93  if (M < 0) {
94  return -2;
95  }
96  if (N < 0) {
97  return -3;
98  }
99  if ( lda < pastix_imax(1,M) )
100  {
101  return -5;
102  }
103  if ( ldd < 1 )
104  {
105  return -7;
106  }
107  if ( ldb < pastix_imax(1,M) ) {
108  return -9;
109  }
110 #endif
111 
112 #if defined(PRECISION_z) || defined(PRECISION_c)
113  if (trans == PastixTrans) {
114  for( j=0; j<N; j++, D += ldd ) {
115  alpha = *D;
116  for( i=0; i<M; i++, B++, A++ ) {
117  *B = (*A) * alpha;
118  }
119  A += lda - M;
120  B += ldb - M;
121  }
122  }
123  else
124 #endif
125  {
126  for( j=0; j<N; j++, D += ldd ) {
127  alpha = *D;
128  for( i=0; i<M; i++, B++, A++ ) {
129  *B = (*A) * alpha;
130  }
131  A += lda - M;
132  B += ldb - M;
133  }
134  }
135 
136  (void)trans;
137  return PASTIX_SUCCESS;
138 }
139 
140 /**
141  *******************************************************************************
142  *
143  * @brief Copy the L term with scaling for the two-terms algorithm
144  *
145  * Performs LD = op(L) * D
146  *
147  *******************************************************************************
148  *
149  * @param[in] trans
150  * @arg PastixNoTrans: No transpose, op( L ) = L;
151  * @arg PastixTrans: Transpose, op( L ) = L;
152  * @arg PastixTrans: Conjugate Transpose, op( L ) = (L).
153  *
154  * @param[in] cblk
155  * Pointer to the structure representing the panel to factorize in the
156  * cblktab array. Next column blok must be accessible through cblk[1].
157  *
158  * @param[inout] dataL
159  * The pointer to the correct representation of lower part of the data.
160  * - coeftab if the block is in full rank. Must be of size cblk.stride -by- cblk.width.
161  * - pastix_lr_block if the block is compressed.
162  *
163  * @param[inout] dataLD
164  * The pointer to the correct representation of LD.
165  * - coeftab if the block is in full rank. Must be of size cblk.stride -by- cblk.width.
166  * - pastix_lr_block if the block is compressed.
167  *
168  *******************************************************************************/
169 void
171  SolverCblk *cblk,
172  void *dataL,
173  void *dataLD )
174 {
175  const SolverBlok *blok, *lblk;
176  pastix_int_t M, N;
177  pastix_lrblock_t *lrL, *lrLD;
178  pastix_fixdbl_t time;
179  float *LD;
180 
181  time = kernel_trace_start( PastixKernelSCALOCblk );
182 
183  N = cblk_colnbr( cblk );
184 
185  blok = cblk->fblokptr + 1; /* Firt off-diagonal block */
186  lblk = cblk[1].fblokptr; /* Next diagonal block */
187 
188  /* if there are off-diagonal supernodes in the column */
189  if ( blok < lblk )
190  {
191  const float *L;
192  const float *D;
193  pastix_int_t ldl, ldd, ldld;
194 
195  if ( cblk->cblktype & CBLK_COMPRESSED ) {
196  lrL = (pastix_lrblock_t *)dataL;
197  lrLD = (pastix_lrblock_t *)dataLD;
198  D = lrL->u;
199  ldd = N+1;
200 
201  lrL++; lrLD++;
202  for(; blok < lblk; blok++, lrL++, lrLD++) {
203  M = blok_rownbr( blok );
204 
205  assert( lrLD->rk == -1 );
206 
207  /* Copy L in LD */
208  lrLD->rk = lrL->rk;
209  lrLD->rkmax = lrL->rkmax;
210 
211  if ( lrL->rk == -1 ) {
212  assert( M == lrL->rkmax );
213 
214  /* Initialize the workspace */
215  memcpy( lrLD->u, lrL->u, lrL->rkmax * N * sizeof(float) );
216  lrLD->v = NULL;
217 
218  L = lrL->u;
219  LD = lrLD->u;
220  }
221  else {
222  /*
223  * Initialize the workspace
224  */
225  memcpy( lrLD->u, lrL->u, M * lrL->rk * sizeof(float) );
226  lrLD->v = ((float *)lrLD->u) + M * lrL->rk;
227  memcpy( lrLD->v, lrL->v, N * lrL->rkmax * sizeof(float) );
228 
229  L = lrL->v;
230  LD = lrLD->v;
231  M = lrLD->rkmax;
232  }
233 
234  ldl = M;
235  ldld = M;
236 
237  /* Compute LD = L * D */
238  core_sscalo( trans, M, N,
239  L, ldl, D, ldd,
240  LD, ldld );
241  }
242  }
243  else if ( cblk->cblktype & CBLK_LAYOUT_2D ) {
244  L = D = (float *)dataL;
245  LD = (float *)dataLD;
246  ldd = N+1;
247 
248  for(; blok < lblk; blok++) {
249  M = blok_rownbr( blok );
250 
251  /* Compute LD = L * D */
252  core_sscalo( trans, M, N,
253  L + blok->coefind, M, D, ldd,
254  LD + blok->coefind, M );
255  }
256  }
257  else {
258  L = D = (float *)dataL;
259  LD = (float *)dataLD;
260  ldl = cblk->stride;
261  ldd = cblk->stride+1;
262 
263  M = cblk->stride - N;
264  LD = LD + blok->coefind;
265  ldld = cblk->stride;
266 
267  core_sscalo( trans, M, N, L + blok->coefind, ldl, D, ldd, LD, ldld );
268  }
269  }
270 
271  M = cblk->stride - N;
272  kernel_trace_stop( cblk->fblokptr->inlast, PastixKernelSCALOCblk, M, N, 0, (pastix_fixdbl_t)(M*N), time );
273 }
274 
275 /**
276  *******************************************************************************
277  *
278  * @brief Copy the lower terms of the block with scaling for the two-terms
279  * algorithm.
280  *
281  * Performs B = op(A) * D
282  *
283  *******************************************************************************
284  *
285  * @param[in] trans
286  * @arg PastixNoTrans: No transpose, op( A ) = A;
287  * @arg PastixTrans: Transpose, op( A ) = A;
288  * @arg PastixTrans: Conjugate Transpose, op( A ) = (A).
289  *
290  * @param[in] cblk
291  * Pointer to the structure representing the panel to factorize in the
292  * cblktab array. Next column blok must be accessible through cblk[1].
293  *
294  * @param[in] blok_m
295  * Index of the off-diagonal block to be solved in the cblk. All blocks
296  * facing the same cblk, in the current column block will be solved.
297  *
298  * @param[in] dataA
299  * The pointer to the correct representation of data of A.
300  * - coeftab if the block is in full rank. Must be of size cblk.stride -by- cblk.width.
301  * - pastix_lr_block if the block is compressed.
302  *
303  * @param[in] dataD
304  * The pointer to the correct representation of data of D.
305  * - coeftab if the block is in full rank. Must be of size cblk.stride -by- cblk.width.
306  * - pastix_lr_block if the block is compressed.
307  *
308  * @param[inout] dataB
309  * The pointer to the correct representation of data of B.
310  * - coeftab if the block is in full rank. Must be of size cblk.stride -by- cblk.width.
311  * - pastix_lr_block if the block is compressed.
312  *
313  *******************************************************************************/
314 void
316  SolverCblk *cblk,
317  pastix_int_t blok_m,
318  const void *dataA,
319  const void *dataD,
320  void *dataB )
321 {
322  const SolverBlok *fblok, *lblok, *blok;
323  pastix_int_t M, N, ldd, offset, cblk_m;
324  const float *lA;
325  pastix_lrblock_t *lrD, *lrB, *lrA;
326  float *D, *B, *A;
327  float *lB;
328 
329  N = cblk_colnbr( cblk );
330  fblok = cblk[0].fblokptr; /* The diagonal block */
331  lblok = cblk[1].fblokptr; /* The diagonal block of the next cblk */
332  ldd = blok_rownbr( fblok ) + 1;
333 
334  assert( blok_rownbr(fblok) == N );
335  assert( cblk->cblktype & CBLK_LAYOUT_2D );
336 
337  blok = fblok + blok_m;
338  offset = blok->coefind;
339  cblk_m = blok->fcblknm;
340 
341  if ( cblk->cblktype & CBLK_COMPRESSED ) {
342  lrA = (pastix_lrblock_t *)dataA;
343  lrD = (pastix_lrblock_t *)dataD;
344  lrB = (pastix_lrblock_t *)dataB;
345  D = lrD->u;
346  for (; (blok < lblok) && (blok->fcblknm == cblk_m); blok++, lrA++, lrB++) {
347  M = blok_rownbr( blok );
348 
349  /* Copy A in B */
350  lrB->rk = lrA->rk;
351  lrB->rkmax = lrA->rkmax;
352 
353  if ( lrB->rk == -1 ) {
354  assert( M == lrA->rkmax );
355  assert( NULL == lrA->v );
356 
357  /* Initialize the workspace */
358  memcpy( lrB->u, lrA->u, lrA->rkmax * N * sizeof(float) );
359  lrB->v = NULL;
360 
361  lA = lrA->u;
362  lB = lrB->u;
363  }
364  else {
365  /*
366  * Initialize the workspace
367  */
368  memcpy( lrB->u, lrA->u, M * lrA->rk * sizeof(float) );
369  lrB->v = ((float *)lrB->u) + M * lrA->rk;
370  memcpy( lrB->v, lrA->v, N * lrA->rkmax * sizeof(float) );
371 
372  lA = lrA->v;
373  lB = lrB->v;
374  M = lrA->rkmax;
375  }
376 
377  /* Compute B = op(A) * D */
378  core_sscalo( trans, M, N,
379  lA, M, D, ldd, lB, M );
380  }
381  }
382  else {
383  A = (float *)dataA;
384  D = (float *)dataD;
385  B = (float *)dataB;
386 
387  for (; (blok < lblok) && (blok->fcblknm == cblk_m); blok++) {
388  lA = A + blok->coefind - offset;
389  lB = B + blok->coefind - offset;
390  M = blok_rownbr(blok);
391 
392  /* Compute B = op(A) * D */
393  core_sscalo( trans, M, N,
394  lA, M, D, ldd, lB, M );
395  }
396  }
397 }
int core_sscalo(pastix_trans_t trans, pastix_int_t M, pastix_int_t N, const float *A, pastix_int_t lda, const float *D, pastix_int_t ldd, float *B, pastix_int_t ldb)
Scale a matrix by a diagonal out of place.
Definition: core_sscalo.c:73
void cpucblk_sscalo(pastix_trans_t trans, SolverCblk *cblk, void *dataL, void *dataLD)
Copy the L term with scaling for the two-terms algorithm.
Definition: core_sscalo.c:170
void cpublok_sscalo(pastix_trans_t trans, SolverCblk *cblk, pastix_int_t blok_m, const void *dataA, const void *dataD, void *dataB)
Copy the lower terms of the block with scaling for the two-terms algorithm.
Definition: core_sscalo.c:315
The block low-rank structure to hold a matrix in low-rank form.
enum pastix_trans_e pastix_trans_t
Transpostion.
@ PastixNoTrans
Definition: api.h:447
@ PastixTrans
Definition: api.h:448
@ PASTIX_SUCCESS
Definition: api.h:369
static pastix_int_t blok_rownbr(const SolverBlok *blok)
Compute the number of rows of a block.
Definition: solver.h:389
static pastix_int_t cblk_colnbr(const SolverCblk *cblk)
Compute the number of columns in a column block.
Definition: solver.h:323
pastix_int_t fcblknm
Definition: solver.h:140
pastix_int_t coefind
Definition: solver.h:144
SolverBlok * fblokptr
Definition: solver.h:163
int8_t inlast
Definition: solver.h:146
pastix_int_t stride
Definition: solver.h:164
int8_t cblktype
Definition: solver.h:159
Solver block structure.
Definition: solver.h:137
Solver column block structure.
Definition: solver.h:156