PaStiX Handbook 6.4.0
Loading...
Searching...
No Matches
s_refine_gmres.c
Go to the documentation of this file.
1/**
2 *
3 * @file s_refine_gmres.c
4 *
5 * PaStiX refinement functions implementations.
6 *
7 * @copyright 2015-2024 Bordeaux INP, CNRS (LaBRI UMR 5800), Inria,
8 * Univ. Bordeaux. All rights reserved.
9 *
10 * @version 6.4.0
11 * @author Mathieu Faverge
12 * @author Pierre Ramet
13 * @author Theophile Terraz
14 * @author Xavier Lacoste
15 * @author Gregoire Pichon
16 * @author Vincent Bridonneau
17 * @date 2024-07-05
18 * @generated from /builds/2mk6rsew/0/solverstack/pastix/refinement/z_refine_gmres.c, normal z -> s, Tue Feb 25 14:36:04 2025
19 *
20 **/
21#include "common.h"
22#include "cblas.h"
23#include "bcsc/bcsc.h"
24#include "s_refine_functions.h"
25
26/**
27 *******************************************************************************
28 *
29 * @ingroup pastix_refine
30 *
31 * s_gmres_smp - Function computing GMRES iterative refinement.
32 *
33 *******************************************************************************
34 *
35 * @param[in] pastix_data
36 * The PaStiX data structure that describes the solver instance.
37 *
38 * @param[out] xp
39 * The solution vector.
40 *
41 * @param[in] bp
42 * The right hand side member (only one).
43 *******************************************************************************
44 *
45 * @return Number of iterations
46 *
47 *******************************************************************************/
50 pastix_rhs_t xp,
51 pastix_rhs_t bp )
52{
53 struct s_solver solver;
54 Clock refine_clk;
55 float *x = (float*)(xp->b);
56 float *b = (float*)(bp->b);
57 float *gmHi, *gmH;
58 float *gmVi, *gmV;
59 float *gmWi, *gmW;
60 float *gmcos, *gmsin;
61 float *gmG;
62 float *sgmWi = NULL;
63#if defined(PASTIX_DEBUG_GMRES)
64 float *dbg_x, *dbg_r, *dbg_G;
65#endif
66 float tmp;
67 pastix_fixdbl_t t0, t3;
68 float eps, resid, resid_b;
69 float norm, normb, normx;
70 pastix_int_t n, im, im1, itermax;
71 pastix_int_t i, j, ldw, iters;
72 int outflag, inflag;
73 int savemem = 0;
74 int precond = 1;
75
76 memset( &solver, 0, sizeof(struct s_solver) );
77 s_refine_init( &solver, pastix_data );
78
79 /* if ( pastix_data->bcsc->mtxtype == PastixSymmetric ) { */
80 /* /\* Check if we need dot for non symmetric matrices (CEA patch) *\/ */
81 /* solver.dot = &bvec_sdot_seq; */
82 /* } */
83
84 /* Get the parameters */
85 n = pastix_data->bcsc->n;
86 im = pastix_data->iparm[IPARM_GMRES_IM];
87 im1 = im + 1;
88 itermax = pastix_data->iparm[IPARM_ITERMAX];
89 eps = pastix_data->dparm[DPARM_EPSILON_REFINEMENT];
90 ldw = n;
91
92 if ( !(pastix_data->steps & STEP_NUMFACT) ) {
93 precond = 0;
94 }
95
96 if ((!precond) || savemem ) {
97 ldw = 0;
98 }
99
100 gmcos = (float *)solver.malloc(im * sizeof(float));
101 gmsin = (float *)solver.malloc(im * sizeof(float));
102 gmG = (float *)solver.malloc(im1 * sizeof(float));
103
104 /**
105 * H stores the h_{i,j} elements of the upper hessenberg matrix H (See Alg. 9.5 p 270)
106 * V stores the v_{i} vectors
107 * W stores the M^{-1} v_{i} vectors to avoid the application of the
108 * preconditioner on the output result (See line 11 of Alg 9.5)
109 *
110 * If no preconditioner is applied, or the user wants to save memory, W
111 * stores only temporarily one vector for the Ax product (ldw is set to 0 to
112 * reuse the same vector at each iteration)
113 */
114 gmH = (float *)solver.malloc(im * im1 * sizeof(float));
115 gmV = (float *)solver.malloc(n * im1 * sizeof(float));
116 if (precond && (!savemem) ) {
117 gmW = (float *)solver.malloc(n * im * sizeof(float));
118 }
119 else {
120 gmW = (float *)solver.malloc(n * sizeof(float));
121 }
122 memset( gmH, 0, im * im1 * sizeof(float) );
123
124#if defined(PASTIX_DEBUG_GMRES)
125 dbg_x = (float *)solver.malloc(n * sizeof(float));
126 dbg_r = (float *)solver.malloc(n * sizeof(float));
127 dbg_G = (float *)solver.malloc(im1 * sizeof(float));
128 solver.copy( pastix_data, n, x, dbg_x );
129#endif
130
131 normb = solver.norm( pastix_data, n, b );
132 if ( normb == 0. ) {
133 normb = 1;
134 }
135 normx = solver.norm( pastix_data, n, x );
136
137 /* Allocating a vector at half-precision, NULL pointer otherwise */
138 if ( pastix_data->iparm[IPARM_MIXED] )
139 {
140 sgmWi = solver.malloc( n * sizeof(float) );
141 }
142
143 clockInit(refine_clk);
144 clockStart(refine_clk);
145
146 /**
147 * Algorithm from Iterative Methods for Sparse Linear systems, Y. Saad, Second Ed. p267-273
148 *
149 * The version implemented is the Right preconditioned algorithm.
150 */
151 outflag = 1;
152 iters = 0;
153 while (outflag)
154 {
155 /* Initialize v_{0} and w_{0} */
156 gmVi = gmV;
157
158 /* Compute r0 = b - A * x */
159 solver.copy( pastix_data, n, b, gmVi );
160 if ( normx > 0. ) {
161 solver.spmv( pastix_data, PastixNoTrans, -1., x, 1., gmVi );
162 }
163
164 /* Compute resid = ||r0||_f */
165 resid = solver.norm( pastix_data, n, gmVi );
166 resid_b = resid / normb;
167
168 /* If residual is small enough, exit */
169 if ( resid_b <= eps )
170 {
171 outflag = 0;
172 break;
173 }
174
175 /* Compute v0 = r0 / resid */
176 tmp = (float)( 1.0 / resid );
177 solver.scal( pastix_data, n, tmp, gmVi );
178
179 gmG[0] = (float)resid;
180 inflag = 1;
181 i = -1;
182 gmHi = gmH - im1;
183 gmWi = gmW - ldw;
184
185 while( inflag )
186 {
187 clockStop( refine_clk );
188 t0 = clockGet();
189
190 i++;
191
192 /* Set H and W pointers to the beginning of columns i */
193 gmHi = gmHi + im1;
194 gmWi = gmWi + ldw;
195
196 /* Backup v_{i} into w_{i} for the end */
197 solver.copy( pastix_data, n, gmVi, gmWi );
198
199 /* Compute w_{i} = M^{-1} v_{i} */
200 if ( precond ) {
201 solver.spsv( pastix_data, gmWi, sgmWi );
202 }
203
204 /* v_{i+1} = A (M^{-1} v_{i}) = A w_{i} */
205 gmVi += n;
206 solver.spmv( pastix_data, PastixNoTrans, 1.0, gmWi, 0., gmVi );
207
208 /* Classical Gram-Schmidt */
209 for (j=0; j<=i; j++)
210 {
211 /* Compute h_{j,i} = < v_{i+1}, v_{j} > */
212 gmHi[j] = solver.dot( pastix_data, n, gmVi, gmV + j * n );
213
214 /* Compute v_{i+1} = v_{i+1} - h_{j,i} v_{j} */
215 solver.axpy( pastix_data, n, -1. * gmHi[j], gmV + j * n, gmVi );
216 }
217
218 /* Compute || v_{i+1} ||_f */
219 norm = solver.norm( pastix_data, n, gmVi );
220 gmHi[i+1] = norm;
221
222 /* Compute v_{i+1} = v_{i+1} / h_{i+1,i} iff h_{i+1,i} is not too small */
223 if ( norm > 1e-50 )
224 {
225 tmp = (float)(1.0 / norm);
226 solver.scal( pastix_data, n, tmp, gmVi );
227 }
228
229 /* Apply the previous Givens rotation to the new column (should call LAPACKE_srot_work())*/
230 for (j=0; j<i;j++)
231 {
232 /*
233 * h_{j, i} = cos_j * h_{j, i} + sin_{j} * h_{j+1, i}
234 * h_{j+1,i} = cos_j * h_{j+1,i} - (sin_{j}) * h_{j, i}
235 */
236 tmp = gmHi[j];
237 gmHi[j] = gmcos[j] * tmp + gmsin[j] * gmHi[j+1];
238 gmHi[j+1] = gmcos[j] * gmHi[j+1] - (gmsin[j]) * tmp;
239 }
240
241 /*
242 * Compute the new Givens rotation (srotg)
243 *
244 * t = sqrtf( h_{i,i}^2 + h_{i+1,i}^2 )
245 * cos = h_{i,i} / t
246 * sin = h_{i+1,i} / t
247 */
248 {
249 tmp = sqrtf( gmHi[i] * gmHi[i] +
250 gmHi[i+1] * gmHi[i+1] );
251
252 if ( fabsf(tmp) <= eps ) {
253 tmp = (float)eps;
254 }
255 gmcos[i] = gmHi[i] / tmp;
256 gmsin[i] = gmHi[i+1] / tmp;
257 }
258
259 /* Update the residuals (See p. 168, eq 6.35) */
260 gmG[i+1] = -gmsin[i] * gmG[i];
261 gmG[i] = gmcos[i] * gmG[i];
262
263 /* Apply the last Givens rotation */
264 gmHi[i] = gmcos[i] * gmHi[i] + gmsin[i] * gmHi[i+1];
265
266 /* (See p. 169, eq 6.42) */
267 resid = fabsf( gmG[i+1] );
268
269 resid_b = resid / normb;
270 iters++;
271 if ( (i+1 >= im) ||
272 (resid_b <= eps) ||
273 (iters >= itermax) )
274 {
275 inflag = 0;
276 }
277
278 clockStop((refine_clk));
279 t3 = clockGet();
280 if ( ( pastix_data->iparm[IPARM_VERBOSE] > PastixVerboseNot ) &&
281 ( pastix_data->procnum == 0 ) ) {
282 solver.output_oneiter( t0, t3, resid_b, iters );
283
284#if defined(PASTIX_DEBUG_GMRES)
285 {
286 float normr2;
287
288 /* Compute y_m = H_m^{-1} g_m (See p. 169) */
289 memcpy( dbg_G, gmG, im1 * sizeof(float) );
290 cblas_strsv( CblasColMajor, CblasUpper, CblasNoTrans, CblasNonUnit,
291 i+1, gmH, im1, dbg_G, 1 );
292
293 solver.copy( pastix_data, n, b, dbg_r );
294 solver.copy( pastix_data, n, x, dbg_x );
295
296 /* Accumulate the current v_m */
297 solver.gemv( pastix_data, n, i+1, 1.0, (precond ? gmW : gmV), n, dbg_G, 1.0, dbg_x );
298
299 /* Compute b - Ax */
300 solver.spmv( pastix_data, PastixNoTrans, -1., dbg_x, 1., dbg_r );
301
302 normr2 = solver.norm( pastix_data, n, dbg_r );
303 fprintf(stdout, OUT_ITERREFINE_ERR, normr2 / normb );
304 }
305#endif
306 }
307 }
308
309 /* Compute y_m = H_m^{-1} g_m (See p. 169) */
310 cblas_strsv( CblasColMajor, CblasUpper, CblasNoTrans, CblasNonUnit,
311 i+1, gmH, im1, gmG, 1 );
312
313 /**
314 * Compute x_m = x_0 + M^{-1} V_m y_m
315 * = x_0 + W_m y_m
316 */
317 if (precond && savemem) {
318 /**
319 * Since we saved memory, we do not have (M^{-1} V_m) stored,
320 * thus we compute:
321 * w = V_m y_m
322 * w = M^{-1} (V_m y_m)
323 * x = x0 + (M^{-1} (V_m y_m))
324 */
325 solver.gemv( pastix_data, n, i+1, 1.0, gmV, n, gmG, 0., gmW );
326 solver.spsv( pastix_data, gmW, sgmWi );
327 solver.axpy( pastix_data, n, 1., gmW, x );
328 }
329 else {
330 /**
331 * Since we did not saved memory, we do have (M^{-1} V_m) stored in
332 * W_m if precond is true, thus we compute:
333 * x = x0 + W_m y_m, if precond
334 * x = x0 + V_m y_m, if not precond
335 */
336 gmWi = precond ? gmW : gmV;
337 solver.gemv( pastix_data, n, i+1, 1.0, gmWi, n, gmG, 1.0, x );
338 }
339
340 /**
341 * Exit only if maximum number of iteration is reached.
342 * Exit on residual if checked at the beginning of the outer loop to be
343 * sure that the final residual of Ax-b is equal to the estimator
344 * computed within the inner loop.
345 */
346 if (iters >= itermax)
347 {
348 outflag = 0;
349 }
350 }
351
352 clockStop( refine_clk );
353 t3 = clockGet();
354
355 solver.output_final( pastix_data, resid_b, iters, t3, x, x );
356
357 solver.free(gmcos);
358 solver.free(gmsin);
359 solver.free(gmG);
360 solver.free(gmH);
361 solver.free(gmV);
362 solver.free(gmW);
363 solver.free(sgmWi);
364#if defined(PASTIX_DEBUG_GMRES)
365 solver.free(dbg_x);
366 solver.free(dbg_r);
367 solver.free(dbg_G);
368#endif
369
370 return iters;
371}
BEGIN_C_DECLS typedef int pastix_int_t
Definition datatypes.h:51
double pastix_fixdbl_t
Definition datatypes.h:65
@ DPARM_EPSILON_REFINEMENT
Definition api.h:161
@ IPARM_MIXED
Definition api.h:139
@ IPARM_GMRES_IM
Definition api.h:114
@ IPARM_ITERMAX
Definition api.h:113
@ IPARM_VERBOSE
Definition api.h:36
@ PastixNoTrans
Definition api.h:445
@ PastixVerboseNot
Definition api.h:220
void s_refine_init(struct s_solver *, pastix_data_t *)
Initiate functions pointers to define basic operations.
pastix_int_t s_gmres_smp(pastix_data_t *pastix_data, pastix_rhs_t xp, pastix_rhs_t bp)
pastix_int_t * iparm
Definition pastixdata.h:70
double * dparm
Definition pastixdata.h:71
pastix_bcsc_t * bcsc
Definition pastixdata.h:102
pastix_int_t steps
Definition pastixdata.h:73
Main PaStiX data structure.
Definition pastixdata.h:68
Main PaStiX RHS structure.
Definition pastixdata.h:155