PaStiX Handbook  6.3.2
schur.c
Go to the documentation of this file.
1 /**
2  * @file example/schur.c
3  *
4  * @brief Schur usage example.
5  *
6  * @copyright 2015-2023 Bordeaux INP, CNRS (LaBRI UMR 5800), Inria,
7  * Univ. Bordeaux. All rights reserved.
8  *
9  * @version 6.3.2
10  * @author Pierre Ramet
11  * @author Mathieu Faverge
12  * @author Matias Hastaran
13  * @author Tony Delarue
14  * @author Alycia Lisito
15  * @date 2023-07-21
16  *
17  * @ingroup pastix_examples
18  * @code
19  *
20  */
21 #include <pastix.h>
22 #include <spm.h>
23 #include <lapacke.h>
24 
25 void
26 schurFactorize( pastix_coeftype_t flttype,
27  pastix_factotype_t factotype,
28  pastix_int_t N,
29  void *S,
30  pastix_int_t lds,
31  int **ipiv )
32 {
33  int info = 0;
34 
35  assert( ipiv != NULL );
36  if ( factotype == PastixFactGETRF ) {
37  *ipiv = malloc( N * sizeof(int) );
38  }
39 
40  switch (flttype) {
41  case PastixFloat:
42  switch (factotype) {
43  case PastixFactPOTRF:
44  info = LAPACKE_spotrf_work( LAPACK_COL_MAJOR, 'L', N, S, lds );
45  break;
46  case PastixFactGETRF:
47  info = LAPACKE_sgetrf_work( LAPACK_COL_MAJOR, N, N, S, lds, *ipiv );
48  break;
49  default:
50  fprintf(stderr, "Factorization type not handled by Schur example\n");
51  }
52  break;
53  case PastixComplex32:
54  switch (factotype) {
55  case PastixFactPOTRF:
56  info = LAPACKE_cpotrf_work( LAPACK_COL_MAJOR, 'L', N, S, lds );
57  break;
58  case PastixFactGETRF:
59  info = LAPACKE_cgetrf_work( LAPACK_COL_MAJOR, N, N, S, lds, *ipiv );
60  break;
61  default:
62  fprintf(stderr, "Factorization type not handled by Schur example\n");
63  }
64  break;
65  case PastixComplex64:
66  switch (factotype) {
67  case PastixFactPOTRF:
68  info = LAPACKE_zpotrf_work( LAPACK_COL_MAJOR, 'L', N, S, lds );
69  break;
70  case PastixFactGETRF:
71  info = LAPACKE_zgetrf_work( LAPACK_COL_MAJOR, N, N, S, lds, *ipiv );
72  break;
73  default:
74  fprintf(stderr, "Factorization type not handled by Schur example\n");
75  }
76  break;
77  case PastixDouble:
78  switch (factotype) {
79  case PastixFactPOTRF:
80  info = LAPACKE_dpotrf_work( LAPACK_COL_MAJOR, 'L', N, S, lds );
81  break;
82  case PastixFactGETRF:
83  info = LAPACKE_dgetrf_work( LAPACK_COL_MAJOR, N, N, S, lds, *ipiv );
84  break;
85  default:
86  fprintf(stderr, "Factorization type not handled by Schur example\n");
87  }
88  break;
89  default:
90  fprintf(stderr, "Incorrect arithmetic type\n");
91  }
92  if (info != 0) {
93  fprintf(stderr, "Error in schurFactorize with info =%d\n", info );
94  }
95  return;
96 }
97 
98 void
99 schurSolve( pastix_coeftype_t flttype,
100  pastix_factotype_t factotype,
101  pastix_int_t Nschur,
102  pastix_int_t NRHS,
103  void *S,
104  pastix_int_t lds,
105  void *bptr,
106  pastix_int_t ldb,
107  int **ipiv )
108 {
109  int info = 0;
110 
111  assert(ipiv != NULL);
112 
113  switch (flttype) {
114  case PastixFloat:
115  {
116  float *b = (float *)bptr;
117 
118  switch (factotype) {
119  case PastixFactPOTRF:
120  info = LAPACKE_spotrs_work( LAPACK_COL_MAJOR, 'L', Nschur, NRHS, S, lds, b, ldb );
121  break;
122  case PastixFactGETRF:
123  info = LAPACKE_sgetrs_work( LAPACK_COL_MAJOR, 'N', Nschur, NRHS, S, lds, *ipiv, b, ldb );
124  break;
125  default:
126  fprintf(stderr, "Factorization type not handled by Schur example\n");
127  }
128  }
129  break;
130  case PastixComplex32:
131  {
133 
134  switch (factotype) {
135  case PastixFactPOTRF:
136  info = LAPACKE_cpotrs_work( LAPACK_COL_MAJOR, 'L', Nschur, NRHS, S, lds, b, ldb );
137  break;
138  case PastixFactGETRF:
139  info = LAPACKE_cgetrs_work( LAPACK_COL_MAJOR, 'N', Nschur, NRHS, S, lds, *ipiv, b, ldb );
140  break;
141  default:
142  fprintf(stderr, "Factorization type not handled by Schur example\n");
143  }
144  }
145  break;
146  case PastixComplex64:
147  {
148  pastix_complex64_t *b = (pastix_complex64_t *)bptr;
149 
150  switch (factotype) {
151  case PastixFactPOTRF:
152  info = LAPACKE_zpotrs_work( LAPACK_COL_MAJOR, 'L', Nschur, NRHS, S, lds, b, ldb );
153  break;
154  case PastixFactGETRF:
155  info = LAPACKE_zgetrs_work( LAPACK_COL_MAJOR, 'N', Nschur, NRHS, S, lds, *ipiv, b, ldb );
156  break;
157  default:
158  fprintf(stderr, "Factorization type not handled by Schur example\n");
159  }
160  }
161  break;
162  case PastixDouble:
163  {
164  double *b = (double *)bptr;
165 
166  switch (factotype) {
167  case PastixFactPOTRF:
168  info = LAPACKE_dpotrs_work( LAPACK_COL_MAJOR, 'L', Nschur, NRHS, S, lds, b, ldb );
169  break;
170  case PastixFactGETRF:
171  info = LAPACKE_dgetrs_work( LAPACK_COL_MAJOR, 'N', Nschur, NRHS, S, lds, *ipiv, b, ldb );
172  break;
173  default:
174  fprintf(stderr, "Factorization type not handled by Schur example\n");
175  }
176  }
177  break;
178  default:
179  fprintf(stderr, "Incorrect arithmetic type\n");
180  }
181 
182  if (*ipiv != NULL) {
183  free( *ipiv );
184  *ipiv = NULL;
185  }
186 
187  if (info != 0) {
188  fprintf(stderr, "Error in schurSolve with info =%d\n", info );
189  }
190 
191  return;
192 }
193 
194 int main (int argc, char **argv)
195 {
196  pastix_data_t *pastix_data = NULL; /*< Pointer to the storage structure required by pastix */
197  pastix_int_t iparm[IPARM_SIZE]; /*< Integer in/out parameters for pastix */
198  double dparm[DPARM_SIZE]; /*< Floating in/out parameters for pastix */
199  spm_driver_t driver;
200  char *filename = NULL;
201  spmatrix_t *spm, spm2;
202  void *x, *b, *S, *x0 = NULL;
203  size_t size;
204  int scatter = 0;
205  int check = 1;
206  int nrhs = 1;
207  int rc = 0;
208  pastix_int_t nschur, lds, ldb;
209  int *ipiv = NULL;
211  pastix_rhs_t Xp;
212 
213  /**
214  * Initialize parameters to default values
215  */
216  pastixInitParam( iparm, dparm );
217 
218  /**
219  * Get options from command line
220  */
221  pastixGetOptions( argc, argv,
222  iparm, dparm,
223  &check, &scatter, &driver, &filename );
224 
225 
226  if ( (iparm[IPARM_FACTORIZATION] == PastixFactLDLT) ||
227  (iparm[IPARM_FACTORIZATION] == PastixFactLDLH) )
228  {
229  fprintf(stderr, "This types of factorization (LDL^t and LDL^h) are not supported by this example.\n");
230  return EXIT_FAILURE;
231  }
232 
233  /**
234  * Startup PaStiX
235  */
236  pastixInit( &pastix_data, MPI_COMM_WORLD, iparm, dparm );
237 
238  /**
239  * Read the sparse matrix with the driver
240  */
241  spm = malloc( sizeof( spmatrix_t ) );
242  if ( scatter ) {
243  rc = spmReadDriverDist( driver, filename, spm, MPI_COMM_WORLD );
244  }
245  else {
246  rc = spmReadDriver( driver, filename, spm );
247  }
248  free( filename );
249  if ( rc != SPM_SUCCESS ) {
250  pastixFinalize( &pastix_data );
251  return rc;
252  }
253 
254  spmPrintInfo( spm, stdout );
255 
256  rc = spmCheckAndCorrect( spm, &spm2 );
257  if ( rc != 0 ) {
258  spmExit( spm );
259  *spm = spm2;
260  }
261 
262  /**
263  * Generate a Fake values array if needed for the numerical part
264  */
265  if ( spm->flttype == SpmPattern ) {
266  spmGenFakeValues( spm );
267  }
268 
269  /**
270  * Initialize the schur list with the first third of the unknowns
271  */
272  {
273  nschur = spm->gN / 3;
274  /* Set to a maximum to avoid memory problem with the test */
275  nschur = (nschur > 5000) ? 5000 : nschur;
276 
277  if ( nschur > 0 ) {
278  pastix_int_t i;
279  pastix_int_t baseval = spmFindBase(spm);
280  pastix_int_t *list = (pastix_int_t*)malloc(nschur * sizeof(pastix_int_t));
281 
282  for (i=0; i<nschur; i++) {
283  list[i] = i+baseval;
284  }
285  pastixSetSchurUnknownList( pastix_data, nschur, list );
286  free( list );
287  }
288  iparm[IPARM_SCHUR_SOLV_MODE] = PastixSolvModeInterface;
289  }
290 
291  /**
292  * Perform ordering, symbolic factorization, and analyze steps
293  */
294  pastix_task_analyze( pastix_data, spm );
295 
296  /**
297  * Normalize A matrix (optional, but recommended for low-rank functionality)
298  */
299  double normA = spmNorm( SpmFrobeniusNorm, spm );
300  spmScal( 1./normA, spm );
301 
302  /**
303  * Perform the numerical factorization
304  */
305  pastix_task_numfact( pastix_data, spm );
306 
307  /**
308  * Get the Schur complement back
309  */
310  lds = nschur;
311  S = malloc( pastix_size_of( spm->flttype ) * nschur * lds );
312 
313  rc = pastixGetSchur( pastix_data, S, lds );
314 
315  if( rc == -1 ){
316  spmExit( spm );
317  free( spm );
318  free( S );
319  pastixFinalize( &pastix_data );
320  exit(0);
321  }
322 
323  /**
324  * Factorize the Schur complement
325  */
326  schurFactorize( spm->flttype, iparm[IPARM_FACTORIZATION],
327  nschur, S, lds, &ipiv );
328 
329  /**
330  * Generates the b and x vector such that A * x = b
331  * Compute the norms of the initial vectors if checking purpose.
332  */
333  size = pastix_size_of( spm->flttype ) * spm->nexp * nrhs;
334  x = malloc( size );
335  b = malloc( size );
336  ldb = spm->nexp;
337 
338  if ( check )
339  {
340  if ( check > 1 ) {
341  x0 = malloc( size );
342  }
343  spmGenRHS( SpmRhsRndX, nrhs, spm, x0, spm->nexp, b, spm->nexp );
344  memcpy( x, b, size );
345  }
346  else {
347  spmGenRHS( SpmRhsRndB, nrhs, spm, NULL, spm->nexp, x, spm->nexp );
348 
349  /* Apply also normalization to b vectors */
350  spmScalMat( 1./normA, spm, nrhs, b, spm->nexp );
351  }
352 
353  /**
354  * Solve the linear system Ax = (P^tLUP)x = b
355  */
356  /* 1- Apply P to b */
357  pastixRhsInit( &Xp );
359  spm->nexp, nrhs, x, ldb, Xp );
360 
361  /* 2- Forward solve on the non Schur complement part of the system */
362  if ( iparm[IPARM_FACTORIZATION] == PastixFactPOTRF ) {
363  diag = PastixNonUnit;
364  }
365  else if( iparm[IPARM_FACTORIZATION] == PastixFactGETRF ) {
366  diag = PastixUnit;
367  }
368 
369  pastix_subtask_trsm( pastix_data, PastixLeft, PastixLower, PastixNoTrans, diag, Xp );
370 
371  /* 3- Solve the Schur complement part */
372  {
373  void *schur_x;
374 
375  size = pastix_size_of( spm->flttype ) * nschur * nrhs;
376  schur_x = malloc( size );
377 
378  pastixRhsSchurGet( pastix_data, nschur, nrhs, Xp, schur_x, nschur );
379 
380  schurSolve( spm->flttype, iparm[IPARM_FACTORIZATION],
381  nschur, nrhs, S, lds, schur_x, nschur, &ipiv );
382 
383  pastixRhsSchurSet( pastix_data, nschur, nrhs, schur_x, nschur, Xp );
384  free( schur_x );
385  }
386 
387  /* 4- Backward solve on the non Schur complement part of the system */
388  if ( iparm[IPARM_FACTORIZATION] == PastixFactPOTRF ) {
389  pastix_subtask_trsm( pastix_data,
391  Xp );
392  }
393  else if( iparm[IPARM_FACTORIZATION] == PastixFactGETRF ) {
394  pastix_subtask_trsm( pastix_data,
396  Xp );
397  }
398 
399  /* 5- Apply P^t to x */
401  spm->nexp, nrhs, x, ldb, Xp );
402  pastixRhsFinalize( Xp );
403 
404  if ( check )
405  {
406  rc = spmCheckAxb( dparm[DPARM_EPSILON_REFINEMENT], nrhs, spm, x0, spm->nexp, b, spm->nexp, x, spm->nexp );
407 
408  if ( x0 ) {
409  free( x0 );
410  }
411  }
412 
413  spmExit( spm );
414  free( spm );
415  free( S );
416  free( x );
417  free( b );
418  pastixFinalize( &pastix_data );
419 
420  return rc;
421 }
422 
423 /**
424  * @endcode
425  */
BEGIN_C_DECLS typedef int pastix_int_t
Definition: datatypes.h:51
float _Complex pastix_complex32_t
Definition: datatypes.h:76
int pastixRhsInit(pastix_rhs_t *rhs)
Initialize an RHS data structure.
Definition: pastix_rhs.c:41
int pastixRhsFinalize(pastix_rhs_t rhs)
Cleanup an RHS data structure.
Definition: pastix_rhs.c:86
spm_coeftype_t pastix_coeftype_t
Arithmetic types.
Definition: api.h:294
void pastixFinalize(pastix_data_t **pastix_data)
Finalize the solver instance.
Definition: api.c:919
enum pastix_diag_e pastix_diag_t
Diagonal.
void pastixInitParam(pastix_int_t *iparm, double *dparm)
Initialize the iparm and dparm arrays to their default values.
Definition: api.c:411
void pastixInit(pastix_data_t **pastix_data, PASTIX_Comm pastix_comm, pastix_int_t *iparm, double *dparm)
Initialize the solver instance.
Definition: api.c:896
enum pastix_factotype_e pastix_factotype_t
Factorization algorithms available for IPARM_FACTORIZATION parameter.
@ PastixDirForward
Definition: api.h:513
@ PastixDirBackward
Definition: api.h:514
@ PastixFactLDLH
Definition: api.h:319
@ PastixFactLDLT
Definition: api.h:316
@ PastixFactPOTRF
Definition: api.h:310
@ PastixFactGETRF
Definition: api.h:312
@ DPARM_EPSILON_REFINEMENT
Definition: api.h:161
@ IPARM_FACTORIZATION
Definition: api.h:99
@ IPARM_SCHUR_SOLV_MODE
Definition: api.h:107
@ PastixUpper
Definition: api.h:466
@ PastixLower
Definition: api.h:467
@ PastixLeft
Definition: api.h:495
@ PastixUnit
Definition: api.h:488
@ PastixNonUnit
Definition: api.h:487
@ PastixConjTrans
Definition: api.h:447
@ PastixNoTrans
Definition: api.h:445
void pastixGetOptions(int argc, char **argv, pastix_int_t *iparm, double *dparm, int *check, int *scatter, spm_driver_t *driver, char **filename)
PaStiX helper function to read command line options in examples.
Definition: get_options.c:149
void pastixSetSchurUnknownList(pastix_data_t *pastix_data, pastix_int_t n, const pastix_int_t *list)
Set the list of unknowns that belongs to the schur complement.
Definition: schur.c:50
int pastixGetSchur(const pastix_data_t *pastix_data, void *S, pastix_int_t lds)
Return the Schur complement.
Definition: schur.c:90
int pastixRhsSchurSet(const pastix_data_t *pastix_data, pastix_int_t m, pastix_int_t n, void *B, pastix_int_t ldb, pastix_rhs_t rhsB)
Set the vector in an RHS data structure.
Definition: schur.c:284
int pastix_subtask_applyorder(pastix_data_t *pastix_data, pastix_dir_t dir, pastix_int_t m, pastix_int_t n, void *B, pastix_int_t ldb, pastix_rhs_t Bp)
Apply a permutation on the right-and-side vector before the solve step.
int pastix_subtask_trsm(pastix_data_t *pastix_data, pastix_side_t side, pastix_uplo_t uplo, pastix_trans_t trans, pastix_diag_t diag, pastix_rhs_t b)
Apply a triangular solve on the right-and-side vectors.
int pastixRhsSchurGet(const pastix_data_t *pastix_data, pastix_int_t m, pastix_int_t n, pastix_rhs_t rhsB, void *B, pastix_int_t ldb)
Get the vector in an RHS data structure.
Definition: schur.c:182
int pastix_task_analyze(pastix_data_t *pastix_data, const spmatrix_t *spm)
Perform all the preprocessing steps: ordering, symbolic factorization, reordering,...
int pastix_task_numfact(pastix_data_t *pastix_data, spmatrix_t *spm)
Perform all the numerical factorization steps: fill the internal block CSC and the solver matrix stru...
Main PaStiX data structure.
Definition: pastixdata.h:67
Main PaStiX RHS structure.
Definition: pastixdata.h:150