PaStiX Handbook  6.4.0
starpu_ztrsm.c
Go to the documentation of this file.
1 /**
2  *
3  * @file starpu_ztrsm.c
4  *
5  * PaStiX ztrsm StarPU wrapper.
6  *
7  * @copyright 2016-2024 Bordeaux INP, CNRS (LaBRI UMR 5800), Inria,
8  * Univ. Bordeaux. All rights reserved.
9  *
10  * @version 6.4.0
11  * @author Vincent Bridonneau
12  * @author Mathieu Faverge
13  * @author Pierre Ramet
14  * @author Alycia Lisito
15  * @author Nolan Bredel
16  * @author Tom Moenne-Loccoz
17  * @date 2024-07-05
18  * @generated from /builds/solverstack/pastix/sopalin/starpu/starpu_ztrsm.c, normal z -> z, Tue Oct 8 14:17:32 2024
19  *
20  * @addtogroup starpu_trsm_solve
21  * @{
22  *
23  **/
24 
25 #include "common.h"
26 #include "blend/solver.h"
27 #include "sopalin/sopalin_data.h"
28 #include "pastix_zcores.h"
29 #include "pastix_starpu.h"
30 #include "pastix_zstarpu.h"
31 
32 /**
33  *******************************************************************************
34  *
35  * @brief Apply a forward solve related to one cblk to all the right hand side.
36  * (StarPU version)
37  *
38  ********************************************************************************
39  *
40  * @param[in] enums
41  * Enums needed for the solve.
42  *
43  * @param[in] sopalin_data
44  * The data that provide the SolverMatrix structure from PaStiX, and
45  * descriptor of b (providing nrhs, b and ldb).
46  *
47  * @param[inout] rhsb
48  * The pointer to the rhs data structure that holds the vectors of the
49  * right hand side.
50  *
51  * @param[in] cblk
52  * The cblk structure to which block belongs to. The A and B pointers
53  * must be the coeftab of this column block.
54  * Next column blok must be accessible through cblk[1].
55  *
56  * @param[in] prio
57  * The priority of the task in th DAG.
58  *
59  *******************************************************************************/
60 void
62  sopalin_data_t *sopalin_data,
63  pastix_rhs_t rhsb,
64  const SolverCblk *cblk,
65  pastix_int_t prio )
66 {
68  SolverMatrix *datacode = sopalin_data->solvmtx;
69  SolverCblk *fcbk;
70  SolverBlok *blok;
71  pastix_trans_t tA;
72  pastix_side_t side = enums->side;
73  pastix_uplo_t uplo = enums->uplo;
74  pastix_trans_t trans = enums->trans;
75  pastix_diag_t diag = enums->diag;
76  pastix_solv_mode_t mode = enums->mode;
77 
78  if ( (cblk->cblktype & CBLK_IN_SCHUR) && (mode != PastixSolvModeSchur) ) {
79  return;
80  }
81 
82  if ( (side == PastixRight) && (uplo == PastixUpper) && (trans == PastixNoTrans) ) {
83  /* We store U^t, so we swap uplo and trans */
84  tA = PastixTrans;
85  cs = PastixUCoef;
86 
87  /* Right is not handled yet */
88  assert( 0 );
89  }
90  else if ( (side == PastixRight) && (uplo == PastixLower) && (trans != PastixNoTrans) ) {
91  tA = trans;
92  cs = PastixLCoef;
93 
94  /* Right is not handled yet */
95  assert( 0 );
96  }
97  else if ( (side == PastixLeft) && (uplo == PastixUpper) && (trans != PastixNoTrans) ) {
98  /* We store U^t, so we swap uplo and trans */
99  tA = PastixNoTrans;
100  cs = PastixUCoef;
101 
102  /* We do not handle conjtrans in complex as we store U^t */
103  assert( trans != PastixConjTrans );
104  }
105  else if ( (side == PastixLeft) && (uplo == PastixLower) && (trans == PastixNoTrans) ) {
106  tA = trans;
107  cs = PastixLCoef;
108  }
109  else {
110  /* This correspond to case treated in backward TRSM */
111  assert(0);
112  return;
113  }
114 
115  /* Solve the diagonal block */
116  starpu_stask_blok_ztrsm( sopalin_data, rhsb, cs, side, PastixLower,
117  tA, diag, cblk, prio );
118 
119  /* Apply the update */
120  for (blok = cblk[0].fblokptr+1; blok < cblk[1].fblokptr; blok++ ) {
121  fcbk = datacode->cblktab + blok->fcblknm;
122 
123  if ( (fcbk->cblktype & CBLK_IN_SCHUR) && (mode == PastixSolvModeLocal) ) {
124  return;
125  }
126 
127  starpu_stask_blok_zgemm( sopalin_data, rhsb, cs, PastixLeft, tA,
128  cblk, blok, fcbk, prio );
129  }
130 }
131 
132 /**
133  *******************************************************************************
134  *
135  * @brief Apply a backward solve related to one cblk to all the right hand side.
136  * (StarPU version)
137  *
138  *******************************************************************************
139  *
140  * @param[in] enums
141  * Enums needed for the solve.
142  *
143  * @param[in] sopalin_data
144  * The data that provide the SolverMatrix structure from PaStiX, and
145  * descriptor of b (providing nrhs, b and ldb).
146  *
147  * @param[inout] rhsb
148  * The pointer to the rhs data structure that holds the vectors of the
149  * right hand side.
150  *
151  * @param[in] cblk
152  * The cblk structure to which block belongs to. The A and B pointers
153  * must be the coeftab of this column block.
154  * Next column blok must be accessible through cblk[1].
155  *
156  * @param[in] prio
157  * The priority of the task in th DAG.
158  *
159  *******************************************************************************/
160 void
162  sopalin_data_t *sopalin_data,
163  pastix_rhs_t rhsb,
164  const SolverCblk *cblk,
165  pastix_int_t prio )
166 {
168  SolverMatrix *datacode = sopalin_data->solvmtx;
169  SolverCblk *fcbk;
170  SolverBlok *blok;
171  pastix_int_t j;
172  pastix_trans_t tA;
173  pastix_side_t side = enums->side;
174  pastix_uplo_t uplo = enums->uplo;
175  pastix_trans_t trans = enums->trans;
176  pastix_diag_t diag = enums->diag;
177  pastix_solv_mode_t mode = enums->mode;
178 
179  /*
180  * Left / Upper / NoTrans (Backward)
181  */
182  if ( (side == PastixLeft) && (uplo == PastixUpper) && (trans == PastixNoTrans) ) {
183  /* We store U^t, so we swap uplo and trans */
184  tA = PastixTrans;
185  cs = PastixUCoef;
186  }
187  else if ( (side == PastixLeft) && (uplo == PastixLower) && (trans != PastixNoTrans) ) {
188  tA = trans;
189  cs = PastixLCoef;
190  }
191  else if ( (side == PastixRight) && (uplo == PastixUpper) && (trans != PastixNoTrans) ) {
192  /* We store U^t, so we swap uplo and trans */
193  tA = PastixNoTrans;
194  cs = PastixUCoef;
195 
196  /* Right is not handled yet */
197  assert( 0 );
198 
199  /* We do not handle conjtrans in complex as we store U^t */
200  assert( trans != PastixConjTrans );
201  }
202  else if ( (side == PastixRight) && (uplo == PastixLower) && (trans == PastixNoTrans) ) {
203  tA = trans;
204  cs = PastixLCoef;
205 
206  /* Right is not handled yet */
207  assert( 0 );
208  }
209  else {
210  /* This correspond to case treated in forward TRSM */
211  assert(0);
212  return;
213  }
214  assert( !(cblk->cblktype & CBLK_RECV) );
215 
216  if ( ( !(cblk->cblktype & CBLK_IN_SCHUR) || (mode == PastixSolvModeSchur) ) &&
217  ( !(cblk->cblktype & CBLK_FANIN) ) )
218  {
219  /* Solve the diagonal block */
220  starpu_stask_blok_ztrsm( sopalin_data, rhsb, cs, side, PastixLower,
221  tA, diag, cblk, prio );
222  }
223 
224  /* Apply the update */
225  for (j = cblk[1].brownum-1; j>=cblk[0].brownum; j-- ) {
226  blok = datacode->bloktab + datacode->browtab[j];
227  fcbk = datacode->cblktab + blok->lcblknm;
228 
229  if ( (fcbk->cblktype & CBLK_IN_SCHUR) && (mode == PastixSolvModeInterface) ) {
230  continue;
231  }
232  if ( fcbk->cblktype & CBLK_RECV ) {
233  continue;
234  }
235 
236  starpu_stask_blok_zgemm( sopalin_data, rhsb, cs, PastixRight, tA,
237  cblk, blok, fcbk, prio );
238  }
239 }
240 
241 /**
242  *******************************************************************************
243  *
244  * @brief Apply a TRSM on a problem with 1 dimension (StarPU version)
245  *
246  *******************************************************************************
247  *
248  * @param[in] pastix_data
249  * The data that provide the mode.
250  *
251  * @param[in] sopalin_data
252  * The data that provide the SolverMatrix structure from PaStiX., and
253  * descriptor of b (providing nrhs, b and ldb).
254  *
255  * @param[inout] rhsb
256  * The pointer to the rhs data structure that holds the vectors of the
257  * right hand side.
258  *
259  * @param[in] enums
260  * Enums needed for the solve.
261  *
262  *******************************************************************************/
263 void
265  sopalin_data_t *sopalin_data,
266  pastix_rhs_t rhsb,
267  const args_solve_t *enums )
268 {
269  SolverMatrix *datacode = sopalin_data->solvmtx;
270  SolverCblk *cblk, *fcblk;
271  pastix_int_t i, cblknbr, prio;
272 
273  /* Backward like */
274  if ( enums->solve_step == PastixSolveBackward ) {
275  cblknbr = (enums->mode == PastixSolvModeLocal) ? datacode->cblkschur : datacode->cblknbr;
276  cblk = datacode->cblktab + cblknbr - 1;
277 
278  for ( i = cblknbr-1; i >= 0; i--, cblk-- ) {
279  prio = i;
280 
281 #if defined(PASTIX_WITH_MPI)
282  /* If this is a recv, let's locally copy and send the accumulation */
283  if ( cblk->cblktype & CBLK_RECV ) {
284  fcblk = datacode->cblktab + cblk->fblokptr->fcblknm;
285  starpu_stask_blok_zcpy_bwd_recv( sopalin_data, rhsb, cblk, fcblk, prio );
286  starpu_mpi_data_migrate( datacode->solv_comm,
287  rhsb->starpu_desc->handletab[i],
288  cblk->ownerid );
289  continue;
290  }
291 
292  /* If this is a fanin, let's submit the receive */
293  if ( cblk->cblktype & CBLK_FANIN ) {
294  starpu_mpi_data_migrate( datacode->solv_comm,
295  rhsb->starpu_desc->handletab[i],
296  datacode->clustnum );
297  }
298 #endif
299  starpu_cblk_ztrsmsp_backward( enums, sopalin_data, rhsb, cblk, prio );
300  }
301  }
302  /* Forward like */
303  else {
304  cblknbr = (enums->mode == PastixSolvModeSchur) ? datacode->cblknbr : datacode->cblkschur;
305  cblk = datacode->cblktab;
306 
307  for ( i = 0; i < cblknbr; i++, cblk++ ) {
308  prio = cblknbr - i;
309 
310 #if defined(PASTIX_WITH_MPI)
311  /* If this is a fanin, let's submit the send */
312  if ( cblk->cblktype & CBLK_FANIN ) {
313  starpu_mpi_data_migrate( datacode->solv_comm,
314  rhsb->starpu_desc->handletab[i],
315  cblk->ownerid );
316  continue;
317  }
318 
319  /* If this is a recv, let's locally sum the accumulation received */
320  if ( cblk->cblktype & CBLK_RECV ) {
321  starpu_mpi_data_migrate( datacode->solv_comm,
322  rhsb->starpu_desc->handletab[i],
323  datacode->clustnum );
324  fcblk = datacode->cblktab + cblk->fblokptr->fcblknm;
325  starpu_stask_blok_zadd_fwd_recv( sopalin_data, rhsb, cblk, fcblk, prio );
326  continue;
327  }
328 #endif
329  starpu_cblk_ztrsmsp_forward( enums, sopalin_data, rhsb, cblk, prio );
330  }
331  }
332 
333  for ( i = 0; i < cblknbr; i++ ) {
334  starpu_data_wont_use( rhsb->starpu_desc->handletab[i] );
335  }
336  (void)pastix_data;
337 }
338 
339 /**
340  *******************************************************************************
341  *
342  * @brief Apply the TRSM solve (StarPU version).
343  *
344  *******************************************************************************
345  *
346  * @param[in] pastix_data
347  * Provide informations about starpu and the schur solving mode.
348  *
349  * @param[in] enums
350  * Enums needed for the solve.
351  *
352  * @param[in] sopalin_data
353  * The data that provide the SolverMatrix structure from PaStiX, and
354  * descriptor of b (providing nrhs, b and ldb).
355  *
356  * @param[inout] rhsb
357  * The pointer to the rhs data structure that holds the vectors of the
358  * right hand side.
359  *
360  *******************************************************************************/
361 void
363  const args_solve_t *enums,
364  sopalin_data_t *sopalin_data,
365  pastix_rhs_t rhsb )
366 {
367  starpu_sparse_matrix_desc_t *sdesc = sopalin_data->solvmtx->starpu_desc;
368  starpu_rhs_desc_t *ddesc = rhsb->starpu_desc;
369 
370  /*
371  * Start StarPU if not already started
372  */
373  if (pastix_data->starpu == NULL) {
374  int argc = 0;
375  pastix_starpu_init( pastix_data, &argc, NULL, NULL );
376  }
377 
378  if ( sdesc == NULL ) {
379  /* Create the sparse matrix descriptor */
380  starpu_sparse_matrix_init( sopalin_data->solvmtx,
382  pastix_data->inter_node_procnbr,
383  pastix_data->inter_node_procnum,
384  PastixComplex64 );
385  sdesc = sopalin_data->solvmtx->starpu_desc;
386  }
387 
388  if ( ddesc == NULL ) {
389  /* Create the dense matrix descriptor */
390  starpu_rhs_init( pastix_data->solvmatr, rhsb,
391  PastixComplex64,
392  pastix_data->inter_node_procnbr,
393  pastix_data->inter_node_procnum );
394  ddesc = rhsb->starpu_desc;
395  }
396 
397 #if defined(STARPU_USE_FXT)
398  if (pastix_data->iparm[IPARM_TRACE] & PastixTraceSolve) {
399  starpu_fxt_start_profiling();
400  }
401 #endif
402  starpu_resume();
403  starpu_ztrsm_sp1dplus( pastix_data, sopalin_data, rhsb, enums );
404 
406  starpu_rhs_getoncpu( ddesc );
407  starpu_task_wait_for_all();
408 #if defined(PASTIX_WITH_MPI)
409  starpu_mpi_wait_for_all( pastix_data->pastix_comm );
410  starpu_mpi_barrier(pastix_data->inter_node_comm);
411 #endif
412  starpu_pause();
413 #if defined(STARPU_USE_FXT)
414  if (pastix_data->iparm[IPARM_TRACE] & PastixTraceSolve) {
415  starpu_fxt_stop_profiling();
416  }
417 #endif
418 
419  return;
420 }
421 
422 /**
423  *@}
424  */
BEGIN_C_DECLS typedef int pastix_int_t
Definition: datatypes.h:51
enum pastix_diag_e pastix_diag_t
Diagonal.
enum pastix_solv_mode_e pastix_solv_mode_t
Solve Schur modes.
enum pastix_uplo_e pastix_uplo_t
Upper/Lower part.
#define PastixHermitian
Definition: api.h:460
enum pastix_side_e pastix_side_t
Side of the operation.
enum pastix_trans_e pastix_trans_t
Transpostion.
enum pastix_coefside_e pastix_coefside_t
Data blocks used in the kernel.
@ PastixLCoef
Definition: api.h:478
@ PastixUCoef
Definition: api.h:479
@ IPARM_TRACE
Definition: api.h:44
@ PastixUpper
Definition: api.h:466
@ PastixLower
Definition: api.h:467
@ PastixRight
Definition: api.h:496
@ PastixLeft
Definition: api.h:495
@ PastixConjTrans
Definition: api.h:447
@ PastixNoTrans
Definition: api.h:445
@ PastixTrans
Definition: api.h:446
@ PastixTraceSolve
Definition: api.h:212
void starpu_sparse_matrix_getoncpu(starpu_sparse_matrix_desc_t *desc)
Submit asynchronous calls to retrieve the data on main memory.
void starpu_rhs_getoncpu(starpu_rhs_desc_t *desc)
Submit asynchronous calls to retrieve the data on main memory.
Definition: starpu_rhs.c:216
void starpu_rhs_init(SolverMatrix *solvmtx, pastix_rhs_t rhsb, int typesze, int nodes, int myrank)
Generate the StarPU descriptor of the dense matrix.
Definition: starpu_rhs.c:152
void starpu_sparse_matrix_init(SolverMatrix *solvmtx, pastix_mtxtype_t mtxtype, int nodes, int myrank, pastix_coeftype_t flttype)
Generate the StarPU descriptor of the sparse matrix.
void pastix_starpu_init(pastix_data_t *pastix, int *argc, char **argv[], const int *bindtab)
Startup the StarPU runtime system.
Definition: starpu.c:92
void starpu_stask_blok_zgemm(sopalin_data_t *sopalin_data, pastix_rhs_t rhsb, pastix_coefside_t coef, pastix_side_t side, pastix_trans_t trans, const SolverCblk *cblk, const SolverBlok *blok, SolverCblk *fcbk, pastix_int_t prio)
Submit a task to perform a gemm.
void starpu_stask_blok_zcpy_bwd_recv(sopalin_data_t *sopalin_data, pastix_rhs_t rhsb, SolverCblk *cblk, const SolverCblk *fcblk, int prio)
Insert the task to add a fanin cblk on the receiver side (The fanin is seen on this side as the RECV ...
void starpu_stask_blok_zadd_fwd_recv(sopalin_data_t *sopalin_data, pastix_rhs_t rhsb, const SolverCblk *cblk, SolverCblk *fcblk, int prio)
Insert the task to add a fanin cblk on the receiver side (The fanin is seen on this side as the RECV ...
StarPU descriptor for the vectors linked to a given sparse matrix.
StarPU descriptor stucture for the sparse matrix.
PASTIX_Comm pastix_comm
Definition: pastixdata.h:76
int inter_node_procnum
Definition: pastixdata.h:84
SolverMatrix * solvmatr
Definition: pastixdata.h:103
int inter_node_procnbr
Definition: pastixdata.h:83
void * starpu
Definition: pastixdata.h:88
pastix_int_t * iparm
Definition: pastixdata.h:70
PASTIX_Comm inter_node_comm
Definition: pastixdata.h:78
Main PaStiX data structure.
Definition: pastixdata.h:68
Main PaStiX RHS structure.
Definition: pastixdata.h:155
void starpu_cblk_ztrsmsp_forward(const args_solve_t *enums, sopalin_data_t *sopalin_data, pastix_rhs_t rhsb, const SolverCblk *cblk, pastix_int_t prio)
Apply a forward solve related to one cblk to all the right hand side. (StarPU version)
Definition: starpu_ztrsm.c:61
void starpu_stask_blok_ztrsm(sopalin_data_t *sopalin_data, pastix_rhs_t rhsb, pastix_coefside_t coef, pastix_side_t side, pastix_uplo_t uplo, pastix_trans_t trans, pastix_diag_t diag, const SolverCblk *cblk, pastix_int_t prio)
Submit a task to do a trsm related to a diagonal block of the matrix A.
void starpu_ztrsm(pastix_data_t *pastix_data, const args_solve_t *enums, sopalin_data_t *sopalin_data, pastix_rhs_t rhsb)
Apply the TRSM solve (StarPU version).
Definition: starpu_ztrsm.c:362
void starpu_ztrsm_sp1dplus(pastix_data_t *pastix_data, sopalin_data_t *sopalin_data, pastix_rhs_t rhsb, const args_solve_t *enums)
Apply a TRSM on a problem with 1 dimension (StarPU version)
Definition: starpu_ztrsm.c:264
void starpu_cblk_ztrsmsp_backward(const args_solve_t *enums, sopalin_data_t *sopalin_data, pastix_rhs_t rhsb, const SolverCblk *cblk, pastix_int_t prio)
Apply a backward solve related to one cblk to all the right hand side. (StarPU version)
Definition: starpu_ztrsm.c:161
pastix_int_t brownum
Definition: solver.h:171
pastix_int_t fcblknm
Definition: solver.h:144
pastix_int_t cblknbr
Definition: solver.h:211
SolverBlok *restrict bloktab
Definition: solver.h:229
SolverBlok * fblokptr
Definition: solver.h:168
pastix_int_t *restrict browtab
Definition: solver.h:230
pastix_int_t lcblknm
Definition: solver.h:143
SolverCblk *restrict cblktab
Definition: solver.h:228
int8_t cblktype
Definition: solver.h:164
int ownerid
Definition: solver.h:181
pastix_int_t cblkschur
Definition: solver.h:221
Arguments for the solve.
Definition: solver.h:88
Solver block structure.
Definition: solver.h:141
Solver column block structure.
Definition: solver.h:161
Solver column block structure.
Definition: solver.h:203