PaStiX Handbook  6.4.0
pastix_rhs.c
Go to the documentation of this file.
1 /**
2  *
3  * @file pastix_rhs.c
4  *
5  * @copyright 2004-2024 Bordeaux INP, CNRS (LaBRI UMR 5800), Inria,
6  * Univ. Bordeaux. All rights reserved.
7  *
8  * @version 6.4.0
9  * @author Mathieu Faverge
10  * @author Pierre Ramet
11  * @author Vincent Bridonneau
12  * @author Alycia Lisito
13  * @date 2024-07-05
14  *
15  **/
16 #include "common.h"
17 #include "bcsc/bvec.h"
18 #if defined(PASTIX_WITH_STARPU)
19 #include "starpu/pastix_starpu.h"
20 #endif
21 #include <lapacke.h>
22 
23 /**
24  *******************************************************************************
25  *
26  * @ingroup bcsc
27  *
28  * @brief Initialize an RHS data structure.
29  *
30  *******************************************************************************
31  *
32  * @param[inout] B_ptr
33  * On entry, an allocated pastix_rhs_t data structure.
34  * On exit, the data is initialized to be used by the pastix_subtask_*
35  * functions.
36  *
37  *******************************************************************************
38  *
39  * @retval PASTIX_SUCCESS on successful exit,
40  * @retval PASTIX_ERR_BADPARAMETER if one parameter is incorrect.
41  *
42  *******************************************************************************/
43 int
45 {
46  pastix_rhs_t B;
47 
48  if ( B_ptr == NULL ) {
49  pastix_print_error( "pastixRhsInit: wrong B parameter" );
51  }
52 
53  *B_ptr = malloc( sizeof(struct pastix_rhs_s) );
54  B = *B_ptr;
55 
56  B->allocated = -1;
57  B->flttype = PastixPattern;
58  B->m = -1;
59  B->n = -1;
60  B->ld = -1;
61  B->b = NULL;
62  B->cblkb = NULL;
63  B->rhs_comm = NULL;
64  B->Ploc2Pglob = NULL;
65 #if defined(PASTIX_WITH_STARPU)
66  B->starpu_desc = NULL;
67 #endif
68 
69  return PASTIX_SUCCESS;
70 }
71 
72 /**
73  *******************************************************************************
74  *
75  * @ingroup bcsc
76  *
77  * @brief Cleanup an RHS data structure.
78  *
79  *******************************************************************************
80  *
81  * @param[inout] B
82  * On entry, the initialized pastix_rhs_t data structure.
83  * On exit, the structure is destroyed and should no longer be used.
84  *
85  *******************************************************************************
86  *
87  * @retval PASTIX_SUCCESS on successful exit,
88  * @retval PASTIX_ERR_BADPARAMETER if one parameter is incorrect.
89  *
90  *******************************************************************************/
91 int
93 {
94  if ( B == NULL ) {
95  pastix_print_error( "pastixRhsFinalize: wrong B parameter" );
97  }
98 
99  if ( B->b != NULL ) {
100  if ( B->allocated > 0 ) {
101  free( B->b );
102  }
103  else {
104  pastix_print_warning( "Calling pastixRhsFinalize before restoring the ordering of vector b.\n"
105  "Please call:\n"
106  " pastix_subtask_applyorder( pastix_data, flttype, PastixDirBackward, m, n,\n"
107  " b, ldb, Bp );\n"
108  "prior to this call to restore it.\n" );
109  }
110  }
111 
112  if ( B->cblkb != NULL ) {
113  memFree_null( B->cblkb );
114  }
115 
116  if ( B->Ploc2Pglob != NULL ) {
117  memFree_null( B->Ploc2Pglob );
118  }
119  if ( B->rhs_comm != NULL ) {
120  memFree_null( B->rhs_comm );
121  }
122 
123 #if defined(PASTIX_WITH_STARPU)
124  {
125  if ( B->starpu_desc != NULL ) {
126  starpu_rhs_destroy( B->starpu_desc );
127  free( B->starpu_desc );
128  }
129  B->starpu_desc = NULL;
130  }
131 #endif
132  free( B );
133  return PASTIX_SUCCESS;
134 }
135 
136 /**
137  *******************************************************************************
138  *
139  * @ingroup pastix_solve
140  *
141  * @brief Reduces the precision of an RHS.
142  *
143  *******************************************************************************
144  *
145  * @param[in] dB
146  * The allocated pastix_rhs_t data structure to convert to lower
147  * precision.
148  *
149  * @param[out] sB
150  * On entry, an allocated pastix_rhs_t data structure.
151  * On exit, the reduced precision pastix_rhs_t of dB.
152  * If sB->allocated == -1 on entry, the internal b vector is
153  * automatically allocated by the function.
154  *
155  *******************************************************************************
156  *
157  * @retval PASTIX_SUCCESS on successful exit,
158  * @retval PASTIX_ERR_BADPARAMETER if one parameter is incorrect.
159  *
160  *******************************************************************************/
161 int
163  pastix_rhs_t sB )
164 {
165  int rc;
166  int tofree = 0;
167 
168  /* Generates halved-precision vector */
169  if ( ( dB->flttype != PastixComplex64 ) &&
170  ( dB->flttype != PastixDouble ) )
171  {
172  pastix_print_error( "bvecDoubletoSingle: Invalid float type for mixed-precision" );
174  }
175 
176  if ( sB->allocated == -1 ) {
177  size_t size = dB->ld * dB->n;
178 
179  memcpy( sB, dB, sizeof( struct pastix_rhs_s ) );
180 
181  sB->allocated = 1;
182  sB->flttype = dB->flttype - 1;
183  sB->b = malloc( size * pastix_size_of( sB->flttype ) );
184  sB->rhs_comm = NULL;
185  tofree = 1;
186  }
187  assert( sB->allocated >= 0 );
188  assert( sB->flttype == (dB->flttype - 1) );
189  assert( sB->b != NULL );
190  assert( sB->m == dB->m );
191  assert( sB->n == dB->n );
192 
193  switch( dB->flttype ) {
194  case PastixComplex64:
195  rc = LAPACKE_zlag2c_work( LAPACK_COL_MAJOR, dB->m, dB->n, dB->b, dB->ld, sB->b, sB->ld );
196  break;
197  case PastixDouble:
198  rc = LAPACKE_dlag2s_work( LAPACK_COL_MAJOR, dB->m, dB->n, dB->b, dB->ld, sB->b, sB->ld );
199  break;
200  default:
201  rc = 1;
202  pastix_print_error( "bvecDoubletoSingle: Invalid input float type for mixed-precision" );
203  }
204 
205  if ( rc ) {
206  if ( tofree ) {
207  free( dB->b );
208  dB->b = NULL;
209  }
210  return PASTIX_ERR_INTERNAL;
211  }
212 
213  return PASTIX_SUCCESS;
214 }
215 
216 /**
217  *******************************************************************************
218  *
219  * @ingroup pastix_solve
220  *
221  * @brief Increases the precision of an RHS.
222  *
223  *******************************************************************************
224  *
225  * @param[in] sB
226  * The allocated pastix_rhs_t data structure to convert to higher
227  * precision.
228  *
229  * @param[out] dB
230  * On entry, an allocated pastix_rhs_t data structure.
231  * On exit, the increased precision pastix_rhs_t of sB.
232  * If dB->allocated == -1 on entry, the internal b vector is
233  * automatically allocated by the function.
234  *
235  *******************************************************************************
236  *
237  * @retval PASTIX_SUCCESS on successful exit,
238  * @retval PASTIX_ERR_BADPARAMETER if one parameter is incorrect.
239  *
240  *******************************************************************************/
241 int
243  pastix_rhs_t dB )
244 {
245  int rc;
246  int tofree = 0;
247 
248  /* Frees halved-precision vector */
249  if ( ( sB->flttype != PastixComplex32 ) &&
250  ( sB->flttype != PastixFloat ) )
251  {
252  pastix_print_error( "bvecSingleToDouble: Invalid input float type for mixed-precision" );
254  }
255 
256  if ( dB->allocated == -1 ) {
257  size_t size = sB->ld * sB->n;
258 
259  memcpy( dB, sB, sizeof( struct pastix_rhs_s ) );
260 
261  dB->allocated = 1;
262  dB->flttype = sB->flttype + 1;
263  dB->b = malloc( size * pastix_size_of( dB->flttype ) );
264  dB->rhs_comm = NULL;
265  tofree = 1;
266  }
267  assert( dB->allocated >= 0 );
268  assert( dB->flttype == (sB->flttype + 1) );
269  assert( dB->b != NULL );
270  assert( dB->m == sB->m );
271  assert( dB->n == sB->n );
272 
273  switch( sB->flttype ) {
274  case PastixComplex32:
275  rc = LAPACKE_clag2z_work( LAPACK_COL_MAJOR, sB->m, sB->n, sB->b, sB->ld, dB->b, dB->ld );
276  break;
277  case PastixFloat:
278  rc = LAPACKE_slag2d_work( LAPACK_COL_MAJOR, sB->m, sB->n, sB->b, sB->ld, dB->b, dB->ld );
279  break;
280  default:
281  rc = 1;
282  pastix_print_error( "bvecSingleToDouble: Invalid float type for mixed-precision" );
283  }
284 
285  if ( rc ) {
286  if ( tofree ) {
287  free( sB->b );
288  sB->b = NULL;
289  }
290  return PASTIX_ERR_INTERNAL;
291  }
292 
293  return PASTIX_SUCCESS;
294 }
int pastixRhsInit(pastix_rhs_t *B_ptr)
Initialize an RHS data structure.
Definition: pastix_rhs.c:44
int pastixRhsFinalize(pastix_rhs_t B)
Cleanup an RHS data structure.
Definition: pastix_rhs.c:92
@ PASTIX_ERR_INTERNAL
Definition: api.h:373
@ PASTIX_SUCCESS
Definition: api.h:367
@ PASTIX_ERR_BADPARAMETER
Definition: api.h:374
int pastixRhsDoubletoSingle(const pastix_rhs_t dB, pastix_rhs_t sB)
Reduces the precision of an RHS.
Definition: pastix_rhs.c:162
int pastixRhsSingleToDouble(const pastix_rhs_t sB, pastix_rhs_t dB)
Increases the precision of an RHS.
Definition: pastix_rhs.c:242
void starpu_rhs_destroy(starpu_rhs_desc_t *desc)
Free the StarPU descriptor of the dense matrix.
Definition: starpu_rhs.c:254
void ** cblkb
Definition: pastixdata.h:162
pastix_int_t * Ploc2Pglob
Definition: pastixdata.h:164
pastix_int_t ld
Definition: pastixdata.h:160
bvec_handle_comm_t * rhs_comm
Definition: pastixdata.h:163
pastix_coeftype_t flttype
Definition: pastixdata.h:157
pastix_int_t m
Definition: pastixdata.h:158
pastix_int_t n
Definition: pastixdata.h:159
int8_t allocated
Definition: pastixdata.h:156
Main PaStiX RHS structure.
Definition: pastixdata.h:155