PaStiX Handbook 6.4.0
Loading...
Searching...
No Matches
s_refine_grad.c
Go to the documentation of this file.
1/**
2 *
3 * @file s_refine_grad.c
4 *
5 * PaStiX refinement functions implementations.
6 *
7 * @copyright 2015-2024 Bordeaux INP, CNRS (LaBRI UMR 5800), Inria,
8 * Univ. Bordeaux. All rights reserved.
9 *
10 * @version 6.4.0
11 * @author Mathieu Faverge
12 * @author Pierre Ramet
13 * @author Xavier Lacoste
14 * @author Theophile Terraz
15 * @author Gregoire Pichon
16 * @author Vincent Bridonneau
17 * @date 2024-07-05
18 * @generated from /builds/2mk6rsew/0/solverstack/pastix/refinement/z_refine_grad.c, normal z -> s, Tue Feb 25 14:36:04 2025
19 *
20 **/
21#include "common.h"
22#include "bcsc/bcsc.h"
23#include "s_refine_functions.h"
24
25/**
26 *******************************************************************************
27 *
28 * @ingroup pastix_refine
29 *
30 * s_grad_smp - Refine the solution using conjugate gradian method.
31 *
32 *******************************************************************************
33 *
34 * @param[in] pastix_data
35 * The PaStiX data structure that describes the solver instance.
36 *
37 * @param[out] xp
38 * The solution vector.
39 *
40 * @param[in] bp
41 * The right hand side member (only one).
42 *
43 *******************************************************************************
44 *
45 * @return Number of iterations
46 *
47 *******************************************************************************/
50 pastix_rhs_t xp,
51 pastix_rhs_t bp )
52{
53 struct s_solver solver;
55 Clock refine_clk;
56 pastix_fixdbl_t t0 = 0;
57 pastix_fixdbl_t t3 = 0;
58 int itermax;
59 int nb_iter = 0;
60 int precond = 1;
61 float *x = (float*)(xp->b);
62 float *b = (float*)(bp->b);
63 float *gradr;
64 float *gradp;
65 float *gradz;
66 float *grad2;
67 float *sgrad = NULL;
68 float normb, normx, normr, alpha, beta;
69 float resid_b, eps;
70
71 memset( &solver, 0, sizeof(struct s_solver) );
72 s_refine_init( &solver, pastix_data );
73
74 if ( !(pastix_data->steps & STEP_NUMFACT) ) {
75 precond = 0;
76 }
77
78 n = pastix_data->bcsc->n;
79 itermax = pastix_data->iparm[IPARM_ITERMAX];
80 eps = pastix_data->dparm[DPARM_EPSILON_REFINEMENT];
81
82 /* Initialize vectors */
83 gradr = (float *)solver.malloc(n * sizeof(float));
84 gradp = (float *)solver.malloc(n * sizeof(float));
85 gradz = (float *)solver.malloc(n * sizeof(float));
86 grad2 = (float *)solver.malloc(n * sizeof(float));
87
88 /* Allocating a vector at half-precision, NULL pointer otherwise */
89 if ( pastix_data->iparm[IPARM_MIXED] )
90 {
91 sgrad = solver.malloc( n * sizeof(float) );
92 }
93
94 clockInit(refine_clk);
95 clockStart(refine_clk);
96
97 normb = solver.norm( pastix_data, n, b );
98 if ( normb == 0. ) {
99 normb = 1;
100 }
101 normx = solver.norm( pastix_data, n, x );
102
103 /* Compute r0 = b - A * x */
104 solver.copy( pastix_data, n, b, gradr );
105 if ( normx > 0. ) {
106 solver.spmv( pastix_data, PastixNoTrans, -1., x, 1., gradr );
107 }
108 normr = solver.norm( pastix_data, n, gradr );
109 resid_b = normr / normb;
110
111 /* z = M^{-1} r */
112 solver.copy( pastix_data, n, gradr, gradz );
113 if ( precond ) {
114 solver.spsv( pastix_data, gradz, sgrad );
115 }
116
117 /* p = z */
118 solver.copy( pastix_data, n, gradz, gradp );
119
120 while ((resid_b > eps) && (nb_iter < itermax))
121 {
122 clockStop((refine_clk));
123 t0 = clockGet();
124 nb_iter++;
125
126 /* grad2 = A * p */
127 solver.spmv( pastix_data, PastixNoTrans, 1.0, gradp, 0., grad2 );
128
129 /* alpha = <r, z> / <Ap, p> */
130 beta = solver.dot( pastix_data, n, gradr, gradz );
131 alpha = solver.dot( pastix_data, n, grad2, gradp );
132 alpha = beta / alpha;
133
134 /* x = x + alpha * p */
135 solver.axpy( pastix_data, n, alpha, gradp, x );
136
137 /* r = r - alpha * A * p */
138 solver.axpy( pastix_data, n, -alpha, grad2, gradr );
139
140 /* z = M-1 * r */
141 solver.copy( pastix_data, n, gradr, gradz );
142 if ( precond ) {
143 solver.spsv( pastix_data, gradz, sgrad );
144 }
145
146 /* beta = <r', z> / <r, z> */
147 alpha = solver.dot( pastix_data, n, gradr, gradz );
148 beta = alpha / beta;
149
150 /* p = z + beta * p */
151 solver.scal( pastix_data, n, beta, gradp );
152 solver.axpy( pastix_data, n, 1., gradz, gradp );
153
154 normr = solver.norm( pastix_data, n, gradr );
155 resid_b = normr / normb;
156
157 clockStop((refine_clk));
158 t3 = clockGet();
159 if ( ( pastix_data->iparm[IPARM_VERBOSE] > PastixVerboseNot ) &&
160 ( pastix_data->procnum == 0 ) ) {
161 solver.output_oneiter( t0, t3, resid_b, nb_iter );
162 }
163 t0 = t3;
164 }
165
166 solver.output_final(pastix_data, resid_b, nb_iter, t3, x, x);
167
168 solver.free((void*) gradr);
169 solver.free((void*) gradp);
170 solver.free((void*) gradz);
171 solver.free((void*) grad2);
172 solver.free((void*) sgrad);
173
174 return nb_iter;
175}
BEGIN_C_DECLS typedef int pastix_int_t
Definition datatypes.h:51
double pastix_fixdbl_t
Definition datatypes.h:65
@ DPARM_EPSILON_REFINEMENT
Definition api.h:161
@ IPARM_MIXED
Definition api.h:139
@ IPARM_ITERMAX
Definition api.h:113
@ IPARM_VERBOSE
Definition api.h:36
@ PastixNoTrans
Definition api.h:445
@ PastixVerboseNot
Definition api.h:220
void s_refine_init(struct s_solver *, pastix_data_t *)
Initiate functions pointers to define basic operations.
pastix_int_t s_grad_smp(pastix_data_t *pastix_data, pastix_rhs_t xp, pastix_rhs_t bp)
pastix_int_t * iparm
Definition pastixdata.h:70
double * dparm
Definition pastixdata.h:71
pastix_bcsc_t * bcsc
Definition pastixdata.h:102
pastix_int_t steps
Definition pastixdata.h:73
Main PaStiX data structure.
Definition pastixdata.h:68
Main PaStiX RHS structure.
Definition pastixdata.h:155