PaStiX Handbook  6.4.0
codelet_blok_sgemmsp.c
Go to the documentation of this file.
1 /**
2  *
3  * @file codelet_blok_sgemmsp.c
4  *
5  * StarPU codelets for blas-like functions
6  *
7  * @copyright 2016-2024 Bordeaux INP, CNRS (LaBRI UMR 5800), Inria,
8  * Univ. Bordeaux. All rights reserved.
9  *
10  * @version 6.4.0
11  * @author Mathieu Faverge
12  * @author Pierre Ramet
13  * @author Ian Masliah
14  * @author Tom Moenne-Loccoz
15  * @date 2024-07-05
16  *
17  * @generated from /builds/solverstack/pastix/sopalin/starpu/codelet_blok_zgemmsp.c, normal z -> s, Thu Aug 29 14:20:31 2024
18  *
19  * @addtogroup pastix_starpu
20  * @{
21  *
22  **/
23 #ifndef DOXYGEN_SHOULD_SKIP_THIS
24 #define _GNU_SOURCE
25 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
26 #include "common.h"
27 #include "blend/solver.h"
28 #include "sopalin/sopalin_data.h"
29 #include "pastix_scores.h"
30 #if defined(PASTIX_WITH_CUDA)
31 #include "pastix_scuda.h"
32 #endif
33 #include "pastix_starpu.h"
34 #include "pastix_sstarpu.h"
35 #include "codelets.h"
36 
37 /**
38  * @brief Main structure for all tasks of blok_sgemmsp type
39  */
40 struct cl_blok_sgemmsp_args_s {
41  profile_data_t profile_data;
42  sopalin_data_t *sopalin_data;
43  pastix_trans_t trans;
44  const SolverCblk *cblk;
45  SolverCblk *fcblk;
46  pastix_int_t blok_mk;
47  pastix_int_t blok_nk;
48  pastix_int_t blok_mn;
49 };
50 
51 #if defined( PASTIX_STARPU_PROFILING )
52 /**
53  * @brief Functions to profile the codelet
54  *
55  * Two levels of profiling are available:
56  * 1) A generic one that returns the flops per worker
57  * 2) A more detailed one that generate logs of the performance for each kernel
58  */
59 starpu_profile_t blok_sgemmsp_profile = {
60  .next = NULL,
61  .name = "blok_sgemmsp"
62 };
63 
64 /**
65  * @brief Profiling registration function
66  */
67 void blok_sgemmsp_profile_register( void ) __attribute__( ( constructor ) );
68 void
69 blok_sgemmsp_profile_register( void )
70 {
71  profiling_register_cl( &blok_sgemmsp_profile );
72 }
73 
74 #ifndef DOXYGEN_SHOULD_SKIP_THIS
75 #if defined(PASTIX_STARPU_PROFILING_LOG)
76 static void
77 cl_profiling_cb_blok_sgemmsp( void *callback_arg )
78 {
79  cl_profiling_callback( callback_arg );
80 
81  struct starpu_task *task = starpu_task_get_current();
82  struct starpu_profiling_task_info *info = task->profiling_info;
83 
84  /* Quick return */
85  if ( info == NULL ) {
86  return;
87  }
88 
89  struct cl_blok_sgemmsp_args_s *args = (struct cl_blok_sgemmsp_args_s *) callback_arg;
90  pastix_fixdbl_t flops = args->profile_data.flops;
91  pastix_fixdbl_t duration = starpu_timing_timespec_delay_us( &info->start_time, &info->end_time );
92  pastix_fixdbl_t speed = flops / ( 1000.0 * duration );
93 
94  pastix_int_t M = blok_rownbr_ext( args->cblk->fblokptr + args->blok_mk );
95  pastix_int_t N = blok_rownbr_ext( args->cblk->fblokptr + args->blok_nk );
96  pastix_int_t K = cblk_colnbr( args->cblk );
97 
98  cl_profiling_log_register( task->name, "blok_sgemmsp", M, N, K, flops, speed );
99 }
100 #endif
101 
102 #if defined(PASTIX_STARPU_PROFILING_LOG)
103 static void (*blok_sgemmsp_callback)(void*) = cl_profiling_cb_blok_sgemmsp;
104 #else
105 static void (*blok_sgemmsp_callback)(void*) = cl_profiling_callback;
106 #endif
107 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
108 
109 #endif /* defined( PASTIX_STARPU_PROFILING ) */
110 
111 /**
112  *******************************************************************************
113  *
114  * @brief Cost model function
115  *
116  * The user can switch from the pastix static model to an history based model
117  * computed automatically.
118  *
119  *******************************************************************************
120  *
121  * @param[in] task
122  * TODO
123  *
124  * @param[in] arch
125  * TODO
126  *
127  * @param[in] nimpl
128  * TODO
129  *
130  *******************************************************************************
131  *
132  * @retval TODO
133  *
134  *******************************************************************************/
135 static inline pastix_fixdbl_t
136 fct_blok_sgemmsp_cost( struct starpu_task *task,
137  struct starpu_perfmodel_arch *arch,
138  unsigned nimpl )
139 {
140  struct cl_blok_sgemmsp_args_s *args = (struct cl_blok_sgemmsp_args_s *)(task->cl_arg);
141 
142  pastix_fixdbl_t cost = 0.;
143  pastix_fixdbl_t *coefs;
144  pastix_int_t M = blok_rownbr_ext( args->cblk->fblokptr + args->blok_mk );
145  pastix_int_t N = blok_rownbr_ext( args->cblk->fblokptr + args->blok_nk );
146  pastix_int_t K = cblk_colnbr( args->cblk );
147 
148  switch( arch->devices->type ) {
149  case STARPU_CPU_WORKER:
150  coefs = &(args->sopalin_data->cpu_models->coefficients[PastixFloat-2][PastixKernelGEMMBlok2d2d][0]);
151  break;
152  case STARPU_CUDA_WORKER:
153  coefs = &(args->sopalin_data->gpu_models->coefficients[PastixFloat-2][PastixKernelGEMMBlok2d2d][0]);
154  break;
155  default:
156  assert(0);
157  return 0.;
158  }
159 
160  /* Get cost in us */
161  cost = modelsGetCost3Param( coefs, M, N, K ) * 1e6;
162 
163  (void)nimpl;
164  return cost;
165 }
166 
167 
168 #ifndef DOXYGEN_SHOULD_SKIP_THIS
169 static struct starpu_perfmodel starpu_blok_sgemmsp_model = {
170 #if defined(PASTIX_STARPU_COST_PER_ARCH)
171  .type = STARPU_PER_ARCH,
172  .arch_cost_function = fct_blok_sgemmsp_cost,
173 #else
174  .type = STARPU_HISTORY_BASED,
175 #endif
176  .symbol = "blok_sgemmsp",
177 };
178 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
179 
180 #if !defined(PASTIX_STARPU_SIMULATION)
181 /**
182  *******************************************************************************
183  *
184  * @brief StarPU CPU implementation
185  *
186  *******************************************************************************
187  *
188  * @param[in] descr
189  * TODO
190  *
191  * @param[in] cl_arg
192  * TODO
193  *
194  *******************************************************************************/
195 static void
196 fct_blok_sgemmsp_cpu( void *descr[], void *cl_arg )
197 {
198  struct cl_blok_sgemmsp_args_s *args = (struct cl_blok_sgemmsp_args_s *)cl_arg;
199  const void *A;
200  const void *B;
201  void *C;
202 
203  A = pastix_starpu_blok_get_ptr( descr[0] );
204  B = pastix_starpu_blok_get_ptr( descr[1] );
205  C = pastix_starpu_blok_get_ptr( descr[2] );
206 
207  assert( args->cblk->cblktype & CBLK_TASKS_2D );
208  assert( args->fcblk->cblktype & CBLK_TASKS_2D );
209 
210  args->profile_data.flops = cpublok_sgemmsp( args->trans,
211  args->cblk, args->fcblk,
212  args->blok_mk, args->blok_nk, args->blok_mn,
213  A, B, C,
214  &(args->sopalin_data->solvmtx->lowrank) );
215 }
216 
217 /**
218  * @brief StarPU GPU implementation
219  */
220 #if defined(PASTIX_WITH_CUDA)
221 static void
222 fct_blok_sgemmsp_gpu( void *descr[], void *cl_arg )
223 {
224  struct cl_blok_sgemmsp_args_s *args = (struct cl_blok_sgemmsp_args_s *)cl_arg;
225  const void *A;
226  const void *B;
227  void *C;
228 
229  A = pastix_starpu_blok_get_ptr( descr[0] );
230  B = pastix_starpu_blok_get_ptr( descr[1] );
231  C = pastix_starpu_blok_get_ptr( descr[2] );
232 
233  assert( args->cblk->cblktype & CBLK_TASKS_2D );
234  assert( args->fcblk->cblktype & CBLK_TASKS_2D );
235 
236  args->profile_data.flops = gpublok_sgemmsp( args->trans,
237  args->cblk, args->fcblk,
238  args->blok_mk, args->blok_nk, args->blok_mn,
239  A, B, C,
240  &(args->sopalin_data->solvmtx->lowrank),
241  starpu_cuda_get_local_stream() );
242 }
243 #endif /* defined(PASTIX_WITH_CUDA) */
244 #endif /* !defined(PASTIX_STARPU_SIMULATION) */
245 
246 #ifndef DOXYGEN_SHOULD_SKIP_THIS
247 CODELETS_GPU( blok_sgemmsp, 3, STARPU_CUDA_ASYNC );
248 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
249 
250 /**
251  *******************************************************************************
252  *
253  * @brief TODO
254  *
255  *******************************************************************************
256  *
257  * @param[in] sopalin_data
258  * TODO
259  *
260  * @param[in] sideA
261  * TODO
262  *
263  * @param[in] sideB
264  * TODO
265  *
266  * @param[in] trans
267  * TODO
268  *
269  * @param[in] cblk
270  * TODO
271  *
272  * @param[in] fcblk
273  * TODO
274  *
275  * @param[in] blokA
276  * TODO
277  *
278  * @param[in] blokB
279  * TODO
280  *
281  * @param[in] prio
282  * TODO
283  *
284  *******************************************************************************/
285 void
286 starpu_task_blok_sgemmsp( sopalin_data_t *sopalin_data,
287  pastix_coefside_t sideA,
288  pastix_coefside_t sideB,
289  pastix_trans_t trans,
290  SolverCblk *cblk,
291  SolverCblk *fcblk,
292  const SolverBlok *blokA,
293  const SolverBlok *blokB,
294  int prio )
295 {
296  struct cl_blok_sgemmsp_args_s *cl_arg = NULL;
297  long long execute_where = cl_blok_sgemmsp_any.where;
298  int need_exec = 1;
299 #if defined(PASTIX_DEBUG_STARPU) || defined(PASTIX_STARPU_PROFILING_LOG)
300  char *task_name;
301 #endif
302 
303  pastix_int_t frownum;
304  pastix_int_t lrownum;
305  pastix_int_t blok_mn = 0, j = 0;
306  pastix_int_t blok_mk = blokA - cblk->fblokptr;
307  pastix_int_t blok_nk = blokB - cblk->fblokptr;
308  SolverBlok *blokC = fcblk->fblokptr;
309 
310  assert( blok_nk <= blok_mk );
311 
312  do {
313  frownum = blokC->frownum;
314  lrownum = blokC->lrownum;
315  blok_mn += j;
316  j = 1;
317 
318  /* Increase lrownum as long as blocks are facing the same cblk */
319  while( (blokC < fcblk[1].fblokptr-1) &&
320  (blokC[0].fcblknm == blokC[1].fcblknm) &&
321  (blokC[0].lcblknm == blokC[1].lcblknm) )
322  {
323  blokC++; j++;
324  lrownum = blokC->lrownum;
325  }
326  blokC++;
327  }
328  while( !((blokA->frownum >= frownum) &&
329  (blokA->lrownum <= lrownum)) );
330 
331  blokC = fcblk->fblokptr + blok_mn;
332 
333  assert( blokA->lcblknm == blokB->lcblknm );
334  assert( blokB->fcblknm == blokC->lcblknm );
335  assert( blokC->frownum <= blokA->frownum );
336  assert( blokA[-1].fcblknm != blokA[0].fcblknm );
337  assert( blokB[-1].fcblknm != blokB[0].fcblknm );
338  assert( (blok_mn == 0) || (blokC[-1].fcblknm != blokC[0].fcblknm) );
339 
340  /*
341  * Check if it needs to be submitted
342  */
343 #if defined(PASTIX_WITH_MPI)
344  {
345  int need_submit = 0;
346  if ( cblk->ownerid == sopalin_data->solvmtx->clustnum ) {
347  need_submit = 1;
348  }
349  if ( (fcblk->cblktype & CBLK_FANIN) ||
350  (fcblk->ownerid == sopalin_data->solvmtx->clustnum) ) {
351  need_submit = 1;
352  }
353  else {
354  need_exec = 0;
355  }
356  if ( starpu_mpi_cached_receive( blokC->handler[sideA] ) ) {
357  need_submit = 1;
358  }
359  if ( !need_submit ) {
360  return;
361  }
362  }
363 #endif
364 
365  /*
366  * Create the arguments array
367  */
368  if ( need_exec ) {
369  cl_arg = malloc( sizeof( struct cl_blok_sgemmsp_args_s ) );
370  cl_arg->sopalin_data = sopalin_data;
371 #if defined(PASTIX_STARPU_PROFILING)
372  cl_arg->profile_data.measures = blok_sgemmsp_profile.measures;
373  cl_arg->profile_data.flops = NAN;
374 #endif
375  cl_arg->trans = trans;
376  cl_arg->cblk = cblk;
377  cl_arg->fcblk = fcblk;
378  cl_arg->blok_mk = blok_mk;
379  cl_arg->blok_nk = blok_nk;
380  cl_arg->blok_mn = blok_mn;
381 
382 #if defined(PASTIX_WITH_CUDA)
383  if ( (cblk->cblktype & CBLK_COMPRESSED) ||
384  (fcblk->cblktype & CBLK_COMPRESSED) )
385  {
386  /* Disable CUDA */
387  execute_where &= (~STARPU_CUDA);
388  }
389 #endif
390  }
391 
392 #if defined(PASTIX_DEBUG_STARPU) || defined(PASTIX_STARPU_PROFILING_LOG)
393  /* This actually generates a memory leak */
394  asprintf( &task_name, "%s( %ld, %ld, %ld, %ld )",
395  cl_blok_sgemmsp_any.name,
396  (long)(blokA - sopalin_data->solvmtx->bloktab),
397  (long)(blokB - sopalin_data->solvmtx->bloktab),
398  (long)(blokC - sopalin_data->solvmtx->bloktab),
399  (long)sideA );
400 #endif
401 
402  pastix_starpu_insert_task(
403  &cl_blok_sgemmsp_any,
404  STARPU_CL_ARGS, cl_arg, sizeof( struct cl_blok_sgemmsp_args_s ),
405  STARPU_EXECUTE_WHERE, execute_where,
406 #if defined(PASTIX_STARPU_PROFILING)
407  STARPU_CALLBACK_WITH_ARG_NFREE, blok_sgemmsp_callback, cl_arg,
408 #endif
409  STARPU_R, blokA->handler[sideA],
410  STARPU_R, blokB->handler[sideB],
411  STARPU_RW, blokC->handler[sideA],
412 #if defined(PASTIX_DEBUG_STARPU) || defined(PASTIX_STARPU_PROFILING_LOG)
413  STARPU_NAME, task_name,
414 #endif
415 #if defined(PASTIX_STARPU_HETEROPRIO)
416  STARPU_PRIORITY, BucketGEMM2D,
417 #else
418  STARPU_PRIORITY, prio,
419 #endif
420  0);
421  (void)prio;
422 }
423 
424 /**
425  * @}
426  */
BEGIN_C_DECLS typedef int pastix_int_t
Definition: datatypes.h:51
double pastix_fixdbl_t
Definition: datatypes.h:65
@ PastixKernelGEMMBlok2d2d
Definition: kernels_enums.h:67
pastix_fixdbl_t cpublok_sgemmsp(pastix_trans_t transB, const SolverCblk *cblk, SolverCblk *fcblk, pastix_int_t blok_mk, pastix_int_t blok_nk, pastix_int_t blok_mn, const void *A, const void *B, void *C, const pastix_lr_t *lowrank)
Compute the CPU gemm associated to a couple of off-diagonal blocks.
enum pastix_trans_e pastix_trans_t
Transpostion.
enum pastix_coefside_e pastix_coefside_t
Data blocks used in the kernel.
static void fct_blok_sgemmsp_cpu(void *descr[], void *cl_arg)
StarPU CPU implementation.
static pastix_fixdbl_t fct_blok_sgemmsp_cost(struct starpu_task *task, struct starpu_perfmodel_arch *arch, unsigned nimpl)
Cost model function.
void starpu_task_blok_sgemmsp(sopalin_data_t *sopalin_data, pastix_coefside_t sideA, pastix_coefside_t sideB, pastix_trans_t trans, SolverCblk *cblk, SolverCblk *fcblk, const SolverBlok *blokA, const SolverBlok *blokB, int prio)
StarPU GPU implementation.
Base structure to all codelet arguments that include the profiling data.
static double cost(symbol_cblk_t *cblk)
Computes the cost of a cblk.
static pastix_int_t blok_rownbr_ext(const SolverBlok *blok)
Compute the number of rows of a contiguous block in front of the same cblk.
Definition: solver.h:407
pastix_int_t lrownum
Definition: solver.h:148
void * handler[2]
Definition: solver.h:142
static pastix_int_t cblk_colnbr(const SolverCblk *cblk)
Compute the number of columns in a column block.
Definition: solver.h:329
pastix_int_t fcblknm
Definition: solver.h:144
pastix_int_t frownum
Definition: solver.h:147
SolverBlok * fblokptr
Definition: solver.h:168
pastix_int_t lcblknm
Definition: solver.h:143
int8_t cblktype
Definition: solver.h:164
int ownerid
Definition: solver.h:181
Solver block structure.
Definition: solver.h:141
Solver column block structure.
Definition: solver.h:161