PaStiX Handbook  6.4.0
z_refine_bicgstab.c
Go to the documentation of this file.
1 /**
2  *
3  * @file z_refine_bicgstab.c
4  *
5  * PaStiX refinement functions implementations.
6  *
7  * @copyright 2015-2024 Bordeaux INP, CNRS (LaBRI UMR 5800), Inria,
8  * Univ. Bordeaux. All rights reserved.
9  *
10  * @version 6.4.0
11  * @author Mathieu Faverge
12  * @author Pierre Ramet
13  * @author Xavier Lacoste
14  * @author Theophile Terraz
15  * @author Gregoire Pichon
16  * @author Vincent Bridonneau
17  * @date 2024-07-05
18  * @generated from /builds/solverstack/pastix/refinement/z_refine_bicgstab.c, normal z -> z, Thu Aug 29 14:20:51 2024
19  *
20  **/
21 #include "common.h"
22 #include "bcsc/bcsc.h"
23 #include "z_refine_functions.h"
24 
25 /**
26  *******************************************************************************
27  *
28  * @ingroup pastix_refine
29  *
30  * z_bicgstab_smp - Function computing bicgstab iterative refinement.
31  *
32  *******************************************************************************
33  *
34  * @param[in] pastix_data
35  * The PaStiX data structure that describes the solver instance.
36  *
37  * @param[out] xp
38  * The solution vector.
39  *
40  * @param[in] bp
41  * The right hand side member (only one).
42  *
43  *******************************************************************************
44  *
45  * @return Number of iterations
46  *
47  *******************************************************************************/
50  pastix_rhs_t xp,
51  pastix_rhs_t bp )
52 {
53  struct z_solver solver;
54  pastix_int_t n;
55  Clock refine_clk;
56  pastix_fixdbl_t t0 = 0;
57  pastix_fixdbl_t t3 = 0;
58  int itermax;
59  int nb_iter = 0;
60  int precond = 1;
61  pastix_complex64_t *x = (pastix_complex64_t*)(xp->b);
62  pastix_complex64_t *b = (pastix_complex64_t*)(bp->b);
63  pastix_complex64_t *gradr; /* Current solution */
64  pastix_complex64_t *gradr2;
65  pastix_complex64_t *gradp;
66  pastix_complex64_t *grady;
67  pastix_complex64_t *gradv;
68  pastix_complex64_t *grads;
69  pastix_complex64_t *gradz;
70  pastix_complex64_t *gradt;
71  pastix_complex64_t *grad2;
72  pastix_complex64_t *grad3;
73  pastix_complex32_t *sgrad = NULL;
74  pastix_complex64_t v1, v2, w;
75  double normb, normx, normr, alpha, beta;
76  double resid_b, eps;
77 
78  memset( &solver, 0, sizeof(struct z_solver) );
79  z_refine_init( &solver, pastix_data );
80 
81  if ( !(pastix_data->steps & STEP_NUMFACT) ) {
82  precond = 0;
83  }
84 
85  n = pastix_data->bcsc->n;
86  itermax = pastix_data->iparm[IPARM_ITERMAX];
87  eps = pastix_data->dparm[DPARM_EPSILON_REFINEMENT];
88 
89  gradr = (pastix_complex64_t *)solver.malloc(n * sizeof(pastix_complex64_t));
90  gradr2 = (pastix_complex64_t *)solver.malloc(n * sizeof(pastix_complex64_t));
91  gradp = (pastix_complex64_t *)solver.malloc(n * sizeof(pastix_complex64_t));
92  grady = (pastix_complex64_t *)solver.malloc(n * sizeof(pastix_complex64_t));
93  gradv = (pastix_complex64_t *)solver.malloc(n * sizeof(pastix_complex64_t));
94  grads = (pastix_complex64_t *)solver.malloc(n * sizeof(pastix_complex64_t));
95  gradz = (pastix_complex64_t *)solver.malloc(n * sizeof(pastix_complex64_t));
96  gradt = (pastix_complex64_t *)solver.malloc(n * sizeof(pastix_complex64_t));
97  grad2 = (pastix_complex64_t *)solver.malloc(n * sizeof(pastix_complex64_t));
98  grad3 = (pastix_complex64_t *)solver.malloc(n * sizeof(pastix_complex64_t));
99 
100  /* Allocating a vector at half-precision, NULL pointer otherwise */
101  if ( pastix_data->iparm[IPARM_MIXED] )
102  {
103  sgrad = solver.malloc( n * sizeof(pastix_complex32_t) );
104  }
105 
106  clockInit(refine_clk);clockStart(refine_clk);
107 
108  normb = solver.norm( pastix_data, n, b );
109  if ( normb == 0. ) {
110  normb = 1;
111  }
112  normx = solver.norm( pastix_data, n, x );
113 
114  /* r = b - Ax */
115  solver.copy( pastix_data, n, b, gradr );
116  if ( normx > 0. ) {
117  solver.spmv( pastix_data, PastixNoTrans, -1., x, 1., gradr );
118  }
119  normr = solver.norm( pastix_data, n, gradr );
120 
121  /* r2 = r */
122  solver.copy( pastix_data, n, gradr, gradr2 );
123  /* p = r */
124  solver.copy( pastix_data, n, gradr, gradp );
125 
126  /* resid_b = ||r|| / ||b|| */
127  resid_b = normr / normb;
128 
129  while ((resid_b > eps) && (nb_iter < itermax))
130  {
131  clockStop((refine_clk));
132  t0 = clockGet();
133  nb_iter++;
134 
135  /* y = M-1 * p */
136  solver.copy( pastix_data, n, gradp, grady );
137  if ( precond ) {
138  solver.spsv( pastix_data, grady, sgrad );
139  }
140 
141  /* v = Ay */
142  solver.spmv( pastix_data, PastixNoTrans, 1.0, grady, 0., gradv );
143 
144  /* alpha = (r, r2) / (v, r2) */
145  alpha = solver.dot( pastix_data, n, gradv, gradr2 );
146  beta = solver.dot( pastix_data, n, gradr, gradr2 );
147  alpha = beta / alpha;
148 
149  /* s = r - alpha * v */
150  solver.copy( pastix_data, n, gradr, grads );
151  solver.axpy( pastix_data, n, -alpha, gradv, grads );
152 
153  /* z = M^{-1} s */
154  solver.copy( pastix_data, n, grads, gradz );
155  if ( precond ) {
156  solver.spsv( pastix_data, gradz, sgrad );
157  }
158 
159  /* t = Az */
160  solver.spmv( pastix_data, PastixNoTrans, 1.0, gradz, 0., gradt );
161 
162  /* w = (M-1t, M-1s) / (M-1t, M-1t) */
163  /* grad2 = M-1t */
164  solver.copy( pastix_data, n, gradt, grad2 );
165  if ( precond ) {
166  solver.spsv( pastix_data, grad2, sgrad );
167  }
168 
169  /* v1 = (M-1t, M-1s) */
170  /* v2 = (M-1t, M-1t) */
171  v1 = solver.dot( pastix_data, n, grad2, gradz );
172  v2 = solver.dot( pastix_data, n, grad2, grad2 );
173  w = v1 / v2;
174 
175  /* x = x + alpha * y + w * z */
176  /* x = x + alpha * y */
177  solver.axpy( pastix_data, n, alpha, grady, x );
178 
179  /* x = x + w * z */
180  solver.axpy( pastix_data, n, w, gradz, x );
181 
182  /* r = s - w * t*/
183  solver.copy( pastix_data, n, grads, gradr );
184  solver.axpy( pastix_data, n, -w, gradt, gradr );
185 
186  /* beta = (r', r2) / (r, r2) * (alpha / w) */
187  /* v1 = (r', r2) */
188  v1 = solver.dot( pastix_data, n, gradr, gradr2 );
189  v2 = alpha / w;
190 
191  beta = v1 / beta;
192  beta = beta * v2;
193 
194  /* p = r + beta * (p - w * v) */
195  /* p = p - w * v */
196  solver.axpy( pastix_data, n, -w, gradv, gradp );
197 
198  /* p = r + beta * p */
199  solver.scal( pastix_data, n, beta, gradp );
200  solver.axpy( pastix_data, n, 1., gradr, gradp );
201 
202  normr = solver.norm( pastix_data, n, gradr );
203  resid_b = normr / normb;
204 
205  clockStop((refine_clk));
206  t3 = clockGet();
207  if ( ( pastix_data->iparm[IPARM_VERBOSE] > PastixVerboseNot ) &&
208  ( pastix_data->procnum == 0 ) ) {
209  solver.output_oneiter( t0, t3, resid_b, nb_iter );
210  }
211  }
212 
213  solver.output_final(pastix_data, resid_b, nb_iter, t3, x, x);
214 
215  solver.free((void*) gradr);
216  solver.free((void*) gradr2);
217  solver.free((void*) gradp);
218  solver.free((void*) grady);
219  solver.free((void*) gradv);
220  solver.free((void*) grads);
221  solver.free((void*) gradz);
222  solver.free((void*) gradt);
223  solver.free((void*) grad2);
224  solver.free((void*) grad3);
225  solver.free((void*) sgrad);
226 
227  return nb_iter;
228 }
BEGIN_C_DECLS typedef int pastix_int_t
Definition: datatypes.h:51
float _Complex pastix_complex32_t
Definition: datatypes.h:76
double pastix_fixdbl_t
Definition: datatypes.h:65
@ DPARM_EPSILON_REFINEMENT
Definition: api.h:161
@ IPARM_MIXED
Definition: api.h:139
@ IPARM_ITERMAX
Definition: api.h:113
@ IPARM_VERBOSE
Definition: api.h:36
@ PastixNoTrans
Definition: api.h:445
@ PastixVerboseNot
Definition: api.h:220
void z_refine_init(struct z_solver *, pastix_data_t *)
Initiate functions pointers to define basic operations.
pastix_int_t z_bicgstab_smp(pastix_data_t *pastix_data, pastix_rhs_t xp, pastix_rhs_t bp)
pastix_int_t * iparm
Definition: pastixdata.h:70
double * dparm
Definition: pastixdata.h:71
pastix_bcsc_t * bcsc
Definition: pastixdata.h:102
pastix_int_t steps
Definition: pastixdata.h:73
Main PaStiX data structure.
Definition: pastixdata.h:68
Main PaStiX RHS structure.
Definition: pastixdata.h:155