PaStiX Handbook  6.3.2
schur.c
Go to the documentation of this file.
1 /**
2  *
3  * @file sopalin/schur.c
4  *
5  * PaStiX schur interface functions
6  *
7  * @copyright 2017-2023 Bordeaux INP, CNRS (LaBRI UMR 5800), Inria,
8  * Univ. Bordeaux. All rights reserved.
9  *
10  * @version 6.3.2
11  * @author Mathieu Faverge
12  * @author Pierre Ramet
13  * @author Xavier Lacoste
14  * @date 2023-11-07
15  *
16  * @addtogroup pastix_schur
17  * @{
18  *
19  **/
20 #include "common.h"
21 #include <spm.h>
22 #include <lapacke.h>
23 #include "blend/solver.h"
24 #include "sopalin/coeftab_z.h"
25 #include "sopalin/coeftab_c.h"
26 #include "sopalin/coeftab_d.h"
27 #include "sopalin/coeftab_s.h"
28 
29 /**
30  *******************************************************************************
31  *
32  * @brief Set the list of unknowns that belongs to the schur complement.
33  *
34  *******************************************************************************
35  *
36  * @param[inout] pastix_data
37  * The pastix data structure of the solver to store the list of Schur
38  * unknowns.
39  *
40  * @param[in] n
41  * The number of unknowns in the Schur complement.
42  *
43  * @param[in] list
44  * Array of integer of size n.
45  * The list of unknowns belonging to the Schur complement with the same
46  * baseval as the associated spm.
47  *
48  *******************************************************************************/
49 void
51  pastix_int_t n,
52  const pastix_int_t *list)
53 {
54  if ( n > 0 ) {
55  pastix_data->schur_n = n;
56  pastix_data->schur_list = (pastix_int_t*)malloc(n * sizeof(pastix_int_t));
57  memcpy( pastix_data->schur_list, list, n * sizeof(pastix_int_t) );
58  }
59 }
60 
61 /**
62  *******************************************************************************
63  *
64  * @brief Return the Schur complement.
65  *
66  * The Schur complement is returned in the column major layout used by the
67  * classic linear algebra libraries such as Blas or Lapack.
68  *
69  *******************************************************************************
70  *
71  * @param[in] pastix_data
72  * The pastix data structure of the problem solved.
73  *
74  * @param[inout] S
75  * Array of size spm->n -by- lds of arithmetic spm->flttype, where spm
76  * is the spm of the original problem.
77  * On exit, the array contains the Schur complement of the factorized
78  * matrix.
79  *
80  * @param[in] lds
81  * The leading dimension of the S array.
82  *
83  ********************************************************************************
84  *
85  * @retval PASTIX_SUCCESS on successful exit,
86  * @retval PASTIX_ERR_BADPARAMETER if one parameter is incorrect.
87  *
88  *******************************************************************************/
89 int
90 pastixGetSchur( const pastix_data_t *pastix_data,
91  void *S,
92  pastix_int_t lds )
93 {
94  pastix_int_t *iparm;
95 
96  /*
97  * Check parameters
98  */
99  if (pastix_data == NULL) {
100  pastix_print_error( "pastix_getSchur: wrong pastix_data parameter" );
102  }
103  if (S == NULL) {
104  pastix_print_error( "pastix_getSchur: S parameter" );
106  }
107  if (lds <= 0) {
108  pastix_print_error( "pastix_getSchur: lds parameter" );
110  }
111  if ( !(pastix_data->steps & STEP_NUMFACT) ) {
112  pastix_print_error( "pastix_getSchur: All steps from pastix_task_init() to pastix_task_numfact() have to be called before calling this function" );
114  }
115 #if defined(PASTIX_WITH_MPI)
116  if (pastix_data->inter_node_procnbr > 1) {
117  if ( pastix_data->inter_node_procnum == 0 ) {
118  pastix_print_error( "pastix_getSchur: Schur complement is not available yet with multiple MPI processes\n" );
119  }
120  return -1;
121  }
122 #endif
123 
124  iparm = pastix_data->iparm;
125  switch(iparm[IPARM_FLOAT])
126  {
127  case PastixPattern:
128  break;
129  case PastixFloat:
130  coeftab_sgetschur( pastix_data->solvmatr, S, lds );
131  break;
132  case PastixComplex32:
133  coeftab_cgetschur( pastix_data->solvmatr, S, lds );
134  break;
135  case PastixComplex64:
136  coeftab_zgetschur( pastix_data->solvmatr, S, lds );
137  break;
138  case PastixDouble:
139  default:
140  coeftab_dgetschur( pastix_data->solvmatr, S, lds );
141  }
142  return PASTIX_SUCCESS;
143 }
144 
145 /**
146  *******************************************************************************
147  *
148  * @ingroup pastix_solve
149  *
150  * @brief Get the vector in an RHS data structure.
151  *
152  *******************************************************************************
153  *
154  * @param[in] pastix_data
155  * TODO
156  *
157  * @param[in] m
158  * The number of rows of the vector b, must be equal to the number of
159  * unknowns in the Schur complement.
160  *
161  * @param[in] n
162  * The number of columns of the vector b.
163  *
164  * @param[in] rhsB
165  * The pastix_rhs_t data structure used to solve the system.
166  *
167  * @param[inout] B
168  * On entry, a vector of size ldb-by-n.
169  * On exit, the m-by-n leading part contains the right hand side
170  * related to the Schur part.
171  *
172  * @param[in] ldb
173  * The leading dimension of the vector b.
174  *
175  *******************************************************************************
176  *
177  * @retval PASTIX_SUCCESS on successful exit,
178  * @retval PASTIX_ERR_BADPARAMETER if one parameter is incorrect.
179  *
180  *******************************************************************************/
181 int
182 pastixRhsSchurGet( const pastix_data_t *pastix_data,
183  pastix_int_t m,
184  pastix_int_t n,
185  pastix_rhs_t rhsB,
186  void *B,
187  pastix_int_t ldb )
188 {
189  const SolverMatrix *solvmtx;
190  const SolverCblk *cblk;
191  pastix_int_t mschur;
192  void *bptr;
193  int rc;
194 
195  if ( pastix_data == NULL ) {
196  pastix_print_error( "pastixRhsSchurGet: wrong pastix_data parameter" );
198  }
199  if ( rhsB == NULL ) {
200  pastix_print_error( "pastixRhsSchurGet: wrong rhsB parameter" );
202  }
203  if ( B == NULL ) {
204  pastix_print_error( "pastixRhsSchurGet: wrong b parameter" );
206  }
207 
208  solvmtx = pastix_data->solvmatr;
209  cblk = solvmtx->cblktab + solvmtx->cblkschur;
210  mschur = solvmtx->nodenbr - cblk->fcolnum;
211 
212  if ( m != mschur ) {
213  pastix_print_error( "pastixRhsSchurGet: wrong m parameter expecting %ld but was %ld\n",
214  (long)mschur, (long)m );
216  }
217  if ( n != rhsB->n ) {
218  pastix_print_error( "pastixRhsSchurGet: wrong n parameter expecting %ld but was %ld\n",
219  (long)rhsB->n, (long)n );
221  }
222  if ( ldb < m ) {
223  pastix_print_error( "pastixRhsSchurGet: wrong ldb parameter\n" );
225  }
226 
227  bptr = ((char *)rhsB->b) + cblk->lcolidx * pastix_size_of( rhsB->flttype );
228 
229  switch( rhsB->flttype ) {
230  case SpmComplex64:
231  rc = LAPACKE_zlacpy_work( LAPACK_COL_MAJOR, 'A', mschur, n, (pastix_complex64_t *)bptr, rhsB->ld, B, ldb );
232  break;
233  case SpmComplex32:
234  rc = LAPACKE_clacpy_work( LAPACK_COL_MAJOR, 'A', mschur, n, (pastix_complex32_t *)bptr, rhsB->ld, B, ldb );
235  break;
236  case SpmDouble:
237  rc = LAPACKE_dlacpy_work( LAPACK_COL_MAJOR, 'A', mschur, n, (double *)bptr, rhsB->ld, B, ldb );
238  break;
239  case SpmFloat:
240  rc = LAPACKE_slacpy_work( LAPACK_COL_MAJOR, 'A', mschur, n, (float *)bptr, rhsB->ld, B, ldb );
241  break;
242  default:
243  pastix_print_error( "pastixRhsSchurGet: unknown flttype\n" );
245  }
246 
247  return rc;
248 }
249 
250 /**
251  *******************************************************************************
252  *
253  * @ingroup pastix_solve
254  *
255  * @brief Set the vector in an RHS data structure.
256  *
257  *******************************************************************************
258  *
259  * @param[in] pastix_data
260  * TODO
261  *
262  * @param[in] m
263  * The number of rows of the vector b.
264  *
265  * @param[in] n
266  * The number of columns of the vector b.
267  *
268  * @param[in] B
269  * The vector b.
270  *
271  * @param[in] ldb
272  * The leading dimension of the vector b.
273  *
274  * @param[out] rhsB
275  * The pastix_rhs_t data structure which contains the vector b.
276  *
277  *******************************************************************************
278  *
279  * @retval PASTIX_SUCCESS on successful exit,
280  * @retval PASTIX_ERR_BADPARAMETER if one parameter is incorrect.
281  *
282  *******************************************************************************/
283 int
284 pastixRhsSchurSet( const pastix_data_t *pastix_data,
285  pastix_int_t m,
286  pastix_int_t n,
287  void *B,
288  pastix_int_t ldb,
289  pastix_rhs_t rhsB )
290 {
291  const SolverMatrix *solvmtx;
292  const SolverCblk *cblk;
293  pastix_int_t mschur;
294  void *bptr;
295  int rc;
296 
297  if ( pastix_data == NULL ) {
298  pastix_print_error( "pastixRhsSchurSet: wrong pastix_data parameter" );
300  }
301  if ( rhsB == NULL ) {
302  pastix_print_error( "pastixRhsSchurSet: wrong rhsB parameter" );
304  }
305  if ( B == NULL ) {
306  pastix_print_error( "pastixRhsSchurSet: wrong b parameter" );
308  }
309 
310  solvmtx = pastix_data->solvmatr;
311  cblk = solvmtx->cblktab + solvmtx->cblkschur;
312  mschur = solvmtx->nodenbr - cblk->fcolnum;
313 
314  if ( m != mschur ) {
315  pastix_print_error( "pastixRhsSchurSet: wrong m parameter expecting %ld but was %ld\n",
316  (long)mschur, (long)m );
318  }
319  if ( n != rhsB->n ) {
320  pastix_print_error( "pastixRhsSchurSet: wrong n parameter expecting %ld but was %ld\n",
321  (long)rhsB->n, (long)n );
323  }
324  if ( ldb < m ) {
325  pastix_print_error( "pastixRhsSchurSet: wrong ldb parameter\n" );
327  }
328 
329  bptr = ((char *)rhsB->b) + cblk->lcolidx * pastix_size_of( rhsB->flttype );
330 
331  switch( rhsB->flttype ) {
332  case SpmComplex64:
333  rc = LAPACKE_zlacpy_work( LAPACK_COL_MAJOR, 'A', mschur, n, B, ldb, (pastix_complex64_t *)bptr, rhsB->ld );
334  break;
335  case SpmComplex32:
336  rc = LAPACKE_clacpy_work( LAPACK_COL_MAJOR, 'A', mschur, n, B, ldb, (pastix_complex32_t *)bptr, rhsB->ld );
337  break;
338  case SpmDouble:
339  rc = LAPACKE_dlacpy_work( LAPACK_COL_MAJOR, 'A', mschur, n, B, ldb, (double *)bptr, rhsB->ld );
340  break;
341  case SpmFloat:
342  rc = LAPACKE_slacpy_work( LAPACK_COL_MAJOR, 'A', mschur, n, B, ldb, (float *)bptr, rhsB->ld );
343  break;
344  default:
345  pastix_print_error( "pastixRhsSchurSet: unknown flttype\n" );
347  }
348 
349  return rc;
350 }
351 
352 /**
353  * @}
354  */
BEGIN_C_DECLS typedef int pastix_int_t
Definition: datatypes.h:51
float _Complex pastix_complex32_t
Definition: datatypes.h:76
void coeftab_sgetschur(const SolverMatrix *solvmtx, float *S, pastix_int_t lds)
Extract the Schur complement.
Definition: coeftab_s.c:607
void coeftab_dgetschur(const SolverMatrix *solvmtx, double *S, pastix_int_t lds)
Extract the Schur complement.
Definition: coeftab_d.c:607
void coeftab_zgetschur(const SolverMatrix *solvmtx, pastix_complex64_t *S, pastix_int_t lds)
Extract the Schur complement.
Definition: coeftab_z.c:607
void coeftab_cgetschur(const SolverMatrix *solvmtx, pastix_complex32_t *S, pastix_int_t lds)
Extract the Schur complement.
Definition: coeftab_c.c:607
@ IPARM_FLOAT
Definition: api.h:149
@ PASTIX_SUCCESS
Definition: api.h:367
@ PASTIX_ERR_BADPARAMETER
Definition: api.h:374
void pastixSetSchurUnknownList(pastix_data_t *pastix_data, pastix_int_t n, const pastix_int_t *list)
Set the list of unknowns that belongs to the schur complement.
Definition: schur.c:50
int pastixGetSchur(const pastix_data_t *pastix_data, void *S, pastix_int_t lds)
Return the Schur complement.
Definition: schur.c:90
int pastixRhsSchurSet(const pastix_data_t *pastix_data, pastix_int_t m, pastix_int_t n, void *B, pastix_int_t ldb, pastix_rhs_t rhsB)
Set the vector in an RHS data structure.
Definition: schur.c:284
int pastixRhsSchurGet(const pastix_data_t *pastix_data, pastix_int_t m, pastix_int_t n, pastix_rhs_t rhsB, void *B, pastix_int_t ldb)
Get the vector in an RHS data structure.
Definition: schur.c:182
int inter_node_procnum
Definition: pastixdata.h:83
SolverMatrix * solvmatr
Definition: pastixdata.h:102
int inter_node_procnbr
Definition: pastixdata.h:82
pastix_int_t * iparm
Definition: pastixdata.h:69
pastix_int_t ld
Definition: pastixdata.h:155
pastix_coeftype_t flttype
Definition: pastixdata.h:152
pastix_int_t schur_n
Definition: pastixdata.h:93
pastix_int_t * schur_list
Definition: pastixdata.h:94
pastix_int_t steps
Definition: pastixdata.h:72
pastix_int_t n
Definition: pastixdata.h:154
Main PaStiX data structure.
Definition: pastixdata.h:67
Main PaStiX RHS structure.
Definition: pastixdata.h:150
pastix_int_t nodenbr
Definition: solver.h:205
pastix_int_t lcolidx
Definition: solver.h:165
SolverCblk *restrict cblktab
Definition: solver.h:222
pastix_int_t cblkschur
Definition: solver.h:217
pastix_int_t fcolnum
Definition: solver.h:161
Solver column block structure.
Definition: solver.h:156
Solver column block structure.
Definition: solver.h:200