PaStiX Handbook  6.3.2
pastix_cuda.h
Go to the documentation of this file.
1 /**
2  * @file pastix_cuda.h
3  *
4  * PaStiX GPU kernel header.
5  *
6  * @copyright 2016-2023 Bordeaux INP, CNRS (LaBRI UMR 5800), Inria,
7  * Univ. Bordeaux. All rights reserved.
8  *
9  * @version 6.3.2
10  * @author Mathieu Faverge
11  * @date 2023-07-21
12  *
13  */
14 #ifndef _pastix_cuda_h_
15 #define _pastix_cuda_h_
16 
17 #ifdef __cplusplus
18 extern "C" {
19 #endif
20 
21 #define MAX_BATCH_COUNT 16
22 
23 typedef struct gemm_param_s{
24  const void *Aptr;
25  void *Cptr;
26  pastix_int_t M;
27  pastix_int_t lda;
28  pastix_int_t ldc;
29 } gemm_param_t;
30 
31 typedef struct gemm_params_s {
32  gemm_param_t p[MAX_BATCH_COUNT];
34 
35 void
36 pastix_zgemm_vbatched_nt(
37  pastix_trans_t transB,
39  cuDoubleComplex alpha,
40  const cuDoubleComplex * dB, pastix_int_t lddb,
41  cuDoubleComplex beta,
42  pastix_int_t max_m, pastix_int_t batchCount, cudaStream_t stream,
43  gemm_params_t params );
44 
45 void
46 pastix_cgemm_vbatched_nt(
47  pastix_trans_t transB,
49  cuFloatComplex alpha,
50  const cuFloatComplex * dB, pastix_int_t lddb,
51  cuFloatComplex beta,
52  pastix_int_t max_m, pastix_int_t batchCount, cudaStream_t stream,
53  gemm_params_t params );
54 
55 void
56 pastix_dgemm_vbatched_nt(
57  pastix_trans_t transB,
59  double alpha,
60  const double * dB, pastix_int_t lddb,
61  double beta,
62  pastix_int_t max_m, pastix_int_t batchCount, cudaStream_t stream,
63  gemm_params_t params );
64 
65 void
66 pastix_sgemm_vbatched_nt(
67  pastix_trans_t transB,
69  float alpha,
70  const float * dB, pastix_int_t lddb,
71  float beta,
72  pastix_int_t max_m, pastix_int_t batchCount, cudaStream_t stream,
73  gemm_params_t params );
74 
75 
76 void
77 pastix_fermi_zgemmsp(
78  char TRANSA, char TRANSB, int m , int n , int k ,
79  cuDoubleComplex alpha, const cuDoubleComplex *gpu_A, int lda,
80  const cuDoubleComplex *gpu_B, int ldb,
81  cuDoubleComplex beta, cuDoubleComplex *gpu_C, int ldc,
82  int blocknbr, const int *blocktab, int fblocknbr, const int *fblocktab,
83  cudaStream_t stream );
84 
85 void
86 pastix_fermi_cgemmsp(
87  char TRANSA, char TRANSB, int m , int n , int k ,
88  cuFloatComplex alpha, const cuFloatComplex *gpu_A, int lda,
89  const cuFloatComplex *gpu_B, int ldb,
90  cuFloatComplex beta, cuFloatComplex *gpu_C, int ldc,
91  int blocknbr, const int *blocktab, int fblocknbr, const int *fblocktab,
92  cudaStream_t stream );
93 
94 void
95 pastix_fermi_dgemmsp(
96  char TRANSA, char TRANSB, int m , int n , int k ,
97  double alpha, const double *gpu_A, int lda,
98  const double *gpu_B, int ldb,
99  double beta, double *gpu_C, int ldc,
100  int blocknbr, const int *blocktab, int fblocknbr, const int *fblocktab,
101  cudaStream_t stream );
102 
103 void
104 pastix_fermi_sgemmsp(
105  char TRANSA, char TRANSB, int m , int n , int k ,
106  float alpha, const float *gpu_A, int lda,
107  const float *gpu_B, int ldb,
108  float beta, float *gpu_C, int ldc,
109  int blocknbr, const int *blocktab, int fblocknbr, const int *fblocktab,
110  cudaStream_t stream );
111 
112 #ifdef __cplusplus
113 }
114 #endif
115 
116 
117 #endif /* _pastix_cuda_h_ */
BEGIN_C_DECLS typedef int pastix_int_t
Definition: datatypes.h:51
enum pastix_trans_e pastix_trans_t
Transpostion.