PaStiX Handbook  6.3.2
core_dgemdm.c
Go to the documentation of this file.
1 /**
2  *
3  * @file core_dgemdm.c
4  *
5  * PaStiX kernel routines.
6  *
7  * @copyright 2010-2015 Univ. of Tennessee, Univ. of California Berkeley and
8  * Univ. of Colorado Denver. All rights reserved.
9  * @copyright 2015-2023 Bordeaux INP, CNRS (LaBRI UMR 5800), Inria,
10  * Univ. Bordeaux. All rights reserved.
11  *
12  * @version 6.3.2
13  * @author Dulceneia Becker
14  * @author Mathieu Faverge
15  * @author Gregoire Pichon
16  * @author Xavier Lacoste
17  * @date 2023-07-21
18  * @generated from /builds/solverstack/pastix/kernels/core_zgemdm.c, normal z -> d, Wed Dec 13 12:09:13 2023
19  *
20  **/
21 #include "common.h"
22 #include "cblas.h"
23 #include "lapacke.h"
24 
25 /**
26  ******************************************************************************
27  *
28  * @brief Perform one of the following matrix-matrix operations
29  *
30  * C := alpha*op( A )*D*op( B ) + beta*C,
31  *
32  * where op( X ) is one of
33  *
34  * op( X ) = X or op( X ) = X',
35  *
36  * alpha and beta are scalars, and A, B, C and D are matrices, with
37  *
38  * op( A ) an m by k matrix,
39  * op( B ) an k by n matrix,
40  * C an m by n matrix and
41  * D an k by k matrix.
42  *
43  *******************************************************************************
44  *
45  * @param[in] transA
46  * @arg PastixNoTrans : No transpose, op( A ) = A;
47  * @arg PastixTrans : Transpose, op( A ) = A'.
48  *
49  * @param[in] transB
50  * @arg PastixNoTrans : No transpose, op( B ) = B;
51  * @arg PastixTrans : Transpose, op( B ) = B'.
52  *
53  * @param[in] M
54  * The number of rows of the matrix op( A ) and of the
55  * matrix C. M must be at least zero.
56  *
57  * @param[in] N
58  * The number of columns of the matrix op( B ) and the
59  * number of columns of the matrix C. N must be at least zero.
60  *
61  * @param[in] K
62  * The number of columns of the matrix op( A ), the number of
63  * rows of the matrix op( B ), and the number of rows and columns
64  * of matrix D. K must be at least zero.
65  *
66  * @param[in] alpha
67  * On entry, alpha specifies the scalar alpha.
68  *
69  * @param[in] A
70  * double array of DIMENSION ( LDA, ka ), where ka is
71  * k when transA = PastixTrans, and is m otherwise.
72  * Before entry with transA = PastixTrans, the leading m by k
73  * part of the array A must contain the matrix A, otherwise
74  * the leading k by m part of the array A must contain the
75  * matrix A.
76  *
77  * @param[in] LDA
78  * On entry, LDA specifies the first dimension of A as declared
79  * in the calling (sub) program. When transA = PastixTrans then
80  * LDA must be at least max( 1, m ), otherwise LDA must be at
81  * least max( 1, k ).
82  *
83  * @param[in] B
84  * double array of DIMENSION ( LDB, kb ), where kb is
85  * n when transB = PastixTrans, and is k otherwise.
86  * Before entry with transB = PastixTrans, the leading k by n
87  * part of the array B must contain the matrix B, otherwise
88  * the leading n by k part of the array B must contain the
89  * matrix B.
90  *
91  * @param[in] LDB
92  * On entry, LDB specifies the first dimension of B as declared
93  * in the calling (sub) program. When transB = PastixTrans then
94  * LDB must be at least max( 1, k ), otherwise LDB must be at
95  * least max( 1, n ).
96  *
97  * @param[in] beta
98  * On entry, beta specifies the scalar beta. When beta is
99  * supplied as zero then C need not be set on input.
100  *
101  * @param[in] C
102  * double array of DIMENSION ( LDC, n ).
103  * Before entry, the leading m by n part of the array C must
104  * contain the matrix C, except when beta is zero, in which
105  * case C need not be set on entry.
106  * On exit, the array C is overwritten by the m by n matrix
107  * ( alpha*op( A )*D*op( B ) + beta*C ).
108  *
109  * @param[in] LDC
110  * On entry, LDC specifies the first dimension of C as declared
111  * in the calling (sub) program. LDC must be at least
112  * max( 1, m ).
113  *
114  * @param[in] D
115  * double array of DIMENSION ( LDD, k ).
116  * Before entry, the leading k by k part of the array D
117  * must contain the matrix D.
118  *
119  * @param[in] incD
120  * On entry, LDD specifies the first dimension of D as declared
121  * in the calling (sub) program. LDD must be at least
122  * max( 1, k ).
123  *
124  * @param[in] WORK
125  * double array, dimension (MAX(1,LWORK))
126  *
127  * @param[in] LWORK
128  * The length of WORK.
129  * On entry, if transA = PastixTrans and transB = PastixTrans then
130  * LWORK >= max(1, K*N). Otherwise LWORK >= max(1, M*K).
131  *
132  *******************************************************************************
133  *
134  * @retval PASTIX_SUCCESS successful exit
135  * @retval <0 if -i, the i-th argument had an illegal value
136  *
137  ******************************************************************************/
138 int
140  pastix_trans_t transB,
141  int M,
142  int N,
143  int K,
144  double alpha,
145  const double *A,
146  int LDA,
147  const double *B,
148  int LDB,
149  double beta,
150  double *C,
151  int LDC,
152  const double *D,
153  int incD,
154  double *WORK,
155  int LWORK )
156 {
157  int j, Am, Bm, ret;
158  double delta;
159  double *wD2, *w;
160  const double *wD;
161 
162  Am = (transA == PastixNoTrans ) ? M : K;
163  Bm = (transB == PastixNoTrans ) ? K : N;
164 
165  /* Check input arguments */
166  if ((transA < PastixNoTrans) || (transA > PastixTrans)) {
167  return -1;
168  }
169  if ((transB < PastixNoTrans) || (transB > PastixTrans)) {
170  return -2;
171  }
172  if (M < 0) {
173  return -3;
174  }
175  if (N < 0) {
176  return -4;
177  }
178  if (K < 0) {
179  return -5;
180  }
181  if ((LDA < pastix_imax(1,Am)) && (Am > 0)) {
182  return -8;
183  }
184  if ((LDB < pastix_imax(1,Bm)) && (Bm > 0)) {
185  return -10;
186  }
187  if ((LDC < pastix_imax(1,M)) && (M > 0)) {
188  return -13;
189  }
190  if ( incD < 0 ) {
191  return -15;
192  }
193  if ( ( ( transA == PastixNoTrans ) && ( LWORK < (M+1)*K) ) ||
194  ( ( transA != PastixNoTrans ) && ( LWORK < (N+1)*K) ) ){
195  pastix_print_error( "CORE_gemdm: Illegal value of LWORK\n" );
196  if (transA == PastixNoTrans ) {
197  pastix_print_error( "LWORK %d < (M=%d+1)*K=%d ", LWORK, M, K );
198  }
199  if (transA == PastixNoTrans ) {
200  pastix_print_error( "LWORK %d < (N=%d+1)*K=%d ", LWORK, N, K );
201  }
202  return -17;
203  }
204 
205  /* Quick return */
206  if (M == 0 || N == 0 ||
207  ((alpha == 0.0 || K == 0) && beta == 1.0) ) {
208  return PASTIX_SUCCESS;
209  }
210 
211  if ( incD == 1 ) {
212  wD = D;
213  } else {
214  wD2 = WORK;
215  cblas_dcopy(K, D, incD, wD2, 1);
216  wD = wD2;
217  }
218  w = WORK + K;
219 
220  /*
221  * transA == PastixNoTrans
222  */
223  if ( transA == PastixNoTrans )
224  {
225  /* WORK = A * D */
226  for (j=0; j<K; j++, wD++) {
227  delta = *wD;
228  cblas_dcopy(M, &A[LDA*j], 1, &w[M*j], 1);
229  cblas_dscal(M, (delta), &w[M*j], 1);
230  }
231 
232  /* C = alpha * WORK * op(B) + beta * C */
233  cblas_dgemm(CblasColMajor, CblasNoTrans, (CBLAS_TRANSPOSE)transB,
234  M, N, K,
235  (alpha), w, M,
236  B, LDB,
237  (beta), C, LDC);
238  }
239  else
240  {
241  if ( transB == PastixNoTrans ) /* Worst case*/
242  {
243  /* WORK = (D * B)' */
244  for (j=0; j<K; j++, wD++) {
245  delta = *wD;
246  cblas_dcopy(N, &B[j], LDB, &w[N*j], 1);
247  cblas_dscal(N, (delta), &w[N*j], 1);
248  }
249 
250  /* C = alpha * op(A) * WORK' + beta * C */
251  cblas_dgemm(CblasColMajor, (CBLAS_TRANSPOSE)transA, CblasTrans,
252  M, N, K,
253  (alpha), A, LDA,
254  w, N,
255  (beta), C, LDC);
256  }
257  else
258  {
259 #if defined(PRECISION_z) || defined(PRECISION_c)
260  if ( transB == PastixTrans )
261  {
262  /* WORK = D * B' */
263  for (j=0; j<K; j++, wD++) {
264  delta = *wD;
265  cblas_dcopy(N, &B[LDB*j], 1, &w[N*j], 1);
266  ret = LAPACKE_dlacgv_work(N, &w[N*j], 1);
267  assert( ret == 0 );
268  cblas_dscal(N, (delta), &w[N*j], 1);
269  }
270  }
271  else
272 #endif
273  {
274  /* WORK = D * B' */
275  for (j=0; j<K; j++, wD++) {
276  delta = *wD;
277  cblas_dcopy(N, &B[LDB*j], 1, &w[N*j], 1);
278  cblas_dscal(N, (delta), &w[N*j], 1);
279  }
280  }
281 
282  /* C = alpha * op(A) * WORK + beta * C */
283  cblas_dgemm(CblasColMajor, (CBLAS_TRANSPOSE)transA, CblasNoTrans,
284  M, N, K,
285  (alpha), A, LDA,
286  w, N,
287  (beta), C, LDC);
288  }
289  }
290  (void)ret;
291  return PASTIX_SUCCESS;
292 }
int core_dgemdm(pastix_trans_t transA, pastix_trans_t transB, int M, int N, int K, double alpha, const double *A, int LDA, const double *B, int LDB, double beta, double *C, int LDC, const double *D, int incD, double *WORK, int LWORK)
Perform one of the following matrix-matrix operations.
Definition: core_dgemdm.c:139
enum pastix_trans_e pastix_trans_t
Transpostion.
@ PastixNoTrans
Definition: api.h:445
@ PastixTrans
Definition: api.h:446
@ PASTIX_SUCCESS
Definition: api.h:367