PaStiX Handbook  6.4.0
s_refine_gmres.c
Go to the documentation of this file.
1 /**
2  *
3  * @file s_refine_gmres.c
4  *
5  * PaStiX refinement functions implementations.
6  *
7  * @copyright 2015-2024 Bordeaux INP, CNRS (LaBRI UMR 5800), Inria,
8  * Univ. Bordeaux. All rights reserved.
9  *
10  * @version 6.4.0
11  * @author Mathieu Faverge
12  * @author Pierre Ramet
13  * @author Theophile Terraz
14  * @author Xavier Lacoste
15  * @author Gregoire Pichon
16  * @author Vincent Bridonneau
17  * @date 2024-07-05
18  * @generated from /builds/solverstack/pastix/refinement/z_refine_gmres.c, normal z -> s, Tue Oct 8 14:17:56 2024
19  *
20  **/
21 #include "common.h"
22 #include "cblas.h"
23 #include "bcsc/bcsc.h"
24 #include "s_refine_functions.h"
25 
26 /**
27  *******************************************************************************
28  *
29  * @ingroup pastix_refine
30  *
31  * s_gmres_smp - Function computing GMRES iterative refinement.
32  *
33  *******************************************************************************
34  *
35  * @param[in] pastix_data
36  * The PaStiX data structure that describes the solver instance.
37  *
38  * @param[out] xp
39  * The solution vector.
40  *
41  * @param[in] bp
42  * The right hand side member (only one).
43  *******************************************************************************
44  *
45  * @return Number of iterations
46  *
47  *******************************************************************************/
49 s_gmres_smp( pastix_data_t *pastix_data,
50  pastix_rhs_t xp,
51  pastix_rhs_t bp )
52 {
53  struct s_solver solver;
54  Clock refine_clk;
55  float *x = (float*)(xp->b);
56  float *b = (float*)(bp->b);
57  float *gmHi, *gmH;
58  float *gmVi, *gmV;
59  float *gmWi, *gmW;
60  float *gmcos, *gmsin;
61  float *gmG;
62  float *sgmWi = NULL;
63 #if defined(PASTIX_DEBUG_GMRES)
64  float *dbg_x, *dbg_r, *dbg_G;
65 #endif
66  float tmp;
67  pastix_fixdbl_t t0, t3;
68  float eps, resid, resid_b;
69  float norm, normb, normx;
70  pastix_int_t n, im, im1, itermax;
71  pastix_int_t i, j, ldw, iters;
72  int outflag, inflag;
73  int savemem = 0;
74  int precond = 1;
75 
76  memset( &solver, 0, sizeof(struct s_solver) );
77  s_refine_init( &solver, pastix_data );
78 
79  /* if ( pastix_data->bcsc->mtxtype == PastixSymmetric ) { */
80  /* /\* Check if we need dot for non symmetric matrices (CEA patch) *\/ */
81  /* solver.dot = &bvec_sdot_seq; */
82  /* } */
83 
84  /* Get the parameters */
85  n = pastix_data->bcsc->n;
86  im = pastix_data->iparm[IPARM_GMRES_IM];
87  im1 = im + 1;
88  itermax = pastix_data->iparm[IPARM_ITERMAX];
89  eps = pastix_data->dparm[DPARM_EPSILON_REFINEMENT];
90  ldw = n;
91 
92  if ( !(pastix_data->steps & STEP_NUMFACT) ) {
93  precond = 0;
94  }
95 
96  if ((!precond) || savemem ) {
97  ldw = 0;
98  }
99 
100  gmcos = (float *)solver.malloc(im * sizeof(float));
101  gmsin = (float *)solver.malloc(im * sizeof(float));
102  gmG = (float *)solver.malloc(im1 * sizeof(float));
103 
104  /**
105  * H stores the h_{i,j} elements of the upper hessenberg matrix H (See Alg. 9.5 p 270)
106  * V stores the v_{i} vectors
107  * W stores the M^{-1} v_{i} vectors to avoid the application of the
108  * preconditioner on the output result (See line 11 of Alg 9.5)
109  *
110  * If no preconditioner is applied, or the user wants to save memory, W
111  * stores only temporarily one vector for the Ax product (ldw is set to 0 to
112  * reuse the same vector at each iteration)
113  */
114  gmH = (float *)solver.malloc(im * im1 * sizeof(float));
115  gmV = (float *)solver.malloc(n * im1 * sizeof(float));
116  if (precond && (!savemem) ) {
117  gmW = (float *)solver.malloc(n * im * sizeof(float));
118  }
119  else {
120  gmW = (float *)solver.malloc(n * sizeof(float));
121  }
122  memset( gmH, 0, im * im1 * sizeof(float) );
123 
124 #if defined(PASTIX_DEBUG_GMRES)
125  dbg_x = (float *)solver.malloc(n * sizeof(float));
126  dbg_r = (float *)solver.malloc(n * sizeof(float));
127  dbg_G = (float *)solver.malloc(im1 * sizeof(float));
128  solver.copy( pastix_data, n, x, dbg_x );
129 #endif
130 
131  normb = solver.norm( pastix_data, n, b );
132  if ( normb == 0. ) {
133  normb = 1;
134  }
135  normx = solver.norm( pastix_data, n, x );
136 
137  /* Allocating a vector at half-precision, NULL pointer otherwise */
138  if ( pastix_data->iparm[IPARM_MIXED] )
139  {
140  sgmWi = solver.malloc( n * sizeof(float) );
141  }
142 
143  clockInit(refine_clk);
144  clockStart(refine_clk);
145 
146  /**
147  * Algorithm from Iterative Methods for Sparse Linear systems, Y. Saad, Second Ed. p267-273
148  *
149  * The version implemented is the Right preconditioned algorithm.
150  */
151  outflag = 1;
152  iters = 0;
153  while (outflag)
154  {
155  /* Initialize v_{0} and w_{0} */
156  gmVi = gmV;
157 
158  /* Compute r0 = b - A * x */
159  solver.copy( pastix_data, n, b, gmVi );
160  if ( normx > 0. ) {
161  solver.spmv( pastix_data, PastixNoTrans, -1., x, 1., gmVi );
162  }
163 
164  /* Compute resid = ||r0||_f */
165  resid = solver.norm( pastix_data, n, gmVi );
166  resid_b = resid / normb;
167 
168  /* If residual is small enough, exit */
169  if ( resid_b <= eps )
170  {
171  outflag = 0;
172  break;
173  }
174 
175  /* Compute v0 = r0 / resid */
176  tmp = (float)( 1.0 / resid );
177  solver.scal( pastix_data, n, tmp, gmVi );
178 
179  gmG[0] = (float)resid;
180  inflag = 1;
181  i = -1;
182  gmHi = gmH - im1;
183  gmWi = gmW - ldw;
184 
185  while( inflag )
186  {
187  clockStop( refine_clk );
188  t0 = clockGet();
189 
190  i++;
191 
192  /* Set H and W pointers to the beginning of columns i */
193  gmHi = gmHi + im1;
194  gmWi = gmWi + ldw;
195 
196  /* Backup v_{i} into w_{i} for the end */
197  solver.copy( pastix_data, n, gmVi, gmWi );
198 
199  /* Compute w_{i} = M^{-1} v_{i} */
200  if ( precond ) {
201  solver.spsv( pastix_data, gmWi, sgmWi );
202  }
203 
204  /* v_{i+1} = A (M^{-1} v_{i}) = A w_{i} */
205  gmVi += n;
206  solver.spmv( pastix_data, PastixNoTrans, 1.0, gmWi, 0., gmVi );
207 
208  /* Classical Gram-Schmidt */
209  for (j=0; j<=i; j++)
210  {
211  /* Compute h_{j,i} = < v_{i+1}, v_{j} > */
212  gmHi[j] = solver.dot( pastix_data, n, gmVi, gmV + j * n );
213 
214  /* Compute v_{i+1} = v_{i+1} - h_{j,i} v_{j} */
215  solver.axpy( pastix_data, n, -1. * gmHi[j], gmV + j * n, gmVi );
216  }
217 
218  /* Compute || v_{i+1} ||_f */
219  norm = solver.norm( pastix_data, n, gmVi );
220  gmHi[i+1] = norm;
221 
222  /* Compute v_{i+1} = v_{i+1} / h_{i+1,i} iff h_{i+1,i} is not too small */
223  if ( norm > 1e-50 )
224  {
225  tmp = (float)(1.0 / norm);
226  solver.scal( pastix_data, n, tmp, gmVi );
227  }
228 
229  /* Apply the previous Givens rotation to the new column (should call LAPACKE_srot_work())*/
230  for (j=0; j<i;j++)
231  {
232  /*
233  * h_{j, i} = cos_j * h_{j, i} + sin_{j} * h_{j+1, i}
234  * h_{j+1,i} = cos_j * h_{j+1,i} - (sin_{j}) * h_{j, i}
235  */
236  tmp = gmHi[j];
237  gmHi[j] = gmcos[j] * tmp + gmsin[j] * gmHi[j+1];
238  gmHi[j+1] = gmcos[j] * gmHi[j+1] - (gmsin[j]) * tmp;
239  }
240 
241  /*
242  * Compute the new Givens rotation (srotg)
243  *
244  * t = sqrtf( h_{i,i}^2 + h_{i+1,i}^2 )
245  * cos = h_{i,i} / t
246  * sin = h_{i+1,i} / t
247  */
248  {
249  tmp = sqrtf( gmHi[i] * gmHi[i] +
250  gmHi[i+1] * gmHi[i+1] );
251 
252  if ( fabsf(tmp) <= eps ) {
253  tmp = (float)eps;
254  }
255  gmcos[i] = gmHi[i] / tmp;
256  gmsin[i] = gmHi[i+1] / tmp;
257  }
258 
259  /* Update the residuals (See p. 168, eq 6.35) */
260  gmG[i+1] = -gmsin[i] * gmG[i];
261  gmG[i] = gmcos[i] * gmG[i];
262 
263  /* Apply the last Givens rotation */
264  gmHi[i] = gmcos[i] * gmHi[i] + gmsin[i] * gmHi[i+1];
265 
266  /* (See p. 169, eq 6.42) */
267  resid = fabsf( gmG[i+1] );
268 
269  resid_b = resid / normb;
270  iters++;
271  if ( (i+1 >= im) ||
272  (resid_b <= eps) ||
273  (iters >= itermax) )
274  {
275  inflag = 0;
276  }
277 
278  clockStop((refine_clk));
279  t3 = clockGet();
280  if ( ( pastix_data->iparm[IPARM_VERBOSE] > PastixVerboseNot ) &&
281  ( pastix_data->procnum == 0 ) ) {
282  solver.output_oneiter( t0, t3, resid_b, iters );
283 
284 #if defined(PASTIX_DEBUG_GMRES)
285  {
286  float normr2;
287 
288  /* Compute y_m = H_m^{-1} g_m (See p. 169) */
289  memcpy( dbg_G, gmG, im1 * sizeof(float) );
290  cblas_strsv( CblasColMajor, CblasUpper, CblasNoTrans, CblasNonUnit,
291  i+1, gmH, im1, dbg_G, 1 );
292 
293  solver.copy( pastix_data, n, b, dbg_r );
294  solver.copy( pastix_data, n, x, dbg_x );
295 
296  /* Accumulate the current v_m */
297  solver.gemv( pastix_data, n, i+1, 1.0, (precond ? gmW : gmV), n, dbg_G, 1.0, dbg_x );
298 
299  /* Compute b - Ax */
300  solver.spmv( pastix_data, PastixNoTrans, -1., dbg_x, 1., dbg_r );
301 
302  normr2 = solver.norm( pastix_data, n, dbg_r );
303  fprintf(stdout, OUT_ITERREFINE_ERR, normr2 / normb );
304  }
305 #endif
306  }
307  }
308 
309  /* Compute y_m = H_m^{-1} g_m (See p. 169) */
310  cblas_strsv( CblasColMajor, CblasUpper, CblasNoTrans, CblasNonUnit,
311  i+1, gmH, im1, gmG, 1 );
312 
313  /**
314  * Compute x_m = x_0 + M^{-1} V_m y_m
315  * = x_0 + W_m y_m
316  */
317  if (precond && savemem) {
318  /**
319  * Since we saved memory, we do not have (M^{-1} V_m) stored,
320  * thus we compute:
321  * w = V_m y_m
322  * w = M^{-1} (V_m y_m)
323  * x = x0 + (M^{-1} (V_m y_m))
324  */
325  solver.gemv( pastix_data, n, i+1, 1.0, gmV, n, gmG, 0., gmW );
326  solver.spsv( pastix_data, gmW, sgmWi );
327  solver.axpy( pastix_data, n, 1., gmW, x );
328  }
329  else {
330  /**
331  * Since we did not saved memory, we do have (M^{-1} V_m) stored in
332  * W_m if precond is true, thus we compute:
333  * x = x0 + W_m y_m, if precond
334  * x = x0 + V_m y_m, if not precond
335  */
336  gmWi = precond ? gmW : gmV;
337  solver.gemv( pastix_data, n, i+1, 1.0, gmWi, n, gmG, 1.0, x );
338  }
339 
340  /**
341  * Exit only if maximum number of iteration is reached.
342  * Exit on residual if checked at the beginning of the outer loop to be
343  * sure that the final residual of Ax-b is equal to the estimator
344  * computed within the inner loop.
345  */
346  if (iters >= itermax)
347  {
348  outflag = 0;
349  }
350  }
351 
352  clockStop( refine_clk );
353  t3 = clockGet();
354 
355  solver.output_final( pastix_data, resid_b, iters, t3, x, x );
356 
357  solver.free(gmcos);
358  solver.free(gmsin);
359  solver.free(gmG);
360  solver.free(gmH);
361  solver.free(gmV);
362  solver.free(gmW);
363  solver.free(sgmWi);
364 #if defined(PASTIX_DEBUG_GMRES)
365  solver.free(dbg_x);
366  solver.free(dbg_r);
367  solver.free(dbg_G);
368 #endif
369 
370  return iters;
371 }
BEGIN_C_DECLS typedef int pastix_int_t
Definition: datatypes.h:51
double pastix_fixdbl_t
Definition: datatypes.h:65
@ DPARM_EPSILON_REFINEMENT
Definition: api.h:161
@ IPARM_MIXED
Definition: api.h:139
@ IPARM_GMRES_IM
Definition: api.h:114
@ IPARM_ITERMAX
Definition: api.h:113
@ IPARM_VERBOSE
Definition: api.h:36
@ PastixNoTrans
Definition: api.h:445
@ PastixVerboseNot
Definition: api.h:220
void s_refine_init(struct s_solver *, pastix_data_t *)
Initiate functions pointers to define basic operations.
pastix_int_t s_gmres_smp(pastix_data_t *pastix_data, pastix_rhs_t xp, pastix_rhs_t bp)
pastix_int_t * iparm
Definition: pastixdata.h:70
double * dparm
Definition: pastixdata.h:71
pastix_bcsc_t * bcsc
Definition: pastixdata.h:102
pastix_int_t steps
Definition: pastixdata.h:73
Main PaStiX data structure.
Definition: pastixdata.h:68
Main PaStiX RHS structure.
Definition: pastixdata.h:155