PaStiX Handbook 6.4.0
Loading...
Searching...
No Matches
starpu_ztrsm.c
Go to the documentation of this file.
1/**
2 *
3 * @file starpu_ztrsm.c
4 *
5 * PaStiX ztrsm StarPU wrapper.
6 *
7 * @copyright 2016-2024 Bordeaux INP, CNRS (LaBRI UMR 5800), Inria,
8 * Univ. Bordeaux. All rights reserved.
9 *
10 * @version 6.4.0
11 * @author Vincent Bridonneau
12 * @author Mathieu Faverge
13 * @author Pierre Ramet
14 * @author Alycia Lisito
15 * @author Nolan Bredel
16 * @author Tom Moenne-Loccoz
17 * @date 2024-07-05
18 * @generated from /builds/2mk6rsew/0/solverstack/pastix/sopalin/starpu/starpu_ztrsm.c, normal z -> z, Tue Feb 25 14:35:18 2025
19 *
20 * @addtogroup starpu_trsm_solve
21 * @{
22 *
23 **/
24
25#include "common.h"
26#include "blend/solver.h"
27#include "sopalin/sopalin_data.h"
28#include "pastix_zcores.h"
29#include "pastix_starpu.h"
30#include "pastix_zstarpu.h"
31
32/**
33 *******************************************************************************
34 *
35 * @brief Apply a forward solve related to one cblk to all the right hand side.
36 * (StarPU version)
37 *
38 ********************************************************************************
39 *
40 * @param[in] enums
41 * Enums needed for the solve.
42 *
43 * @param[in] sopalin_data
44 * The data that provide the SolverMatrix structure from PaStiX, and
45 * descriptor of b (providing nrhs, b and ldb).
46 *
47 * @param[inout] rhsb
48 * The pointer to the rhs data structure that holds the vectors of the
49 * right hand side.
50 *
51 * @param[in] cblk
52 * The cblk structure to which block belongs to. The A and B pointers
53 * must be the coeftab of this column block.
54 * Next column blok must be accessible through cblk[1].
55 *
56 * @param[in] prio
57 * The priority of the task in th DAG.
58 *
59 *******************************************************************************/
60void
62 sopalin_data_t *sopalin_data,
63 pastix_rhs_t rhsb,
64 const SolverCblk *cblk,
65 pastix_int_t prio )
66{
68 SolverMatrix *datacode = sopalin_data->solvmtx;
69 SolverCblk *fcbk;
70 SolverBlok *blok;
72 pastix_side_t side = enums->side;
73 pastix_uplo_t uplo = enums->uplo;
74 pastix_trans_t trans = enums->trans;
75 pastix_diag_t diag = enums->diag;
76 pastix_solv_mode_t mode = enums->mode;
77
78 if ( (cblk->cblktype & CBLK_IN_SCHUR) && (mode != PastixSolvModeSchur) ) {
79 return;
80 }
81
82 if ( (side == PastixRight) && (uplo == PastixUpper) && (trans == PastixNoTrans) ) {
83 /* We store U^t, so we swap uplo and trans */
84 tA = PastixTrans;
85 cs = PastixUCoef;
86
87 /* Right is not handled yet */
88 assert( 0 );
89 }
90 else if ( (side == PastixRight) && (uplo == PastixLower) && (trans != PastixNoTrans) ) {
91 tA = trans;
92 cs = PastixLCoef;
93
94 /* Right is not handled yet */
95 assert( 0 );
96 }
97 else if ( (side == PastixLeft) && (uplo == PastixUpper) && (trans != PastixNoTrans) ) {
98 /* We store U^t, so we swap uplo and trans */
99 tA = PastixNoTrans;
100 cs = PastixUCoef;
101
102 /* We do not handle conjtrans in complex as we store U^t */
103 assert( trans != PastixConjTrans );
104 }
105 else if ( (side == PastixLeft) && (uplo == PastixLower) && (trans == PastixNoTrans) ) {
106 tA = trans;
107 cs = PastixLCoef;
108 }
109 else {
110 /* This correspond to case treated in backward TRSM */
111 assert(0);
112 return;
113 }
114
115 /* Solve the diagonal block */
116 starpu_stask_blok_ztrsm( sopalin_data, rhsb, cs, side, PastixLower,
117 tA, diag, cblk, prio );
118
119 /* Apply the update */
120 for (blok = cblk[0].fblokptr+1; blok < cblk[1].fblokptr; blok++ ) {
121 fcbk = datacode->cblktab + blok->fcblknm;
122
123 if ( (fcbk->cblktype & CBLK_IN_SCHUR) && (mode == PastixSolvModeLocal) ) {
124 return;
125 }
126
127 starpu_stask_blok_zgemm( sopalin_data, rhsb, cs, PastixLeft, tA,
128 cblk, blok, fcbk, prio );
129 }
130}
131
132/**
133 *******************************************************************************
134 *
135 * @brief Apply a backward solve related to one cblk to all the right hand side.
136 * (StarPU version)
137 *
138 *******************************************************************************
139 *
140 * @param[in] enums
141 * Enums needed for the solve.
142 *
143 * @param[in] sopalin_data
144 * The data that provide the SolverMatrix structure from PaStiX, and
145 * descriptor of b (providing nrhs, b and ldb).
146 *
147 * @param[inout] rhsb
148 * The pointer to the rhs data structure that holds the vectors of the
149 * right hand side.
150 *
151 * @param[in] cblk
152 * The cblk structure to which block belongs to. The A and B pointers
153 * must be the coeftab of this column block.
154 * Next column blok must be accessible through cblk[1].
155 *
156 * @param[in] prio
157 * The priority of the task in th DAG.
158 *
159 *******************************************************************************/
160void
162 sopalin_data_t *sopalin_data,
163 pastix_rhs_t rhsb,
164 const SolverCblk *cblk,
165 pastix_int_t prio )
166{
168 SolverMatrix *datacode = sopalin_data->solvmtx;
169 SolverCblk *fcbk;
170 SolverBlok *blok;
171 pastix_int_t j;
173 pastix_side_t side = enums->side;
174 pastix_uplo_t uplo = enums->uplo;
175 pastix_trans_t trans = enums->trans;
176 pastix_diag_t diag = enums->diag;
177 pastix_solv_mode_t mode = enums->mode;
178
179 /*
180 * Left / Upper / NoTrans (Backward)
181 */
182 if ( (side == PastixLeft) && (uplo == PastixUpper) && (trans == PastixNoTrans) ) {
183 /* We store U^t, so we swap uplo and trans */
184 tA = PastixTrans;
185 cs = PastixUCoef;
186 }
187 else if ( (side == PastixLeft) && (uplo == PastixLower) && (trans != PastixNoTrans) ) {
188 tA = trans;
189 cs = PastixLCoef;
190 }
191 else if ( (side == PastixRight) && (uplo == PastixUpper) && (trans != PastixNoTrans) ) {
192 /* We store U^t, so we swap uplo and trans */
193 tA = PastixNoTrans;
194 cs = PastixUCoef;
195
196 /* Right is not handled yet */
197 assert( 0 );
198
199 /* We do not handle conjtrans in complex as we store U^t */
200 assert( trans != PastixConjTrans );
201 }
202 else if ( (side == PastixRight) && (uplo == PastixLower) && (trans == PastixNoTrans) ) {
203 tA = trans;
204 cs = PastixLCoef;
205
206 /* Right is not handled yet */
207 assert( 0 );
208 }
209 else {
210 /* This correspond to case treated in forward TRSM */
211 assert(0);
212 return;
213 }
214 assert( !(cblk->cblktype & CBLK_RECV) );
215
216 if ( ( !(cblk->cblktype & CBLK_IN_SCHUR) || (mode == PastixSolvModeSchur) ) &&
217 ( !(cblk->cblktype & CBLK_FANIN) ) )
218 {
219 /* Solve the diagonal block */
220 starpu_stask_blok_ztrsm( sopalin_data, rhsb, cs, side, PastixLower,
221 tA, diag, cblk, prio );
222 }
223
224 /* Apply the update */
225 for (j = cblk[1].brownum-1; j>=cblk[0].brownum; j-- ) {
226 blok = datacode->bloktab + datacode->browtab[j];
227 fcbk = datacode->cblktab + blok->lcblknm;
228
229 if ( (fcbk->cblktype & CBLK_IN_SCHUR) && (mode == PastixSolvModeInterface) ) {
230 continue;
231 }
232 if ( fcbk->cblktype & CBLK_RECV ) {
233 continue;
234 }
235
236 starpu_stask_blok_zgemm( sopalin_data, rhsb, cs, PastixRight, tA,
237 cblk, blok, fcbk, prio );
238 }
239}
240
241/**
242 *******************************************************************************
243 *
244 * @brief Apply a TRSM on a problem with 1 dimension (StarPU version)
245 *
246 *******************************************************************************
247 *
248 * @param[in] pastix_data
249 * The data that provide the mode.
250 *
251 * @param[in] sopalin_data
252 * The data that provide the SolverMatrix structure from PaStiX., and
253 * descriptor of b (providing nrhs, b and ldb).
254 *
255 * @param[inout] rhsb
256 * The pointer to the rhs data structure that holds the vectors of the
257 * right hand side.
258 *
259 * @param[in] enums
260 * Enums needed for the solve.
261 *
262 *******************************************************************************/
263void
265 sopalin_data_t *sopalin_data,
266 pastix_rhs_t rhsb,
267 const args_solve_t *enums )
268{
269 SolverMatrix *datacode = sopalin_data->solvmtx;
270 SolverCblk *cblk;
271#if defined(PASTIX_WITH_MPI)
272 SolverCblk *fcblk;
273#endif
274 pastix_int_t i, cblknbr, prio;
275
276 /* Backward like */
277 if ( enums->solve_step == PastixSolveBackward ) {
278 cblknbr = (enums->mode == PastixSolvModeLocal) ? datacode->cblkschur : datacode->cblknbr;
279 cblk = datacode->cblktab + cblknbr - 1;
280
281 for ( i = cblknbr-1; i >= 0; i--, cblk-- ) {
282 prio = i;
283
284#if defined(PASTIX_WITH_MPI)
285 /* If this is a recv, let's locally copy and send the accumulation */
286 if ( cblk->cblktype & CBLK_RECV ) {
287 fcblk = datacode->cblktab + cblk->fblokptr->fcblknm;
288 starpu_stask_blok_zcpy_bwd_recv( sopalin_data, rhsb, cblk, fcblk, prio );
289 starpu_mpi_data_migrate( datacode->solv_comm,
290 rhsb->starpu_desc->handletab[i],
291 cblk->ownerid );
292 continue;
293 }
294
295 /* If this is a fanin, let's submit the receive */
296 if ( cblk->cblktype & CBLK_FANIN ) {
297 starpu_mpi_data_migrate( datacode->solv_comm,
298 rhsb->starpu_desc->handletab[i],
299 datacode->clustnum );
300 }
301#endif
302 starpu_cblk_ztrsmsp_backward( enums, sopalin_data, rhsb, cblk, prio );
303 }
304 }
305 /* Forward like */
306 else {
307 cblknbr = (enums->mode == PastixSolvModeSchur) ? datacode->cblknbr : datacode->cblkschur;
308 cblk = datacode->cblktab;
309
310 for ( i = 0; i < cblknbr; i++, cblk++ ) {
311 prio = cblknbr - i;
312
313#if defined(PASTIX_WITH_MPI)
314 /* If this is a fanin, let's submit the send */
315 if ( cblk->cblktype & CBLK_FANIN ) {
316 starpu_mpi_data_migrate( datacode->solv_comm,
317 rhsb->starpu_desc->handletab[i],
318 cblk->ownerid );
319 continue;
320 }
321
322 /* If this is a recv, let's locally sum the accumulation received */
323 if ( cblk->cblktype & CBLK_RECV ) {
324 starpu_mpi_data_migrate( datacode->solv_comm,
325 rhsb->starpu_desc->handletab[i],
326 datacode->clustnum );
327 fcblk = datacode->cblktab + cblk->fblokptr->fcblknm;
328 starpu_stask_blok_zadd_fwd_recv( sopalin_data, rhsb, cblk, fcblk, prio );
329 continue;
330 }
331#endif
332 starpu_cblk_ztrsmsp_forward( enums, sopalin_data, rhsb, cblk, prio );
333 }
334 }
335
336 for ( i = 0; i < cblknbr; i++ ) {
337 starpu_data_wont_use( rhsb->starpu_desc->handletab[i] );
338 }
339 (void)pastix_data;
340}
341
342/**
343 *******************************************************************************
344 *
345 * @brief Apply the TRSM solve (StarPU version).
346 *
347 *******************************************************************************
348 *
349 * @param[in] pastix_data
350 * Provide informations about starpu and the schur solving mode.
351 *
352 * @param[in] enums
353 * Enums needed for the solve.
354 *
355 * @param[in] sopalin_data
356 * The data that provide the SolverMatrix structure from PaStiX, and
357 * descriptor of b (providing nrhs, b and ldb).
358 *
359 * @param[inout] rhsb
360 * The pointer to the rhs data structure that holds the vectors of the
361 * right hand side.
362 *
363 *******************************************************************************/
364void
366 const args_solve_t *enums,
367 sopalin_data_t *sopalin_data,
368 pastix_rhs_t rhsb )
369{
370 starpu_sparse_matrix_desc_t *sdesc = sopalin_data->solvmtx->starpu_desc;
371 starpu_rhs_desc_t *ddesc = rhsb->starpu_desc;
372
373 /*
374 * Start StarPU if not already started
375 */
376 if (pastix_data->starpu == NULL) {
377 int argc = 0;
378 pastix_starpu_init( pastix_data, &argc, NULL, NULL );
379 }
380
381 if ( sdesc == NULL ) {
382 /* Create the sparse matrix descriptor */
383 starpu_sparse_matrix_init( sopalin_data->solvmtx,
385 pastix_data->inter_node_procnbr,
386 pastix_data->inter_node_procnum,
387 PastixComplex64 );
388 sdesc = sopalin_data->solvmtx->starpu_desc;
389 }
390
391 if ( ddesc == NULL ) {
392 /* Create the dense matrix descriptor */
393 starpu_rhs_init( pastix_data->solvmatr, rhsb,
394 PastixComplex64,
395 pastix_data->inter_node_procnbr,
396 pastix_data->inter_node_procnum );
397 ddesc = rhsb->starpu_desc;
398 }
399
400#if defined(STARPU_USE_FXT)
401 if (pastix_data->iparm[IPARM_TRACE] & PastixTraceSolve) {
402 starpu_fxt_start_profiling();
403 }
404#endif
405 starpu_resume();
406 starpu_ztrsm_sp1dplus( pastix_data, sopalin_data, rhsb, enums );
407
409 starpu_rhs_getoncpu( ddesc );
410 starpu_task_wait_for_all();
411#if defined(PASTIX_WITH_MPI)
412 starpu_mpi_wait_for_all( pastix_data->pastix_comm );
413 starpu_mpi_barrier(pastix_data->inter_node_comm);
414#endif
415 starpu_pause();
416#if defined(STARPU_USE_FXT)
417 if (pastix_data->iparm[IPARM_TRACE] & PastixTraceSolve) {
418 starpu_fxt_stop_profiling();
419 }
420#endif
421
422 return;
423}
424
425/**
426 *@}
427 */
BEGIN_C_DECLS typedef int pastix_int_t
Definition datatypes.h:51
enum pastix_diag_e pastix_diag_t
Diagonal.
enum pastix_solv_mode_e pastix_solv_mode_t
Solve Schur modes.
enum pastix_uplo_e pastix_uplo_t
Upper/Lower part.
#define PastixHermitian
Definition api.h:460
enum pastix_side_e pastix_side_t
Side of the operation.
enum pastix_trans_e pastix_trans_t
Transpostion.
enum pastix_coefside_e pastix_coefside_t
Data blocks used in the kernel.
@ PastixLCoef
Definition api.h:478
@ PastixUCoef
Definition api.h:479
@ IPARM_TRACE
Definition api.h:44
@ PastixUpper
Definition api.h:466
@ PastixLower
Definition api.h:467
@ PastixRight
Definition api.h:496
@ PastixLeft
Definition api.h:495
@ PastixConjTrans
Definition api.h:447
@ PastixNoTrans
Definition api.h:445
@ PastixTrans
Definition api.h:446
@ PastixTraceSolve
Definition api.h:212
void starpu_sparse_matrix_getoncpu(starpu_sparse_matrix_desc_t *desc)
Submit asynchronous calls to retrieve the data on main memory.
void starpu_rhs_getoncpu(starpu_rhs_desc_t *desc)
Submit asynchronous calls to retrieve the data on main memory.
Definition starpu_rhs.c:217
void starpu_rhs_init(SolverMatrix *solvmtx, pastix_rhs_t rhsb, int typesze, int nodes, int myrank)
Generate the StarPU descriptor of the dense matrix.
Definition starpu_rhs.c:152
void starpu_sparse_matrix_init(SolverMatrix *solvmtx, pastix_mtxtype_t mtxtype, int nodes, int myrank, pastix_coeftype_t flttype)
Generate the StarPU descriptor of the sparse matrix.
void pastix_starpu_init(pastix_data_t *pastix, int *argc, char **argv[], const int *bindtab)
Startup the StarPU runtime system.
Definition starpu.c:92
void starpu_stask_blok_zgemm(sopalin_data_t *sopalin_data, pastix_rhs_t rhsb, pastix_coefside_t coef, pastix_side_t side, pastix_trans_t trans, const SolverCblk *cblk, const SolverBlok *blok, SolverCblk *fcbk, pastix_int_t prio)
Submit a task to perform a gemm.
void starpu_stask_blok_zcpy_bwd_recv(sopalin_data_t *sopalin_data, pastix_rhs_t rhsb, SolverCblk *cblk, const SolverCblk *fcblk, int prio)
Insert the task to add a fanin cblk on the receiver side (The fanin is seen on this side as the RECV ...
void starpu_stask_blok_zadd_fwd_recv(sopalin_data_t *sopalin_data, pastix_rhs_t rhsb, const SolverCblk *cblk, SolverCblk *fcblk, int prio)
Insert the task to add a fanin cblk on the receiver side (The fanin is seen on this side as the RECV ...
StarPU descriptor for the vectors linked to a given sparse matrix.
StarPU descriptor stucture for the sparse matrix.
PASTIX_Comm pastix_comm
Definition pastixdata.h:76
int inter_node_procnum
Definition pastixdata.h:84
SolverMatrix * solvmatr
Definition pastixdata.h:103
int inter_node_procnbr
Definition pastixdata.h:83
void * starpu
Definition pastixdata.h:88
pastix_int_t * iparm
Definition pastixdata.h:70
PASTIX_Comm inter_node_comm
Definition pastixdata.h:78
Main PaStiX data structure.
Definition pastixdata.h:68
Main PaStiX RHS structure.
Definition pastixdata.h:155
void starpu_cblk_ztrsmsp_forward(const args_solve_t *enums, sopalin_data_t *sopalin_data, pastix_rhs_t rhsb, const SolverCblk *cblk, pastix_int_t prio)
Apply a forward solve related to one cblk to all the right hand side. (StarPU version)
void starpu_stask_blok_ztrsm(sopalin_data_t *sopalin_data, pastix_rhs_t rhsb, pastix_coefside_t coef, pastix_side_t side, pastix_uplo_t uplo, pastix_trans_t trans, pastix_diag_t diag, const SolverCblk *cblk, pastix_int_t prio)
Submit a task to do a trsm related to a diagonal block of the matrix A.
void starpu_ztrsm(pastix_data_t *pastix_data, const args_solve_t *enums, sopalin_data_t *sopalin_data, pastix_rhs_t rhsb)
Apply the TRSM solve (StarPU version).
void starpu_ztrsm_sp1dplus(pastix_data_t *pastix_data, sopalin_data_t *sopalin_data, pastix_rhs_t rhsb, const args_solve_t *enums)
Apply a TRSM on a problem with 1 dimension (StarPU version)
void starpu_cblk_ztrsmsp_backward(const args_solve_t *enums, sopalin_data_t *sopalin_data, pastix_rhs_t rhsb, const SolverCblk *cblk, pastix_int_t prio)
Apply a backward solve related to one cblk to all the right hand side. (StarPU version)
pastix_int_t brownum
Definition solver.h:171
pastix_int_t fcblknm
Definition solver.h:144
pastix_int_t cblknbr
Definition solver.h:211
SolverBlok *restrict bloktab
Definition solver.h:229
SolverBlok * fblokptr
Definition solver.h:168
pastix_int_t *restrict browtab
Definition solver.h:230
pastix_int_t lcblknm
Definition solver.h:143
SolverCblk *restrict cblktab
Definition solver.h:228
int8_t cblktype
Definition solver.h:164
pastix_int_t cblkschur
Definition solver.h:221
Arguments for the solve.
Definition solver.h:88
Solver block structure.
Definition solver.h:141
Solver column block structure.
Definition solver.h:161
Solver column block structure.
Definition solver.h:203