PaStiX Handbook  6.2.1
core_dgemdm.c
Go to the documentation of this file.
1 /**
2  *
3  * @file core_dgemdm.c
4  *
5  * PaStiX kernel routines.
6  *
7  * @copyright 2010-2015 Univ. of Tennessee, Univ. of California Berkeley and
8  * Univ. of Colorado Denver. All rights reserved.
9  * @copyright 2015-2021 Bordeaux INP, CNRS (LaBRI UMR 5800), Inria,
10  * Univ. Bordeaux. All rights reserved.
11  *
12  * @version 6.1.0
13  * @author Dulceneia Becker
14  * @author Mathieu Faverge
15  * @author Gregoire Pichon
16  * @author Xavier Lacoste
17  * @date 2019-11-12
18  * @generated from /builds/solverstack/pastix/kernels/core_zgemdm.c, normal z -> d, Tue Apr 12 09:38:34 2022
19  *
20  **/
21 #include "common.h"
22 #include "cblas.h"
23 #include "lapacke.h"
24 
25 /**
26  ******************************************************************************
27  *
28  * @brief Perform one of the following matrix-matrix operations
29  *
30  * C := alpha*op( A )*D*op( B ) + beta*C,
31  *
32  * where op( X ) is one of
33  *
34  * op( X ) = X or op( X ) = X',
35  *
36  * alpha and beta are scalars, and A, B, C and D are matrices, with
37  *
38  * op( A ) an m by k matrix,
39  * op( B ) an k by n matrix,
40  * C an m by n matrix and
41  * D an k by k matrix.
42  *
43  *******************************************************************************
44  *
45  * @param[in] transA
46  * @arg PastixNoTrans : No transpose, op( A ) = A;
47  * @arg PastixTrans : Transpose, op( A ) = A'.
48  *
49  * @param[in] transB
50  * @arg PastixNoTrans : No transpose, op( B ) = B;
51  * @arg PastixTrans : Transpose, op( B ) = B'.
52  *
53  * @param[in] M
54  * The number of rows of the matrix op( A ) and of the
55  * matrix C. M must be at least zero.
56  *
57  * @param[in] N
58  * The number of columns of the matrix op( B ) and the
59  * number of columns of the matrix C. N must be at least zero.
60  *
61  * @param[in] K
62  * The number of columns of the matrix op( A ), the number of
63  * rows of the matrix op( B ), and the number of rows and columns
64  * of matrix D. K must be at least zero.
65  *
66  * @param[in] alpha
67  * On entry, alpha specifies the scalar alpha.
68  *
69  * @param[in] A
70  * double array of DIMENSION ( LDA, ka ), where ka is
71  * k when transA = PastixTrans, and is m otherwise.
72  * Before entry with transA = PastixTrans, the leading m by k
73  * part of the array A must contain the matrix A, otherwise
74  * the leading k by m part of the array A must contain the
75  * matrix A.
76  *
77  * @param[in] LDA
78  * On entry, LDA specifies the first dimension of A as declared
79  * in the calling (sub) program. When transA = PastixTrans then
80  * LDA must be at least max( 1, m ), otherwise LDA must be at
81  * least max( 1, k ).
82  *
83  * @param[in] B
84  * double array of DIMENSION ( LDB, kb ), where kb is
85  * n when transB = PastixTrans, and is k otherwise.
86  * Before entry with transB = PastixTrans, the leading k by n
87  * part of the array B must contain the matrix B, otherwise
88  * the leading n by k part of the array B must contain the
89  * matrix B.
90  *
91  * @param[in] LDB
92  * On entry, LDB specifies the first dimension of B as declared
93  * in the calling (sub) program. When transB = PastixTrans then
94  * LDB must be at least max( 1, k ), otherwise LDB must be at
95  * least max( 1, n ).
96  *
97  * @param[in] beta
98  * On entry, beta specifies the scalar beta. When beta is
99  * supplied as zero then C need not be set on input.
100  *
101  * @param[in] C
102  * double array of DIMENSION ( LDC, n ).
103  * Before entry, the leading m by n part of the array C must
104  * contain the matrix C, except when beta is zero, in which
105  * case C need not be set on entry.
106  * On exit, the array C is overwritten by the m by n matrix
107  * ( alpha*op( A )*D*op( B ) + beta*C ).
108  *
109  * @param[in] LDC
110  * On entry, LDC specifies the first dimension of C as declared
111  * in the calling (sub) program. LDC must be at least
112  * max( 1, m ).
113  *
114  * @param[in] D
115  * double array of DIMENSION ( LDD, k ).
116  * Before entry, the leading k by k part of the array D
117  * must contain the matrix D.
118  *
119  * @param[in] incD
120  * On entry, LDD specifies the first dimension of D as declared
121  * in the calling (sub) program. LDD must be at least
122  * max( 1, k ).
123  *
124  * @param[in] WORK
125  * double array, dimension (MAX(1,LWORK))
126  *
127  * @param[in] LWORK
128  * The length of WORK.
129  * On entry, if transA = PastixTrans and transB = PastixTrans then
130  * LWORK >= max(1, K*N). Otherwise LWORK >= max(1, M*K).
131  *
132  *******************************************************************************
133  *
134  * @retval PASTIX_SUCCESS successful exit
135  * @retval <0 if -i, the i-th argument had an illegal value
136  *
137  ******************************************************************************/
138 int
140  int M, int N, int K,
141  double alpha,
142  const double *A, int LDA,
143  const double *B, int LDB,
144  double beta,
145  double *C, int LDC,
146  const double *D, int incD,
147  double *WORK, int LWORK )
148 {
149  int j, Am, Bm;
150  double delta;
151  double *wD2, *w;
152  const double *wD;
153 
154  Am = (transA == PastixNoTrans ) ? M : K;
155  Bm = (transB == PastixNoTrans ) ? K : N;
156 
157  /* Check input arguments */
158  if ((transA < PastixNoTrans) || (transA > PastixTrans)) {
159  return -1;
160  }
161  if ((transB < PastixNoTrans) || (transB > PastixTrans)) {
162  return -2;
163  }
164  if (M < 0) {
165  return -3;
166  }
167  if (N < 0) {
168  return -4;
169  }
170  if (K < 0) {
171  return -5;
172  }
173  if ((LDA < pastix_imax(1,Am)) && (Am > 0)) {
174  return -8;
175  }
176  if ((LDB < pastix_imax(1,Bm)) && (Bm > 0)) {
177  return -10;
178  }
179  if ((LDC < pastix_imax(1,M)) && (M > 0)) {
180  return -13;
181  }
182  if ( incD < 0 ) {
183  return -15;
184  }
185  if ( ( ( transA == PastixNoTrans ) && ( LWORK < (M+1)*K) ) ||
186  ( ( transA != PastixNoTrans ) && ( LWORK < (N+1)*K) ) ){
187  pastix_print_error( "CORE_gemdm: Illegal value of LWORK\n" );
188  if (transA == PastixNoTrans ) {
189  pastix_print_error( "LWORK %d < (M=%d+1)*K=%d ", LWORK, M, K );
190  }
191  if (transA == PastixNoTrans ) {
192  pastix_print_error( "LWORK %d < (N=%d+1)*K=%d ", LWORK, N, K );
193  }
194  return -17;
195  }
196 
197  /* Quick return */
198  if (M == 0 || N == 0 ||
199  ((alpha == 0.0 || K == 0) && beta == 1.0) ) {
200  return PASTIX_SUCCESS;
201  }
202 
203  if ( incD == 1 ) {
204  wD = D;
205  } else {
206  wD2 = WORK;
207  cblas_dcopy(K, D, incD, wD2, 1);
208  wD = wD2;
209  }
210  w = WORK + K;
211 
212  /*
213  * transA == PastixNoTrans
214  */
215  if ( transA == PastixNoTrans )
216  {
217  /* WORK = A * D */
218  for (j=0; j<K; j++, wD++) {
219  delta = *wD;
220  cblas_dcopy(M, &A[LDA*j], 1, &w[M*j], 1);
221  cblas_dscal(M, (delta), &w[M*j], 1);
222  }
223 
224  /* C = alpha * WORK * op(B) + beta * C */
225  cblas_dgemm(CblasColMajor, CblasNoTrans, (CBLAS_TRANSPOSE)transB,
226  M, N, K,
227  (alpha), w, M,
228  B, LDB,
229  (beta), C, LDC);
230  }
231  else
232  {
233  if ( transB == PastixNoTrans ) /* Worst case*/
234  {
235  /* WORK = (D * B)' */
236  for (j=0; j<K; j++, wD++) {
237  delta = *wD;
238  cblas_dcopy(N, &B[j], LDB, &w[N*j], 1);
239  cblas_dscal(N, (delta), &w[N*j], 1);
240  }
241 
242  /* C = alpha * op(A) * WORK' + beta * C */
243  cblas_dgemm(CblasColMajor, (CBLAS_TRANSPOSE)transA, CblasTrans,
244  M, N, K,
245  (alpha), A, LDA,
246  w, N,
247  (beta), C, LDC);
248  }
249  else
250  {
251 #if defined(PRECISION_z) || defined(PRECISION_c)
252  if ( transB == PastixTrans )
253  {
254  /* WORK = D * B' */
255  for (j=0; j<K; j++, wD++) {
256  delta = *wD;
257  cblas_dcopy(N, &B[LDB*j], 1, &w[N*j], 1);
258  LAPACKE_dlacgv_work(N, &w[N*j], 1);
259  cblas_dscal(N, (delta), &w[N*j], 1);
260  }
261  }
262  else
263 #endif
264  {
265  /* WORK = D * B' */
266  for (j=0; j<K; j++, wD++) {
267  delta = *wD;
268  cblas_dcopy(N, &B[LDB*j], 1, &w[N*j], 1);
269  cblas_dscal(N, (delta), &w[N*j], 1);
270  }
271  }
272 
273  /* C = alpha * op(A) * WORK + beta * C */
274  cblas_dgemm(CblasColMajor, (CBLAS_TRANSPOSE)transA, CblasNoTrans,
275  M, N, K,
276  (alpha), A, LDA,
277  w, N,
278  (beta), C, LDC);
279  }
280  }
281  return PASTIX_SUCCESS;
282 }
PastixTrans
@ PastixTrans
Definition: api.h:425
core_dgemdm
int core_dgemdm(pastix_trans_t transA, pastix_trans_t transB, int M, int N, int K, double alpha, const double *A, int LDA, const double *B, int LDB, double beta, double *C, int LDC, const double *D, int incD, double *WORK, int LWORK)
Perform one of the following matrix-matrix operations.
Definition: core_dgemdm.c:139
pastix_trans_t
enum pastix_trans_e pastix_trans_t
Transpostion.
PastixNoTrans
@ PastixNoTrans
Definition: api.h:424
PASTIX_SUCCESS
@ PASTIX_SUCCESS
Definition: api.h:346