PaStiX Handbook  6.4.0
solve_ctrsmsp.c
Go to the documentation of this file.
1 /**
2  *
3  * @file solve_ctrsmsp.c
4  *
5  * PaStiX solve kernels routines
6  *
7  * @copyright 2012-2024 Bordeaux INP, CNRS (LaBRI UMR 5800), Inria,
8  * Univ. Bordeaux. All rights reserved.
9  *
10  * @version 6.4.0
11  * @author Mathieu Faverge
12  * @author Pierre Ramet
13  * @author Xavier Lacoste
14  * @author Tony Delarue
15  * @author Vincent Bridonneau
16  * @author Alycia Lisito
17  * @author Nolan Bredel
18  * @date 2024-07-05
19  * @generated from /builds/solverstack/pastix/kernels/solve_ztrsmsp.c, normal z -> c, Tue Oct 8 14:17:24 2024
20  *
21  **/
22 #include "common.h"
23 #include "cblas.h"
24 #include "blend/solver.h"
25 #include "kernels_trace.h"
26 #include "pastix_ccores.h"
27 #include "pastix_clrcores.h"
28 
29 #ifndef DOXYGEN_SHOULD_SKIP_THIS
30 static pastix_complex32_t czero = 0.0;
31 static pastix_complex32_t cone = 1.0;
32 static pastix_complex32_t mcone = -1.0;
33 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
34 
35 /**
36  *******************************************************************************
37  *
38  * @brief Apply a solve trsm update related to a diagonal block of the matrix A.
39  *
40  *******************************************************************************
41  *
42  * @param[in] side
43  * Specify the side parameter of the TRSM.
44  *
45  * @param[in] uplo
46  * Specify the uplo parameter of the TRSM.
47  *
48  * @param[in] trans
49  * Specify the transposition used for the matrix A in the
50  * computation. It has to be either PastixTrans or PastixConjTrans.
51  *
52  * @param[in] diag
53  * Specify if the off-diagonal blocks are unit triangular. It has to be
54  * either PastixUnit or PastixNonUnit.
55  *
56  * @param[in] cblk
57  * The cblk structure that corresponds to the A and B matrix.
58  *
59  * @param[in] nrhs
60  * The number of right hand side.
61  *
62  * @param[in] dataA
63  * The pointer to the correct representation of the data of A.
64  * - coeftab if the block is in full rank. Must be of size cblk.stride -by- cblk.width.
65  * - pastix_lr_block if the block is compressed.
66  *
67  * @param[inout] b
68  * The pointer to the matrix B, that is a portion of the right hand
69  * side to solve.
70  *
71  * @param[in] ldb
72  * The leading dimension of B.
73  *
74  *******************************************************************************/
75 void
77  pastix_uplo_t uplo,
78  pastix_trans_t trans,
79  pastix_diag_t diag,
80  const SolverCblk *cblk,
81  int nrhs,
82  const void *dataA,
84  int ldb )
85 {
86  const pastix_complex32_t *A;
87  pastix_int_t n;
88  pastix_int_t lda;
89 
90  n = cblk_colnbr( cblk );
91 
92  if ( cblk->cblktype & CBLK_COMPRESSED ) {
93  const pastix_lrblock_t *lrA = (const pastix_lrblock_t *)dataA;
94  assert( lrA->rk == -1 );
95  A = lrA->u;
96  lda = n;
97  }
98  else {
99  A = (const pastix_complex32_t *)dataA;
100  lda = (cblk->cblktype & CBLK_LAYOUT_2D) ? n : cblk->stride;
101  }
102 
103  cblas_ctrsm(
104  CblasColMajor, (CBLAS_SIDE)side, (CBLAS_UPLO)uplo,
105  (CBLAS_TRANSPOSE)trans, (CBLAS_DIAG)diag,
106  n, nrhs,
107  CBLAS_SADDR(cone), A, lda,
108  b, ldb );
109 }
110 
111 /**
112  *******************************************************************************
113  *
114  * @brief Apply a solve gemm update related to a single block of the matrix A.
115  *
116  *******************************************************************************
117  *
118  * @param[in] side
119  * Specify whether the blok parameter belongs to cblk (PastixLeft), or
120  * to fcbk (PastixRight).
121  *
122  * @param[in] trans
123  * Specify the transposition used for the matrix A in the
124  * computation. It has to be either PastixTrans or PastixConjTrans.
125  *
126  * @param[in] nrhs
127  * The number of right hand side.
128  *
129  * @param[in] cblk
130  * The cblk structure that corresponds to the B matrix.
131  *
132  * @param[in] blok
133  * The blok structure that corresponds to the A matrix, and that
134  * belongs either to cblk or fcbk depending on the side parameter.
135  *
136  * @param[inout] fcbk
137  * The cblk structure that corresponds to the C matrix.
138  *
139  * @param[in] dataA
140  * The pointer to the correct representation of the data of A.
141  * - coeftab if the block is in full rank. Must be of size cblk.stride -by- cblk.width.
142  * - pastix_lr_block if the block is compressed.
143  *
144  * @param[in] B
145  * The pointer to the matrix B, that is a portion of the right hand
146  * side.
147  *
148  * @param[in] ldb
149  * The leading dimension of B.
150  *
151  * @param[inout] C
152  * The pointer to the matrix C, that is the updated portion of the
153  * right hand side.
154  *
155  * @param[in] ldc
156  * The leading dimension of C.
157  *
158  *******************************************************************************/
159 void
161  pastix_trans_t trans,
162  pastix_int_t nrhs,
163  const SolverCblk *cblk,
164  const SolverBlok *blok,
165  SolverCblk *fcbk,
166  const void *dataA,
167  const pastix_complex32_t *B,
168  pastix_int_t ldb,
170  pastix_int_t ldc )
171 {
172  pastix_int_t m, n, lda;
173  pastix_int_t offB, offC;
174  const SolverCblk *bowner;
175 
176  if ( side == PastixLeft ) {
177  /*
178  * Blok should belong to cblk
179  */
180  bowner = cblk;
181 
182  m = blok_rownbr( blok );
183  n = cblk_colnbr( cblk );
184  lda = m;
185 
186  offB = 0;
187  offC = blok->frownum - fcbk->fcolnum;
188  }
189  else {
190  /*
191  * Blok should belong to fcbk
192  */
193  bowner = fcbk;
194 
195  m = cblk_colnbr( fcbk );
196  n = blok_rownbr( blok );
197  lda = n;
198 
199  offB = blok->frownum - cblk->fcolnum;
200  offC = 0;
201  }
202 
203  assert( (blok > bowner[0].fblokptr) &&
204  (blok < bowner[1].fblokptr) );
205 
206  if ( bowner->cblktype & CBLK_COMPRESSED ) {
207  const pastix_lrblock_t *lrA = dataA;
208  pastix_complex32_t *tmp;
209 
210  switch (lrA->rk){
211  case 0:
212  break;
213  case -1:
214  pastix_cblk_lock( fcbk );
215  cblas_cgemm(
216  CblasColMajor, (CBLAS_TRANSPOSE)trans, CblasNoTrans,
217  m, nrhs, n,
218  CBLAS_SADDR(mcone), lrA->u, lda,
219  B + offB, ldb,
220  CBLAS_SADDR(cone), C + offC, ldc );
221  pastix_cblk_unlock( fcbk );
222  break;
223  default:
224  MALLOC_INTERN( tmp, lrA->rk * nrhs, pastix_complex32_t);
225  if (trans == PastixNoTrans) {
226  cblas_cgemm(
227  CblasColMajor, (CBLAS_TRANSPOSE)trans, CblasNoTrans,
228  lrA->rk, nrhs, n,
229  CBLAS_SADDR(cone), lrA->v, lrA->rkmax,
230  B + offB, ldb,
231  CBLAS_SADDR(czero), tmp, lrA->rk );
232 
233  pastix_cblk_lock( fcbk );
234  cblas_cgemm(
235  CblasColMajor, (CBLAS_TRANSPOSE)trans, CblasNoTrans,
236  m, nrhs, lrA->rk,
237  CBLAS_SADDR(mcone), lrA->u, lda,
238  tmp, lrA->rk,
239  CBLAS_SADDR(cone), C + offC, ldc );
240  pastix_cblk_unlock( fcbk );
241  }
242  else {
243  cblas_cgemm(
244  CblasColMajor, (CBLAS_TRANSPOSE)trans, CblasNoTrans,
245  lrA->rk, nrhs, n,
246  CBLAS_SADDR(cone), lrA->u, lda,
247  B + offB, ldb,
248  CBLAS_SADDR(czero), tmp, lrA->rk );
249 
250  pastix_cblk_lock( fcbk );
251  cblas_cgemm(
252  CblasColMajor, (CBLAS_TRANSPOSE)trans, CblasNoTrans,
253  m, nrhs, lrA->rk,
254  CBLAS_SADDR(mcone), lrA->v, lrA->rkmax,
255  tmp, lrA->rk,
256  CBLAS_SADDR(cone), C + offC, ldc );
257  pastix_cblk_unlock( fcbk );
258  }
259  memFree_null(tmp);
260  break;
261  }
262  }
263  else{
264  const pastix_complex32_t *A = dataA;
265  lda = (bowner->cblktype & CBLK_LAYOUT_2D) ? lda : bowner->stride;
266 
267  pastix_cblk_lock( fcbk );
268  cblas_cgemm(
269  CblasColMajor, (CBLAS_TRANSPOSE)trans, CblasNoTrans,
270  m, nrhs, n,
271  CBLAS_SADDR(mcone), A, lda,
272  B + offB, ldb,
273  CBLAS_SADDR(cone), C + offC, ldc );
274  pastix_cblk_unlock( fcbk );
275  }
276 }
277 
278 /**
279  *******************************************************************************
280  *
281  * @brief Apply a forward solve related to one cblk to all the right hand side.
282  *
283  *******************************************************************************
284  *
285  * @param[in] enums
286  * Enums needed for the solve.
287  *
288  * @param[in] datacode
289  * The SolverMatrix structure from PaStiX.
290  *
291  * @param[in] cblk
292  * The cblk structure to which block belongs to. The A and B pointers
293  * must be the coeftab of this column block.
294  * Next column blok must be accessible through cblk[1].
295  *
296  * @param[inout] rhsb
297  * The pointer to the rhs data structure that holds the vectors of the
298  * right hand side.
299  *
300  *******************************************************************************/
301 void
303  SolverMatrix *datacode,
304  const SolverCblk *cblk,
305  pastix_rhs_t rhsb )
306 {
307  SolverCblk *fcbk;
308  const SolverBlok *blok;
309  pastix_trans_t tA;
311  const void *dataA = NULL;
312  const pastix_lrblock_t *lrA;
313  const pastix_complex32_t *A;
314  pastix_complex32_t *B, *C;
315  pastix_int_t ldb, ldc, k;
316  pastix_fixdbl_t time;
317  pastix_fixdbl_t flops_lvl1 = 0;
318  pastix_fixdbl_t flops_lvl2 = 0;
319  pastix_side_t side = enums->side;
320  pastix_uplo_t uplo = enums->uplo;
321  pastix_trans_t trans = enums->trans;
322  pastix_diag_t diag = enums->diag;
323  pastix_solv_mode_t mode = enums->mode;
324 
326 
327  if ( (side == PastixRight) && (uplo == PastixUpper) && (trans == PastixNoTrans) ) {
328  /* We store U^t, so we swap uplo and trans */
329  tA = PastixTrans;
330  cs = PastixUCoef;
331 
332  /* Right is not handled yet */
333  assert( 0 );
334  }
335  else if ( (side == PastixRight) && (uplo == PastixLower) && (trans != PastixNoTrans) ) {
336  tA = trans;
337  cs = PastixLCoef;
338 
339  /* Right is not handled yet */
340  assert( 0 );
341  }
342  else if ( (side == PastixLeft) && (uplo == PastixUpper) && (trans != PastixNoTrans) ) {
343  /* We store U^t, so we swap uplo and trans */
344  tA = PastixNoTrans;
345  cs = PastixUCoef;
346 
347  /* We do not handle conjtrans in complex as we store U^t */
348 #if defined(PRECISION_z) || defined(PRECISION_c)
349  assert( trans != PastixConjTrans );
350 #endif
351  }
352  else if ( (side == PastixLeft) && (uplo == PastixLower) && (trans == PastixNoTrans) ) {
353  tA = trans;
354  cs = PastixLCoef;
355  }
356  else {
357  /* This correspond to case treated in backward trsm */
358  assert(0);
359  return;
360  }
361 
362  assert( !( cblk->cblktype & (CBLK_FANIN|CBLK_RECV) ) );
363 
364  if ( (cblk->cblktype & CBLK_IN_SCHUR) && (mode != PastixSolvModeSchur) ) {
365  return;
366  }
367 
368  B = rhsb->b;
369  B = B + cblk->lcolidx;
370  ldb = rhsb->ld;
371  k = cblk_colnbr( cblk );
372 
373  /* Solve the diagonal block */
374  flops_lvl2 = FLOPS_CTRSM( side, k, rhsb->n );
375  kernel_trace_start_lvl2( PastixKernelLvl2_FR_TRSM );
377  tA, diag, cblk, rhsb->n,
378  cblk_getdata( cblk, cs ),
379  B, ldb );
380  kernel_trace_stop_lvl2( flops_lvl2 );
381  flops_lvl1 += flops_lvl2;
382 
383  /* Apply the update */
384  for (blok = cblk[0].fblokptr+1; blok < cblk[1].fblokptr; blok++ ) {
385  fcbk = datacode->cblktab + blok->fcblknm;
386 
387  if ( (fcbk->cblktype & CBLK_IN_SCHUR) && (mode == PastixSolvModeLocal) ) {
388  return;
389  }
390  assert( !(fcbk->cblktype & CBLK_RECV) );
391 
392  /*
393  * Make sure we get the correct pointer to the lrA, or to the right position in [lu]coeftab
394  */
395  dataA = cblk_getdata( cblk, cs );
396  if ( cblk->cblktype & CBLK_COMPRESSED ) {
397  lrA = dataA;
398  lrA += (blok - cblk->fblokptr);
399  dataA = lrA;
400  }
401  else {
402  A = dataA;
403  A += blok->coefind;
404  dataA = A;
405  }
406 
407  /*
408  * Make sure we get the correct pointer for the C matrix.
409  */
410  if ( fcbk->cblktype & CBLK_FANIN ) {
411  C = rhsb->cblkb[ - fcbk->bcscnum - 1 ];
412  ldc = cblk_colnbr( fcbk );
413  if ( C == NULL ) {
414  C = calloc( ldc * rhsb->n, sizeof( pastix_complex32_t ) );
415  if ( !pastix_atomic_cas_xxb( &(rhsb->cblkb[ - fcbk->bcscnum - 1 ]),
416  (uint64_t)NULL, (uint64_t)C, sizeof(void*) ) )
417  {
418  free( C );
419  C = rhsb->cblkb[ - fcbk->bcscnum - 1 ];
420  }
421  }
422  }
423  else {
424  C = rhsb->b;
425  C = C + fcbk->lcolidx;
426  ldc = rhsb->ld;
427  }
428 
429  flops_lvl2 = FLOPS_CGEMM( blok_rownbr( blok ), rhsb->n, k );
430  kernel_trace_start_lvl2( PastixKernelLvl2_FR_GEMM );
431  solve_blok_cgemm( PastixLeft, tA, rhsb->n,
432  cblk, blok, fcbk,
433  dataA, B, ldb, C, ldc );
434  kernel_trace_stop_lvl2( flops_lvl2 );
435  flops_lvl1 += flops_lvl2;
436 
437  cpucblk_crelease_rhs_fwd_deps( enums, datacode,
438  rhsb, cblk, fcbk );
439  }
441  cblk_rownbr(cblk), rhsb->n, k, flops_lvl1, time );
442 }
443 
444 /**
445  *******************************************************************************
446  *
447  * @brief Apply a backward solve related to one cblk to all the right hand side.
448  *
449  *******************************************************************************
450  *
451  * @param[in] enums
452  * Enums needed for the solve.
453  *
454  * @param[in] datacode
455  * The SolverMatrix structure from PaStiX.
456  *
457  * @param[in] cblk
458  * The cblk structure to which block belongs to. The A and B pointers
459  * must be the coeftab of this column block.
460  * Next column blok must be accessible through cblk[1].
461  *
462  * @param[inout] rhsb
463  * The pointer to the rhs data structure that holds the vectors of the
464  * right hand side.
465  *
466  *******************************************************************************/
467 void
469  SolverMatrix *datacode,
470  SolverCblk *cblk,
471  pastix_rhs_t rhsb )
472 {
473  SolverCblk *fcbk;
474  const SolverBlok *blok;
475  pastix_int_t j;
476  pastix_trans_t tA;
478  const void *dataA = NULL;
479  const pastix_lrblock_t *lrA;
480  const pastix_complex32_t *A;
481  pastix_complex32_t *B, *C;
482  pastix_int_t ldb, ldc, k;
483  pastix_fixdbl_t time;
484  pastix_fixdbl_t flops_lvl1 = 0;
485  pastix_fixdbl_t flops_lvl2 = 0;
486  pastix_side_t side = enums->side;
487  pastix_uplo_t uplo = enums->uplo;
488  pastix_trans_t trans = enums->trans;
489  pastix_diag_t diag = enums->diag;
490  pastix_solv_mode_t mode = enums->mode;
491 
493  /*
494  * Left / Upper / NoTrans (Backward)
495  */
496  if ( (side == PastixLeft) && (uplo == PastixUpper) && (trans == PastixNoTrans) ) {
497  /* We store U^t, so we swap uplo and trans */
498  tA = PastixTrans;
499  cs = PastixUCoef;
500  }
501  else if ( (side == PastixLeft) && (uplo == PastixLower) && (trans != PastixNoTrans) ) {
502  tA = trans;
503  cs = PastixLCoef;
504  }
505  else if ( (side == PastixRight) && (uplo == PastixUpper) && (trans != PastixNoTrans) ) {
506  /* We store U^t, so we swap uplo and trans */
507  tA = PastixNoTrans;
508  cs = PastixUCoef;
509 
510  /* Right is not handled yet */
511  assert( 0 );
512 
513  /* We do not handle conjtrans in complex as we store U^t */
514  assert( trans != PastixConjTrans );
515  }
516  else if ( (side == PastixRight) && (uplo == PastixLower) && (trans == PastixNoTrans) ) {
517  tA = trans;
518  cs = PastixLCoef;
519 
520  /* Right is not handled yet */
521  assert( 0 );
522  }
523  else {
524  /* This correspond to case treated in forward trsm */
525  assert(0);
526  return;
527  }
528 
529  /*
530  * If cblk is in the schur complement, all brow blocks are in
531  * the interface. Thus, it doesn't generate any update in local
532  * mode, and we know that we are at least in interface mode
533  * after this test.
534  */
535  if ( (cblk->cblktype & CBLK_IN_SCHUR) && (mode == PastixSolvModeLocal) ) {
536  for (j = cblk[0].brownum; j < cblk[1].brownum; j++ ) {
537  blok = datacode->bloktab + datacode->browtab[j];
538  fcbk = datacode->cblktab + blok->lcblknm;
539 
540  if ( fcbk->cblktype & CBLK_IN_SCHUR ) {
541  break;
542  }
543  cpucblk_crelease_rhs_bwd_deps( enums, datacode,
544  rhsb, cblk, fcbk );
545  }
546  return;
547  }
548 
549  /*
550  * Make sure we get the correct pointer for the B matrix.
551  */
552  assert( !(cblk->cblktype & CBLK_RECV) );
553  if ( cblk->cblktype & CBLK_FANIN ) {
554  B = rhsb->cblkb[ - cblk->bcscnum - 1 ];
555  ldb = cblk_colnbr( cblk );
556  }
557  else {
558  B = rhsb->b;
559  B = B + cblk->lcolidx;
560  ldb = rhsb->ld;
561  }
562  k = cblk_colnbr( cblk );
563 
564  if ( !(cblk->cblktype & (CBLK_FANIN|CBLK_RECV) ) &&
565  (!(cblk->cblktype & CBLK_IN_SCHUR) || (mode == PastixSolvModeSchur)) )
566  {
567  /* Solve the diagonal block */
568  flops_lvl2 = FLOPS_CTRSM( side, k, rhsb->n );
569  kernel_trace_start_lvl2( PastixKernelLvl2_FR_TRSM );
570  solve_blok_ctrsm( side, PastixLower, tA, diag,
571  cblk, rhsb->n,
572  cblk_getdata( cblk, cs ),
573  B, ldb );
574  kernel_trace_stop_lvl2( flops_lvl2 );
575  flops_lvl1 += flops_lvl2;
576  }
577 
578  /* Apply the update */
579  for (j = cblk[1].brownum-1; j>=cblk[0].brownum; j-- ) {
580  blok = datacode->bloktab + datacode->browtab[j];
581  fcbk = datacode->cblktab + blok->lcblknm;
582 
583  if ( (fcbk->cblktype & CBLK_IN_SCHUR) && (mode == PastixSolvModeInterface) ) {
584  continue;
585  }
586 
587  if ( fcbk->cblktype & CBLK_RECV ) {
588 #if defined( PASTIX_WITH_MPI )
589  /* If PastixSchedSequential, then the communications are done syncrhonously */
590  if( enums->sched != PastixSchedSequential ) {
591  assert( datacode->reqtab != NULL );
592  cpucblk_cisend_rhs_bwd( datacode, rhsb, fcbk );
593  }
594 #endif
595  continue;
596  }
597  assert( !(fcbk->cblktype & CBLK_FANIN) );
598 
599  /*
600  * Make sure we get the correct pointer to the lrA, or to the right position in [lu]coeftab
601  */
602  dataA = cblk_getdata( fcbk, cs );
603  if ( fcbk->cblktype & CBLK_COMPRESSED ) {
604  lrA = dataA;
605  lrA += (blok - fcbk->fblokptr);
606  dataA = lrA;
607  }
608  else {
609  A = dataA;
610  A += blok->coefind;
611  dataA = A;
612  }
613 
614  /*
615  * Make sure we get the correct pointer for the C matrix.
616  */
617  C = rhsb->b;
618  C = C + fcbk->lcolidx;
619  ldc = rhsb->ld;
620 
621  flops_lvl2 = FLOPS_CGEMM( blok_rownbr( blok ), rhsb->n, k );
622  kernel_trace_start_lvl2( PastixKernelLvl2_FR_GEMM );
623  solve_blok_cgemm( PastixRight, tA, rhsb->n,
624  cblk, blok, fcbk,
625  dataA, B, ldb, C, ldc );
626  kernel_trace_stop_lvl2( flops_lvl2 );
627  flops_lvl1 += flops_lvl2;
628 
629  cpucblk_crelease_rhs_bwd_deps( enums, datacode,
630  rhsb, cblk, fcbk );
631  }
632 
633  if ( cblk->cblktype & CBLK_FANIN ) {
634  memFree_null( rhsb->cblkb[ - cblk->bcscnum - 1 ] );
635  }
636  kernel_trace_stop( cblk->fblokptr->inlast, PastixKernelTRSMBack, cblk_rownbr( cblk ), rhsb->n, k, flops_lvl1, time );
637 }
638 
639 /**
640  *******************************************************************************
641  *
642  * @brief Apply the diagonal solve related to one cblk to all the right hand side.
643  *
644  *******************************************************************************
645  *
646  * @param[in] cblk
647  * The cblk structure to which diagonal block belongs to.
648  *
649  * @param[in] nrhs
650  * The number of right hand side
651  *
652  * @param[inout] b
653  * The pointer to vectors of the right hand side
654  *
655  * @param[in] ldb
656  * The leading dimension of b
657  *
658  * @param[inout] work
659  * Workspace to temporarily store the diagonal when multiple RHS are
660  * involved. Might be set to NULL for internal allocation on need.
661  *
662  *******************************************************************************/
663 void
665  const void *dataA,
666  int nrhs,
668  int ldb,
669  pastix_complex32_t *work )
670 {
671  const pastix_complex32_t *A;
672  pastix_complex32_t *tmp;
673  pastix_int_t k, j, tempn, lda;
674 
675  tempn = cblk->lcolnum - cblk->fcolnum + 1;
676  lda = (cblk->cblktype & CBLK_LAYOUT_2D) ? tempn : cblk->stride;
677  assert( blok_rownbr( cblk->fblokptr ) == tempn );
678 
679  if ( cblk->cblktype & CBLK_COMPRESSED ) {
680  const pastix_lrblock_t *lrA = (const pastix_lrblock_t*)dataA;
681  A = lrA->u;
682  assert( lrA->rkmax == lda );
683  }
684  else {
685  A = (const pastix_complex32_t*)dataA;
686  }
687 
688  /* Add shift for diagonal elements */
689  lda++;
690 
691  if( nrhs == 1 ) {
692  for (j=0; j<tempn; j++, b++, A+=lda) {
693  *b = (*b) / (*A);
694  }
695  }
696  else {
697  /* Copy the diagonal to a temporary buffer */
698  tmp = work;
699  if ( work == NULL ) {
700  MALLOC_INTERN( tmp, tempn, pastix_complex32_t );
701  }
702  cblas_ccopy( tempn, A, lda, tmp, 1 );
703 
704  /* Compute */
705  for (k=0; k<nrhs; k++, b+=ldb)
706  {
707  for (j=0; j<tempn; j++) {
708  b[j] /= tmp[j];
709  }
710  }
711 
712  if ( work == NULL ) {
713  memFree_null(tmp);
714  }
715  }
716 }
BEGIN_C_DECLS typedef int pastix_int_t
Definition: datatypes.h:51
float _Complex pastix_complex32_t
Definition: datatypes.h:76
double pastix_fixdbl_t
Definition: datatypes.h:65
static void kernel_trace_stop(int8_t inlast, pastix_ktype_t ktype, int m, int n, int k, double flops, double starttime)
Stop the trace of a single kernel.
static double kernel_trace_start(pastix_ktype_t ktype)
Start the trace of a single kernel.
Definition: kernels_trace.h:87
@ PastixKernelTRSMBack
Definition: kernels_enums.h:56
@ PastixKernelTRSMForw
Definition: kernels_enums.h:55
void cpucblk_crelease_rhs_fwd_deps(const args_solve_t *enums, SolverMatrix *solvmtx, pastix_rhs_t rhsb, const SolverCblk *cblk, SolverCblk *fcbk)
Release the dependencies of the given cblk after an update.
void cpucblk_crelease_rhs_bwd_deps(const args_solve_t *enums, SolverMatrix *solvmtx, pastix_rhs_t rhsb, const SolverCblk *cblk, SolverCblk *fcbk)
Release the dependencies of the given cblk after an update.
The block low-rank structure to hold a matrix in low-rank form.
void solve_cblk_ctrsmsp_backward(const args_solve_t *enums, SolverMatrix *datacode, SolverCblk *cblk, pastix_rhs_t rhsb)
Apply a backward solve related to one cblk to all the right hand side.
void solve_cblk_cdiag(const SolverCblk *cblk, const void *dataA, int nrhs, pastix_complex32_t *b, int ldb, pastix_complex32_t *work)
Apply the diagonal solve related to one cblk to all the right hand side.
void solve_blok_ctrsm(pastix_side_t side, pastix_uplo_t uplo, pastix_trans_t trans, pastix_diag_t diag, const SolverCblk *cblk, int nrhs, const void *dataA, pastix_complex32_t *b, int ldb)
Apply a solve trsm update related to a diagonal block of the matrix A.
Definition: solve_ctrsmsp.c:76
void solve_cblk_ctrsmsp_forward(const args_solve_t *enums, SolverMatrix *datacode, const SolverCblk *cblk, pastix_rhs_t rhsb)
Apply a forward solve related to one cblk to all the right hand side.
void solve_blok_cgemm(pastix_side_t side, pastix_trans_t trans, pastix_int_t nrhs, const SolverCblk *cblk, const SolverBlok *blok, SolverCblk *fcbk, const void *dataA, const pastix_complex32_t *B, pastix_int_t ldb, pastix_complex32_t *C, pastix_int_t ldc)
Apply a solve gemm update related to a single block of the matrix A.
enum pastix_diag_e pastix_diag_t
Diagonal.
enum pastix_solv_mode_e pastix_solv_mode_t
Solve Schur modes.
enum pastix_uplo_e pastix_uplo_t
Upper/Lower part.
enum pastix_side_e pastix_side_t
Side of the operation.
enum pastix_trans_e pastix_trans_t
Transpostion.
enum pastix_coefside_e pastix_coefside_t
Data blocks used in the kernel.
@ PastixLCoef
Definition: api.h:478
@ PastixUCoef
Definition: api.h:479
@ PastixUpper
Definition: api.h:466
@ PastixLower
Definition: api.h:467
@ PastixRight
Definition: api.h:496
@ PastixLeft
Definition: api.h:495
@ PastixConjTrans
Definition: api.h:447
@ PastixNoTrans
Definition: api.h:445
@ PastixTrans
Definition: api.h:446
@ PastixSchedSequential
Definition: api.h:334
void ** cblkb
Definition: pastixdata.h:162
pastix_int_t ld
Definition: pastixdata.h:160
pastix_int_t n
Definition: pastixdata.h:159
Main PaStiX RHS structure.
Definition: pastixdata.h:155
static pastix_int_t blok_rownbr(const SolverBlok *blok)
Compute the number of rows of a block.
Definition: solver.h:395
pastix_int_t brownum
Definition: solver.h:171
static pastix_int_t cblk_colnbr(const SolverCblk *cblk)
Compute the number of columns in a column block.
Definition: solver.h:329
pastix_int_t fcblknm
Definition: solver.h:144
MPI_Request * reqtab
Definition: solver.h:269
SolverBlok *restrict bloktab
Definition: solver.h:229
pastix_int_t frownum
Definition: solver.h:147
static pastix_int_t cblk_rownbr(const SolverCblk *cblk)
Compute the number of rows of a column block.
Definition: solver.h:449
static void * cblk_getdata(const SolverCblk *cblk, pastix_coefside_t side)
Get the pointer to the data associated to the side part of the cblk.
Definition: solver.h:369
pastix_int_t coefind
Definition: solver.h:149
SolverBlok * fblokptr
Definition: solver.h:168
pastix_int_t *restrict browtab
Definition: solver.h:230
pastix_int_t lcblknm
Definition: solver.h:143
pastix_int_t lcolidx
Definition: solver.h:170
int8_t inlast
Definition: solver.h:151
pastix_int_t bcscnum
Definition: solver.h:175
SolverCblk *restrict cblktab
Definition: solver.h:228
pastix_int_t stride
Definition: solver.h:169
int8_t cblktype
Definition: solver.h:164
pastix_int_t lcolnum
Definition: solver.h:167
pastix_int_t fcolnum
Definition: solver.h:166
Arguments for the solve.
Definition: solver.h:88
Solver block structure.
Definition: solver.h:141
Solver column block structure.
Definition: solver.h:161
Solver column block structure.
Definition: solver.h:203