PaStiX Handbook  6.3.2
starpu_ztrsm.c
Go to the documentation of this file.
1 /**
2  *
3  * @file starpu_ztrsm.c
4  *
5  * PaStiX ztrsm StarPU wrapper.
6  *
7  * @copyright 2016-2023 Bordeaux INP, CNRS (LaBRI UMR 5800), Inria,
8  * Univ. Bordeaux. All rights reserved.
9  *
10  * @version 6.3.2
11  * @author Vincent Bridonneau
12  * @author Mathieu Faverge
13  * @author Pierre Ramet
14  * @date 2023-07-21
15  * @generated from /builds/solverstack/pastix/sopalin/starpu/starpu_ztrsm.c, normal z -> z, Wed Dec 13 12:09:25 2023
16  *
17  * @addtogroup starpu_trsm_solve
18  * @{
19  *
20  **/
21 
22 #include "common.h"
23 #include "blend/solver.h"
24 #include "sopalin/sopalin_data.h"
25 #include "pastix_zcores.h"
26 #include "pastix_starpu.h"
27 #include "pastix_zstarpu.h"
28 
29 /**
30  *******************************************************************************
31  *
32  * @brief Apply a forward solve related to one cblk to all the right hand side.
33  * (StarPU version)
34  *
35  ********************************************************************************
36  *
37  * @param[in] enums
38  * Enums needed for the solve.
39  *
40  * @param[in] sopalin_data
41  * The data that provide the SolverMatrix structure from PaStiX, and
42  * descriptor of b (providing nrhs, b and ldb).
43  *
44  * @param[in] cblk
45  * The cblk structure to which block belongs to. The A and B pointers
46  * must be the coeftab of this column block.
47  * Next column blok must be accessible through cblk[1].
48  *
49  * @param[in] prio
50  * The priority of the task in th DAG.
51  *
52  *******************************************************************************/
53 void
55  sopalin_data_t *sopalin_data,
56  const SolverCblk *cblk,
57  pastix_int_t prio )
58 {
60  SolverMatrix *datacode = sopalin_data->solvmtx;
61  SolverCblk *fcbk;
62  SolverBlok *blok;
63  pastix_trans_t tA;
64  pastix_side_t side = enums->side;
65  pastix_uplo_t uplo = enums->uplo;
66  pastix_trans_t trans = enums->trans;
67  pastix_diag_t diag = enums->diag;
68  pastix_solv_mode_t mode = enums->mode;
69 
70  if ( (cblk->cblktype & CBLK_IN_SCHUR) && (mode != PastixSolvModeSchur) ) {
71  return;
72  }
73 
74  if ( (side == PastixRight) && (uplo == PastixUpper) && (trans == PastixNoTrans) ) {
75  /* We store U^t, so we swap uplo and trans */
76  tA = PastixTrans;
77  cs = PastixUCoef;
78 
79  /* Right is not handled yet */
80  assert( 0 );
81  }
82  else if ( (side == PastixRight) && (uplo == PastixLower) && (trans != PastixNoTrans) ) {
83  tA = trans;
84  cs = PastixLCoef;
85 
86  /* Right is not handled yet */
87  assert( 0 );
88  }
89  else if ( (side == PastixLeft) && (uplo == PastixUpper) && (trans != PastixNoTrans) ) {
90  /* We store U^t, so we swap uplo and trans */
91  tA = PastixNoTrans;
92  cs = PastixUCoef;
93 
94  /* We do not handle conjtrans in complex as we store U^t */
95  assert( trans != PastixConjTrans );
96  }
97  else if ( (side == PastixLeft) && (uplo == PastixLower) && (trans == PastixNoTrans) ) {
98  tA = trans;
99  cs = PastixLCoef;
100  }
101  else {
102  /* This correspond to case treated in backward TRSM */
103  assert(0);
104  return;
105  }
106 
107  /* In sequential */
108  assert( cblk->fcolnum == cblk->lcolidx );
109 
110  /* Solve the diagonal block */
112  sopalin_data, cs, side, PastixLower, tA, diag, cblk, prio );
113 
114  /* Apply the update */
115  for (blok = cblk[0].fblokptr+1; blok < cblk[1].fblokptr; blok++ ) {
116  fcbk = datacode->cblktab + blok->fcblknm;
117 
118  if ( (fcbk->cblktype & CBLK_IN_SCHUR) && (mode == PastixSolvModeLocal) ) {
119  return;
120  }
121 
122  starpu_stask_blok_zgemm( sopalin_data, cs, PastixLeft, tA,
123  cblk, blok, fcbk, prio );
124  }
125 }
126 
127 /**
128  *******************************************************************************
129  *
130  * @brief Apply a backward solve related to one cblk to all the right hand side.
131  * (StarPU version)
132  *
133  *******************************************************************************
134  *
135  * @param[in] enums
136  * Enums needed for the solve.
137  *
138  * @param[in] sopalin_data
139  * The data that provide the SolverMatrix structure from PaStiX, and
140  * descriptor of b (providing nrhs, b and ldb).
141  *
142  * @param[in] cblk
143  * The cblk structure to which block belongs to. The A and B pointers
144  * must be the coeftab of this column block.
145  * Next column blok must be accessible through cblk[1].
146  *
147  * @param[in] prio
148  * The priority of the task in th DAG.
149  *
150  *******************************************************************************/
151 void
153  sopalin_data_t *sopalin_data,
154  const SolverCblk *cblk,
155  pastix_int_t prio )
156 {
158  SolverMatrix *datacode = sopalin_data->solvmtx;
159  SolverCblk *fcbk;
160  SolverBlok *blok;
161  pastix_int_t j;
162  pastix_trans_t tA;
163  pastix_side_t side = enums->side;
164  pastix_uplo_t uplo = enums->uplo;
165  pastix_trans_t trans = enums->trans;
166  pastix_diag_t diag = enums->diag;
167  pastix_solv_mode_t mode = enums->mode;
168 
169  /*
170  * Left / Upper / NoTrans (Backward)
171  */
172  if ( (side == PastixLeft) && (uplo == PastixUpper) && (trans == PastixNoTrans) ) {
173  /* We store U^t, so we swap uplo and trans */
174  tA = PastixTrans;
175  cs = PastixUCoef;
176  }
177  else if ( (side == PastixLeft) && (uplo == PastixLower) && (trans != PastixNoTrans) ) {
178  tA = trans;
179  cs = PastixLCoef;
180  }
181  else if ( (side == PastixRight) && (uplo == PastixUpper) && (trans != PastixNoTrans) ) {
182  /* We store U^t, so we swap uplo and trans */
183  tA = PastixNoTrans;
184  cs = PastixUCoef;
185 
186  /* Right is not handled yet */
187  assert( 0 );
188 
189  /* We do not handle conjtrans in complex as we store U^t */
190  assert( trans != PastixConjTrans );
191  }
192  else if ( (side == PastixRight) && (uplo == PastixLower) && (trans == PastixNoTrans) ) {
193  tA = trans;
194  cs = PastixLCoef;
195 
196  /* Right is not handled yet */
197  assert( 0 );
198  }
199  else {
200  /* This correspond to case treated in forward TRSM */
201  assert(0);
202  return;
203  }
204 
205  if ( !(cblk->cblktype & CBLK_IN_SCHUR) || (mode == PastixSolvModeSchur) ) {
206  /* Solve the diagonal block */
208  sopalin_data, cs, side, PastixLower, tA, diag, cblk, prio );
209  }
210 
211  /* Apply the update */
212  for (j = cblk[1].brownum-1; j>=cblk[0].brownum; j-- ) {
213  blok = datacode->bloktab + datacode->browtab[j];
214  fcbk = datacode->cblktab + blok->lcblknm;
215 
216  if ( (fcbk->cblktype & CBLK_IN_SCHUR) && (mode == PastixSolvModeInterface) ) {
217  continue;
218  }
219 
220  starpu_stask_blok_zgemm( sopalin_data, cs, PastixRight, tA,
221  cblk, blok, fcbk, prio );
222  }
223 }
224 
225 /**
226  *******************************************************************************
227  *
228  * @brief Apply a TRSM on a problem with 1 dimension (StarPU version)
229  *
230  *******************************************************************************
231  *
232  * @param[in] pastix_data
233  * The data that provide the mode.
234  *
235  * @param[in] sopalin_data
236  * The data that provide the SolverMatrix structure from PaStiX., and
237  * descriptor of b (providing nrhs, b and ldb).
238  *
239  * @param[in] enums
240  * Enums needed for the solve.
241  *
242  *******************************************************************************/
243 void
245  sopalin_data_t *sopalin_data,
246  const args_solve_t *enums )
247 {
248  SolverMatrix *datacode = sopalin_data->solvmtx;
249  SolverCblk *cblk;
250  pastix_int_t i, cblknbr;
251 
252  /* Backward like */
253  if ( enums->solve_step == PastixSolveBackward ) {
254  cblknbr = (enums->mode == PastixSolvModeLocal) ? datacode->cblkschur : datacode->cblknbr;
255 
256  cblk = datacode->cblktab + cblknbr - 1;
257  for (i=0; i<cblknbr; i++, cblk--){
258  starpu_cblk_ztrsmsp_backward( enums, sopalin_data, cblk, cblknbr - i );
259  }
260  }
261  /* Forward like */
262  else {
263  cblknbr = (enums->mode == PastixSolvModeSchur) ? datacode->cblknbr : datacode->cblkschur;
264 
265  cblk = datacode->cblktab;
266  for (i=0; i<cblknbr; i++, cblk++){
267  starpu_cblk_ztrsmsp_forward( enums, sopalin_data, cblk, cblknbr - i );
268  }
269  }
270  (void)pastix_data;
271 }
272 
273 /**
274  *******************************************************************************
275  *
276  * @brief Apply the TRSM solve (StarPU version).
277  *
278  *******************************************************************************
279  *
280  * @param[in] pastix_data
281  * Provide informations about starpu and the schur solving mode.
282  *
283  * @param[in] enums
284  * Enums needed for the solve.
285  *
286  * @param[in] sopalin_data
287  * The data that provide the SolverMatrix structure from PaStiX, and
288  * descriptor of b (providing nrhs, b and ldb).
289  *
290  * @param[inout] rhsb
291  * The pointer to the rhs data structure that holds the vectors of the
292  * right hand side.
293  *
294  *******************************************************************************/
295 void
297  const args_solve_t *enums,
298  sopalin_data_t *sopalin_data,
299  pastix_rhs_t rhsb )
300 {
301  starpu_sparse_matrix_desc_t *sdesc = sopalin_data->solvmtx->starpu_desc;
303 
304  /*
305  * Start StarPU if not already started
306  */
307  if (pastix_data->starpu == NULL) {
308  int argc = 0;
309  pastix_starpu_init( pastix_data, &argc, NULL, NULL );
310  }
311 
312  if ( sdesc == NULL ) {
313  /* Create the sparse matrix descriptor */
314  starpu_sparse_matrix_init( sopalin_data->solvmtx,
316  pastix_data->inter_node_procnbr,
317  pastix_data->inter_node_procnum,
318  PastixComplex64 );
319  sdesc = sopalin_data->solvmtx->starpu_desc;
320  }
321 
322  /* Create the dense matrix descriptor */
323  starpu_dense_matrix_init( sopalin_data->solvmtx,
324  rhsb->n, rhsb->b, rhsb->ld,
325  sizeof(pastix_complex64_t),
326  pastix_data->inter_node_procnbr,
327  pastix_data->inter_node_procnum );
328  ddesc = sopalin_data->solvmtx->starpu_desc_rhs;
329 
330 #if defined(STARPU_USE_FXT)
331  if (pastix_data->iparm[IPARM_TRACE] & PastixTraceSolve) {
332  starpu_fxt_start_profiling();
333  }
334 #endif
335  starpu_resume();
336  starpu_ztrsm_sp1dplus( pastix_data, sopalin_data, enums );
337 
340  starpu_task_wait_for_all();
341 #if defined(PASTIX_WITH_MPI)
342  starpu_mpi_wait_for_all( pastix_data->pastix_comm );
343  starpu_mpi_barrier(pastix_data->inter_node_comm);
344 #endif
345  starpu_pause();
346 #if defined(STARPU_USE_FXT)
347  if (pastix_data->iparm[IPARM_TRACE] & PastixTraceSolve) {
348  starpu_fxt_stop_profiling();
349  }
350 #endif
351 
352  return;
353 }
354 
355 /**
356  *@}
357  */
BEGIN_C_DECLS typedef int pastix_int_t
Definition: datatypes.h:51
enum pastix_diag_e pastix_diag_t
Diagonal.
enum pastix_solv_mode_e pastix_solv_mode_t
Solve Schur modes.
enum pastix_uplo_e pastix_uplo_t
Upper/Lower part.
#define PastixHermitian
Definition: api.h:460
enum pastix_side_e pastix_side_t
Side of the operation.
enum pastix_trans_e pastix_trans_t
Transpostion.
enum pastix_coefside_e pastix_coefside_t
Data blocks used in the kernel.
@ PastixLCoef
Definition: api.h:478
@ PastixUCoef
Definition: api.h:479
@ IPARM_TRACE
Definition: api.h:44
@ PastixUpper
Definition: api.h:466
@ PastixLower
Definition: api.h:467
@ PastixRight
Definition: api.h:496
@ PastixLeft
Definition: api.h:495
@ PastixConjTrans
Definition: api.h:447
@ PastixNoTrans
Definition: api.h:445
@ PastixTrans
Definition: api.h:446
@ PastixTraceSolve
Definition: api.h:212
void starpu_dense_matrix_init(SolverMatrix *solvmtx, pastix_int_t ncol, char *A, pastix_int_t lda, int typesze, int nodes, int myrank)
Generate the StarPU descriptor of the dense matrix.
void starpu_sparse_matrix_getoncpu(starpu_sparse_matrix_desc_t *desc)
Submit asynchronous calls to retrieve the data on main memory.
void starpu_sparse_matrix_init(SolverMatrix *solvmtx, pastix_mtxtype_t mtxtype, int nodes, int myrank, pastix_coeftype_t flttype)
Generate the StarPU descriptor of the sparse matrix.
void pastix_starpu_init(pastix_data_t *pastix, int *argc, char **argv[], const int *bindtab)
Startup the StarPU runtime system.
Definition: starpu.c:92
void starpu_stask_blok_zgemm(sopalin_data_t *sopalin_data, pastix_coefside_t coef, pastix_side_t side, pastix_trans_t trans, const SolverCblk *cblk, const SolverBlok *blok, SolverCblk *fcbk, pastix_int_t prio)
Submit a task to perform a gemm.
void starpu_dense_matrix_getoncpu(starpu_dense_matrix_desc_t *desc)
Submit asynchronous calls to retrieve the data on main memory.
StarPU descriptor for the vectors linked to a given sparse matrix.
StarPU descriptor stucture for the sparse matrix.
PASTIX_Comm pastix_comm
Definition: pastixdata.h:75
int inter_node_procnum
Definition: pastixdata.h:83
int inter_node_procnbr
Definition: pastixdata.h:82
void * starpu
Definition: pastixdata.h:87
pastix_int_t * iparm
Definition: pastixdata.h:69
pastix_int_t ld
Definition: pastixdata.h:155
PASTIX_Comm inter_node_comm
Definition: pastixdata.h:77
pastix_int_t n
Definition: pastixdata.h:154
Main PaStiX data structure.
Definition: pastixdata.h:67
Main PaStiX RHS structure.
Definition: pastixdata.h:150
void starpu_ztrsm(pastix_data_t *pastix_data, const args_solve_t *enums, sopalin_data_t *sopalin_data, pastix_rhs_t rhsb)
Apply the TRSM solve (StarPU version).
Definition: starpu_ztrsm.c:296
void starpu_ztrsm_sp1dplus(pastix_data_t *pastix_data, sopalin_data_t *sopalin_data, const args_solve_t *enums)
Apply a TRSM on a problem with 1 dimension (StarPU version)
Definition: starpu_ztrsm.c:244
void starpu_cblk_ztrsmsp_forward(const args_solve_t *enums, sopalin_data_t *sopalin_data, const SolverCblk *cblk, pastix_int_t prio)
Apply a forward solve related to one cblk to all the right hand side. (StarPU version)
Definition: starpu_ztrsm.c:54
void starpu_cblk_ztrsmsp_backward(const args_solve_t *enums, sopalin_data_t *sopalin_data, const SolverCblk *cblk, pastix_int_t prio)
Apply a backward solve related to one cblk to all the right hand side. (StarPU version)
Definition: starpu_ztrsm.c:152
void starpu_stask_blok_ztrsm(sopalin_data_t *sopalin_data, pastix_coefside_t coef, pastix_side_t side, pastix_uplo_t uplo, pastix_trans_t trans, pastix_diag_t diag, const SolverCblk *cblk, pastix_int_t prio)
Submit a task to do a trsm related to a diagonal block of the matrix A.
pastix_int_t brownum
Definition: solver.h:166
pastix_int_t fcblknm
Definition: solver.h:140
pastix_int_t cblknbr
Definition: solver.h:208
SolverBlok *restrict bloktab
Definition: solver.h:223
SolverBlok * fblokptr
Definition: solver.h:163
pastix_int_t *restrict browtab
Definition: solver.h:224
pastix_int_t lcblknm
Definition: solver.h:139
pastix_int_t lcolidx
Definition: solver.h:165
SolverCblk *restrict cblktab
Definition: solver.h:222
int8_t cblktype
Definition: solver.h:159
pastix_int_t cblkschur
Definition: solver.h:217
pastix_int_t fcolnum
Definition: solver.h:161
Arguments for the solve.
Definition: solver.h:85
Solver block structure.
Definition: solver.h:137
Solver column block structure.
Definition: solver.h:156
Solver column block structure.
Definition: solver.h:200