| /* |
| * Copyright (C) 2015 The Android Open Source Project |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package android.renderscript; |
| |
| import android.annotation.IntDef; |
| import java.lang.annotation.Retention; |
| import java.lang.annotation.RetentionPolicy; |
| |
| /** |
| * |
| * BLAS |
| * |
| * @hide |
| **/ |
| public final class ScriptIntrinsicBLAS extends ScriptIntrinsic { |
| private Allocation mLUT; |
| |
| private ScriptIntrinsicBLAS(long id, RenderScript rs) { |
| super(id, rs); |
| } |
| |
| private static final int RsBlas_sdsdot = 1; |
| private static final int RsBlas_dsdot = 2; |
| private static final int RsBlas_sdot = 3; |
| private static final int RsBlas_ddot = 4; |
| private static final int RsBlas_cdotu_sub = 5; |
| private static final int RsBlas_cdotc_sub = 6; |
| private static final int RsBlas_zdotu_sub = 7; |
| private static final int RsBlas_zdotc_sub = 8; |
| private static final int RsBlas_snrm2 = 9; |
| private static final int RsBlas_sasum = 10; |
| private static final int RsBlas_dnrm2 = 11; |
| private static final int RsBlas_dasum = 12; |
| private static final int RsBlas_scnrm2 = 13; |
| private static final int RsBlas_scasum = 14; |
| private static final int RsBlas_dznrm2 = 15; |
| private static final int RsBlas_dzasum = 16; |
| private static final int RsBlas_isamax = 17; |
| private static final int RsBlas_idamax = 18; |
| private static final int RsBlas_icamax = 19; |
| private static final int RsBlas_izamax = 20; |
| private static final int RsBlas_sswap = 21; |
| private static final int RsBlas_scopy = 22; |
| private static final int RsBlas_saxpy = 23; |
| private static final int RsBlas_dswap = 24; |
| private static final int RsBlas_dcopy = 25; |
| private static final int RsBlas_daxpy = 26; |
| private static final int RsBlas_cswap = 27; |
| private static final int RsBlas_ccopy = 28; |
| private static final int RsBlas_caxpy = 29; |
| private static final int RsBlas_zswap = 30; |
| private static final int RsBlas_zcopy = 31; |
| private static final int RsBlas_zaxpy = 32; |
| private static final int RsBlas_srotg = 33; |
| private static final int RsBlas_srotmg = 34; |
| private static final int RsBlas_srot = 35; |
| private static final int RsBlas_srotm = 36; |
| private static final int RsBlas_drotg = 37; |
| private static final int RsBlas_drotmg = 38; |
| private static final int RsBlas_drot = 39; |
| private static final int RsBlas_drotm = 40; |
| private static final int RsBlas_sscal = 41; |
| private static final int RsBlas_dscal = 42; |
| private static final int RsBlas_cscal = 43; |
| private static final int RsBlas_zscal = 44; |
| private static final int RsBlas_csscal = 45; |
| private static final int RsBlas_zdscal = 46; |
| private static final int RsBlas_sgemv = 47; |
| private static final int RsBlas_sgbmv = 48; |
| private static final int RsBlas_strmv = 49; |
| private static final int RsBlas_stbmv = 50; |
| private static final int RsBlas_stpmv = 51; |
| private static final int RsBlas_strsv = 52; |
| private static final int RsBlas_stbsv = 53; |
| private static final int RsBlas_stpsv = 54; |
| private static final int RsBlas_dgemv = 55; |
| private static final int RsBlas_dgbmv = 56; |
| private static final int RsBlas_dtrmv = 57; |
| private static final int RsBlas_dtbmv = 58; |
| private static final int RsBlas_dtpmv = 59; |
| private static final int RsBlas_dtrsv = 60; |
| private static final int RsBlas_dtbsv = 61; |
| private static final int RsBlas_dtpsv = 62; |
| private static final int RsBlas_cgemv = 63; |
| private static final int RsBlas_cgbmv = 64; |
| private static final int RsBlas_ctrmv = 65; |
| private static final int RsBlas_ctbmv = 66; |
| private static final int RsBlas_ctpmv = 67; |
| private static final int RsBlas_ctrsv = 68; |
| private static final int RsBlas_ctbsv = 69; |
| private static final int RsBlas_ctpsv = 70; |
| private static final int RsBlas_zgemv = 71; |
| private static final int RsBlas_zgbmv = 72; |
| private static final int RsBlas_ztrmv = 73; |
| private static final int RsBlas_ztbmv = 74; |
| private static final int RsBlas_ztpmv = 75; |
| private static final int RsBlas_ztrsv = 76; |
| private static final int RsBlas_ztbsv = 77; |
| private static final int RsBlas_ztpsv = 78; |
| private static final int RsBlas_ssymv = 79; |
| private static final int RsBlas_ssbmv = 80; |
| private static final int RsBlas_sspmv = 81; |
| private static final int RsBlas_sger = 82; |
| private static final int RsBlas_ssyr = 83; |
| private static final int RsBlas_sspr = 84; |
| private static final int RsBlas_ssyr2 = 85; |
| private static final int RsBlas_sspr2 = 86; |
| private static final int RsBlas_dsymv = 87; |
| private static final int RsBlas_dsbmv = 88; |
| private static final int RsBlas_dspmv = 89; |
| private static final int RsBlas_dger = 90; |
| private static final int RsBlas_dsyr = 91; |
| private static final int RsBlas_dspr = 92; |
| private static final int RsBlas_dsyr2 = 93; |
| private static final int RsBlas_dspr2 = 94; |
| private static final int RsBlas_chemv = 95; |
| private static final int RsBlas_chbmv = 96; |
| private static final int RsBlas_chpmv = 97; |
| private static final int RsBlas_cgeru = 98; |
| private static final int RsBlas_cgerc = 99; |
| private static final int RsBlas_cher = 100; |
| private static final int RsBlas_chpr = 101; |
| private static final int RsBlas_cher2 = 102; |
| private static final int RsBlas_chpr2 = 103; |
| private static final int RsBlas_zhemv = 104; |
| private static final int RsBlas_zhbmv = 105; |
| private static final int RsBlas_zhpmv = 106; |
| private static final int RsBlas_zgeru = 107; |
| private static final int RsBlas_zgerc = 108; |
| private static final int RsBlas_zher = 109; |
| private static final int RsBlas_zhpr = 110; |
| private static final int RsBlas_zher2 = 111; |
| private static final int RsBlas_zhpr2 = 112; |
| private static final int RsBlas_sgemm = 113; |
| private static final int RsBlas_ssymm = 114; |
| private static final int RsBlas_ssyrk = 115; |
| private static final int RsBlas_ssyr2k = 116; |
| private static final int RsBlas_strmm = 117; |
| private static final int RsBlas_strsm = 118; |
| private static final int RsBlas_dgemm = 119; |
| private static final int RsBlas_dsymm = 120; |
| private static final int RsBlas_dsyrk = 121; |
| private static final int RsBlas_dsyr2k = 122; |
| private static final int RsBlas_dtrmm = 123; |
| private static final int RsBlas_dtrsm = 124; |
| private static final int RsBlas_cgemm = 125; |
| private static final int RsBlas_csymm = 126; |
| private static final int RsBlas_csyrk = 127; |
| private static final int RsBlas_csyr2k = 128; |
| private static final int RsBlas_ctrmm = 129; |
| private static final int RsBlas_ctrsm = 130; |
| private static final int RsBlas_zgemm = 131; |
| private static final int RsBlas_zsymm = 132; |
| private static final int RsBlas_zsyrk = 133; |
| private static final int RsBlas_zsyr2k = 134; |
| private static final int RsBlas_ztrmm = 135; |
| private static final int RsBlas_ztrsm = 136; |
| private static final int RsBlas_chemm = 137; |
| private static final int RsBlas_cherk = 138; |
| private static final int RsBlas_cher2k = 139; |
| private static final int RsBlas_zhemm = 140; |
| private static final int RsBlas_zherk = 141; |
| private static final int RsBlas_zher2k = 142; |
| |
| // BLAS extensions start here |
| private static final int RsBlas_bnnm = 1000; |
| |
| /** |
| */ |
| public static ScriptIntrinsicBLAS create(RenderScript rs) { |
| long id = rs.nScriptIntrinsicCreate(13, Element.U32(rs).getID(rs)); |
| return new ScriptIntrinsicBLAS(id, rs); |
| } |
| |
| @IntDef({NO_TRANSPOSE, TRANSPOSE, CONJ_TRANSPOSE}) |
| @Retention(RetentionPolicy.SOURCE) |
| public @interface Transpose {} |
| |
| @IntDef({UPPER, LOWER}) |
| @Retention(RetentionPolicy.SOURCE) |
| public @interface Uplo {} |
| |
| @IntDef({NON_UNIT, UNIT}) |
| @Retention(RetentionPolicy.SOURCE) |
| public @interface Diag {} |
| |
| @IntDef({LEFT, RIGHT}) |
| @Retention(RetentionPolicy.SOURCE) |
| public @interface Side {} |
| |
| public static final int NO_TRANSPOSE = 111; |
| public static final int TRANSPOSE = 112; |
| public static final int CONJ_TRANSPOSE = 113; |
| |
| public static final int UPPER = 121; |
| public static final int LOWER = 122; |
| |
| public static final int NON_UNIT = 131; |
| public static final int UNIT = 132; |
| |
| public static final int LEFT = 141; |
| public static final int RIGHT = 142; |
| |
| static void validateSide(@Side int Side) { |
| if (Side != LEFT && Side != RIGHT) { |
| throw new RSRuntimeException("Invalid side passed to BLAS"); |
| } |
| } |
| |
| static void validateTranspose(@Transpose int Trans) { |
| if (Trans != NO_TRANSPOSE && Trans != TRANSPOSE && |
| Trans != CONJ_TRANSPOSE) { |
| throw new RSRuntimeException("Invalid transpose passed to BLAS"); |
| } |
| } |
| |
| static void validateConjTranspose(@Transpose int Trans) { |
| if (Trans != NO_TRANSPOSE && |
| Trans != CONJ_TRANSPOSE) { |
| throw new RSRuntimeException("Invalid transpose passed to BLAS"); |
| } |
| } |
| |
| static void validateDiag(@Diag int Diag) { |
| if (Diag != NON_UNIT && Diag != UNIT) { |
| throw new RSRuntimeException("Invalid diag passed to BLAS"); |
| } |
| } |
| |
| static void validateUplo(@Uplo int Uplo) { |
| if (Uplo != LEFT && Uplo != RIGHT) { |
| throw new RSRuntimeException("Invalid uplo passed to BLAS"); |
| } |
| } |
| |
| |
| /** |
| * Level 2 BLAS |
| */ |
| |
| static void validateGEMV(Element e, int TransA, Allocation A, Allocation X, int incX, Allocation Y, int incY) { |
| validateTranspose(TransA); |
| int M = A.getType().getY(); |
| int N = A.getType().getX(); |
| if (!A.getType().getElement().isCompatible(e) || |
| !X.getType().getElement().isCompatible(e) || |
| !Y.getType().getElement().isCompatible(e)) { |
| throw new RSRuntimeException("Called BLAS with wrong Element type"); |
| } |
| if (X.getType().getY() > 1 || Y.getType().getY() > 1) { |
| throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); |
| } |
| |
| if (incX <= 0 || incY <= 0) { |
| throw new RSRuntimeException("Vector increments must be greater than 0"); |
| } |
| int expectedXDim = -1, expectedYDim = -1; |
| if (TransA == NO_TRANSPOSE) { |
| expectedXDim = 1 + (N - 1) * incX; |
| expectedYDim = 1 + (M - 1) * incY; |
| } else { |
| expectedXDim = 1 + (M - 1) * incX; |
| expectedYDim = 1 + (N - 1) * incY; |
| } |
| if (X.getType().getX() != expectedXDim || |
| Y.getType().getY() != expectedXDim) { |
| throw new RSRuntimeException("Incorrect vector dimensions for GEMV"); |
| } |
| } |
| void SGEMV(@Transpose int TransA, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) { |
| validateGEMV(Element.F32(mRS), TransA, A, X, incX, Y, incY); |
| int M = A.getType().getY(); |
| int N = A.getType().getX(); |
| mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0); |
| } |
| void DGEMV(@Transpose int TransA, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) { |
| validateGEMV(Element.F64(mRS), TransA, A, X, incX, Y, incY); |
| int M = A.getType().getY(); |
| int N = A.getType().getX(); |
| mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0); |
| } |
| void CGEMV(@Transpose int TransA, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) { |
| validateGEMV(Element.F32_2(mRS), TransA, A, X, incX, Y, incY); |
| int M = A.getType().getY(); |
| int N = A.getType().getX(); |
| mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0); |
| } |
| void ZGEMV(@Transpose int TransA, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) { |
| validateGEMV(Element.F64_2(mRS), TransA, A, X, incX, Y, incY); |
| int M = A.getType().getY(); |
| int N = A.getType().getX(); |
| mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0); |
| } |
| |
| void SGBMV(@Transpose int TransA, int KL, int KU, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) { |
| // GBMV has the same validation requirements as GEMV + KL and KU >= 0 |
| validateGEMV(Element.F32(mRS), TransA, A, X, incX, Y, incY); |
| if (KL < 0 || KU < 0) { |
| throw new RSRuntimeException("KL and KU must be greater than or equal to 0"); |
| } |
| int M = A.getType().getY(); |
| int N = A.getType().getX(); |
| mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, KL, KU); |
| } |
| void DGBMV(@Transpose int TransA, int KL, int KU, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) { |
| // GBMV has the same validation requirements as GEMV + KL and KU >= 0 |
| validateGEMV(Element.F64(mRS), TransA, A, X, incX, Y, incY); |
| if (KL < 0 || KU < 0) { |
| throw new RSRuntimeException("KL and KU must be greater than or equal to 0"); |
| } |
| int M = A.getType().getY(); |
| int N = A.getType().getX(); |
| mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, KL, KU); |
| } |
| void CGBMV(@Transpose int TransA, int KL, int KU, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) { |
| // GBMV has the same validation requirements as GEMV + KL and KU >= 0 |
| validateGEMV(Element.F32_2(mRS), TransA, A, X, incX, Y, incY); |
| if (KL < 0 || KU < 0) { |
| throw new RSRuntimeException("KL and KU must be greater than or equal to 0"); |
| } |
| int M = A.getType().getY(); |
| int N = A.getType().getX(); |
| mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, KL, KU); |
| } |
| void ZGBMV(@Transpose int TransA, int KL, int KU, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) { |
| // GBMV has the same validation requirements as GEMV + KL and KU >= 0 |
| validateGEMV(Element.F64_2(mRS), TransA, A, X, incX, Y, incY); |
| if (KL < 0 || KU < 0) { |
| throw new RSRuntimeException("KL and KU must be greater than or equal to 0"); |
| } |
| int M = A.getType().getY(); |
| int N = A.getType().getX(); |
| mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, KL, KU); |
| } |
| |
| static void validateTRMV(Element e, @Transpose int TransA, Allocation A, Allocation X, int incX) { |
| validateTranspose(TransA); |
| int N = A.getType().getY(); |
| if (A.getType().getX() != N) { |
| throw new RSRuntimeException("A must be a square matrix for TRMV"); |
| } |
| if (!A.getType().getElement().isCompatible(e) || |
| !X.getType().getElement().isCompatible(e)) { |
| throw new RSRuntimeException("Called BLAS with wrong Element type"); |
| } |
| if (X.getType().getY() > 1) { |
| throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); |
| } |
| |
| if (incX <= 0) { |
| throw new RSRuntimeException("Vector increments must be greater than 0"); |
| } |
| int expectedXDim = 1 + (N - 1) * incX; |
| if (X.getType().getX() != expectedXDim) { |
| throw new RSRuntimeException("Incorrect vector dimensions for TRMV"); |
| } |
| } |
| |
| static int validateTPMV(Element e, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { |
| validateTranspose(TransA); |
| validateUplo(Uplo); |
| validateDiag(Diag); |
| if (!Ap.getType().getElement().isCompatible(e) || |
| !X.getType().getElement().isCompatible(e)) { |
| throw new RSRuntimeException("Called BLAS with wrong Element type"); |
| } |
| if (X.getType().getY() > 1) { |
| throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); |
| } |
| |
| if (Ap.getType().getY() > 1) { |
| throw new RSRuntimeException("Ap must have a Y dimension of 0 or 1"); |
| } |
| |
| int N = (int)Math.sqrt((double)Ap.getType().getX() * 2); |
| if (Ap.getType().getX() != ((N * (N+1)) / 2)) { |
| throw new RSRuntimeException("Invalid dimension for Ap"); |
| } |
| |
| int expectedXDim = 1 + (N - 1) * incX; |
| if (X.getType().getX() != expectedXDim) { |
| throw new RSRuntimeException("Incorrect vector dimensions for SYMV"); |
| } |
| |
| return N; |
| } |
| |
| void STRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { |
| validateTRMV(Element.F32(mRS), TransA, A, X, incX); |
| int N = A.getType().getY(); |
| mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); |
| } |
| void DTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { |
| validateTRMV(Element.F64(mRS), TransA, A, X, incX); |
| int N = A.getType().getY(); |
| mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); |
| } |
| void CTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { |
| validateTRMV(Element.F32_2(mRS), TransA, A, X, incX); |
| int N = A.getType().getY(); |
| mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); |
| } |
| void ZTRMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { |
| validateTRMV(Element.F64_2(mRS), TransA, A, X, incX); |
| int N = A.getType().getY(); |
| mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); |
| } |
| void STBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { |
| // TBMV has the same requirements as TRMV |
| validateTRMV(Element.F32(mRS), TransA, A, X, incX); |
| int N = A.getType().getY(); |
| mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); |
| } |
| void DTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { |
| // TBMV has the same requirements as TRMV |
| validateTRMV(Element.F64(mRS), TransA, A, X, incX); |
| int N = A.getType().getY(); |
| mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); |
| } |
| void CTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { |
| // TBMV has the same requirements as TRMV |
| validateTRMV(Element.F32_2(mRS), TransA, A, X, incX); |
| int N = A.getType().getY(); |
| mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); |
| } |
| void ZTBMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { |
| // TBMV has the same requirements as TRMV |
| validateTRMV(Element.F64_2(mRS), TransA, A, X, incX); |
| int N = A.getType().getY(); |
| mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); |
| } |
| void STPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { |
| int N = validateTPMV(Element.F32(mRS), Uplo, TransA, Diag, Ap, X, incX); |
| mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); |
| } |
| void DTPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { |
| int N = validateTPMV(Element.F64(mRS), Uplo, TransA, Diag, Ap, X, incX); |
| mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); |
| } |
| void CTPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { |
| int N = validateTPMV(Element.F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX); |
| mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); |
| } |
| void ZTPMV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { |
| int N = validateTPMV(Element.F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX); |
| mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); |
| } |
| void STRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { |
| // TRSV is the same as TRMV |
| validateTRMV(Element.F32(mRS), TransA, A, X, incX); |
| int N = A.getType().getY(); |
| mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); |
| |
| } |
| void DTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { |
| // TRSV is the same as TRMV |
| validateTRMV(Element.F64(mRS), TransA, A, X, incX); |
| int N = A.getType().getY(); |
| mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); |
| |
| } |
| void CTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { |
| // TRSV is the same as TRMV |
| validateTRMV(Element.F32_2(mRS), TransA, A, X, incX); |
| int N = A.getType().getY(); |
| mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); |
| |
| } |
| void ZTRSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation A, Allocation X, int incX) { |
| // TRSV is the same as TRMV |
| validateTRMV(Element.F64_2(mRS), TransA, A, X, incX); |
| int N = A.getType().getY(); |
| mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); |
| |
| } |
| void STBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { |
| // TBSV is the same as TRMV |
| validateTRMV(Element.F32(mRS), TransA, A, X, incX); |
| int N = A.getType().getY(); |
| if (K < 0) { |
| throw new RSRuntimeException("Number of diagonals must be positive"); |
| } |
| mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); |
| } |
| void DTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { |
| // TBSV is the same as TRMV |
| validateTRMV(Element.F64(mRS), TransA, A, X, incX); |
| int N = A.getType().getY(); |
| if (K < 0) { |
| throw new RSRuntimeException("Number of diagonals must be positive"); |
| } |
| mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); |
| } |
| void CTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { |
| // TBSV is the same as TRMV |
| validateTRMV(Element.F32_2(mRS), TransA, A, X, incX); |
| int N = A.getType().getY(); |
| if (K < 0) { |
| throw new RSRuntimeException("Number of diagonals must be positive"); |
| } |
| mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); |
| } |
| void ZTBSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, int K, Allocation A, Allocation X, int incX) { |
| // TBSV is the same as TRMV |
| validateTRMV(Element.F64_2(mRS), TransA, A, X, incX); |
| int N = A.getType().getY(); |
| if (K < 0) { |
| throw new RSRuntimeException("Number of diagonals must be positive"); |
| } |
| mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); |
| } |
| void STPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { |
| // TPSV is same as TPMV |
| int N = validateTPMV(Element.F32(mRS), Uplo, TransA, Diag, Ap, X, incX); |
| mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_stpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); |
| } |
| void DTPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { |
| // TPSV is same as TPMV |
| int N = validateTPMV(Element.F64(mRS), Uplo, TransA, Diag, Ap, X, incX); |
| mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dtpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, incX, 0, 0, 0); |
| } |
| void CTPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { |
| // TPSV is same as TPMV |
| int N = validateTPMV(Element.F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX); |
| mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ctpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); |
| } |
| void ZTPSV(@Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Allocation Ap, Allocation X, int incX) { |
| // TPSV is same as TPMV |
| int N = validateTPMV(Element.F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX); |
| mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ztpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap.getID(mRS), X.getID(mRS), 0, 0, 0, incX, 0, 0, 0); |
| } |
| |
| /** |
| * Level 2, S and D only |
| */ |
| static int validateSYMV(Element e, @Uplo int Uplo, Allocation A, Allocation X, Allocation Y, int incX, int incY) { |
| validateUplo(Uplo); |
| int N = A.getType().getY(); |
| if (A.getType().getX() != N) { |
| throw new RSRuntimeException("A must be a square matrix for SYMV"); |
| } |
| if (!A.getType().getElement().isCompatible(e) || |
| !X.getType().getElement().isCompatible(e) || |
| !Y.getType().getElement().isCompatible(e) ) { |
| throw new RSRuntimeException("Called BLAS with wrong Element type"); |
| } |
| if (X.getType().getY() > 1 || Y.getType().getY() > 1) { |
| throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); |
| } |
| |
| if (incX <= 0 || incY <= 0) { |
| throw new RSRuntimeException("Vector increments must be greater than 0"); |
| } |
| int expectedXDim = 1 + (N - 1) * incX; |
| if (X.getType().getX() != expectedXDim) { |
| throw new RSRuntimeException("Incorrect vector dimensions for SYMV"); |
| } |
| int expectedYDim = 1 + (N - 1) * incY; |
| if (Y.getType().getX() != expectedYDim) { |
| throw new RSRuntimeException("Incorrect vector dimensions for SYMV"); |
| } |
| return N; |
| } |
| static int validateSPMV(Element e, @Uplo int Uplo, Allocation Ap, Allocation X, int incX, Allocation Y, int incY) { |
| validateUplo(Uplo); |
| if (!Ap.getType().getElement().isCompatible(e) || |
| !X.getType().getElement().isCompatible(e) || |
| !Y.getType().getElement().isCompatible(e)) { |
| throw new RSRuntimeException("Called BLAS with wrong Element type"); |
| } |
| if (X.getType().getY() > 1 || Y.getType().getY() > 1) { |
| throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); |
| } |
| |
| if (Ap.getType().getY() > 1) { |
| throw new RSRuntimeException("Ap must have a Y dimension of 0 or 1"); |
| } |
| |
| int N = (int)Math.sqrt((double)Ap.getType().getX() * 2); |
| if (Ap.getType().getX() != ((N * (N+1)) / 2)) { |
| throw new RSRuntimeException("Invalid dimension for Ap"); |
| } |
| |
| int expectedXDim = 1 + (N - 1) * incX; |
| if (X.getType().getX() != expectedXDim) { |
| throw new RSRuntimeException("Incorrect vector dimensions for SPMV"); |
| } |
| int expectedYDim = 1 + (N - 1) * incY; |
| if (Y.getType().getX() != expectedYDim) { |
| throw new RSRuntimeException("Incorrect vector dimensions for SPMV"); |
| } |
| |
| return N; |
| } |
| static void validateGER(Element e, Allocation X, int incX, Allocation Y, int incY, Allocation A) { |
| if (!A.getType().getElement().isCompatible(e) || |
| !X.getType().getElement().isCompatible(e) || |
| !Y.getType().getElement().isCompatible(e) ) { |
| throw new RSRuntimeException("Called BLAS with wrong Element type"); |
| } |
| |
| if (X.getType().getY() > 1 || Y.getType().getY() > 1) { |
| throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); |
| } |
| |
| int M = A.getType().getY(); |
| int N = A.getType().getX(); |
| |
| if (N < 1 || M < 1) { |
| throw new RSRuntimeException("M and N must be 1 or greater for GER"); |
| } |
| |
| int expectedXDim = 1 + (N - 1) * incX; |
| if (X.getType().getX() != expectedXDim) { |
| throw new RSRuntimeException("Incorrect vector dimensions for GER"); |
| } |
| int expectedYDim = 1 + (N - 1) * incY; |
| if (Y.getType().getX() != expectedYDim) { |
| throw new RSRuntimeException("Incorrect vector dimensions for GER"); |
| } |
| |
| |
| } |
| static int validateSYR(Element e, @Uplo int Uplo, Allocation X, int incX, Allocation A) { |
| validateUplo(Uplo); |
| if (!A.getType().getElement().isCompatible(e) || |
| !X.getType().getElement().isCompatible(e)) { |
| throw new RSRuntimeException("Called BLAS with wrong Element type"); |
| } |
| |
| int N = A.getType().getX(); |
| |
| if (X.getType().getY() > 1) { |
| throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); |
| } |
| if (N != A.getType().getY()) { |
| throw new RSRuntimeException("A must be a symmetric matrix"); |
| } |
| |
| int expectedXDim = 1 + (N - 1) * incX; |
| if (X.getType().getX() != expectedXDim) { |
| throw new RSRuntimeException("Incorrect vector dimensions for SYR"); |
| } |
| return N; |
| } |
| static int validateSPR(Element e, @Uplo int Uplo, Allocation X, int incX, Allocation Ap) { |
| validateUplo(Uplo); |
| if (!Ap.getType().getElement().isCompatible(e) || |
| !X.getType().getElement().isCompatible(e)) { |
| throw new RSRuntimeException("Called BLAS with wrong Element type"); |
| } |
| if (X.getType().getY() > 1) { |
| throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); |
| } |
| |
| if (Ap.getType().getY() > 1) { |
| throw new RSRuntimeException("Ap must have a Y dimension of 0 or 1"); |
| } |
| |
| int N = (int)Math.sqrt((double)Ap.getType().getX() * 2); |
| if (Ap.getType().getX() != ((N * (N+1)) / 2)) { |
| throw new RSRuntimeException("Invalid dimension for Ap"); |
| } |
| |
| int expectedXDim = 1 + (N - 1) * incX; |
| if (X.getType().getX() != expectedXDim) { |
| throw new RSRuntimeException("Incorrect vector dimensions for SPMV"); |
| } |
| |
| return N; |
| } |
| |
| static int validateSYR2(Element e, @Uplo int Uplo, Allocation X, int incX, Allocation Y, int incY, Allocation A) { |
| validateUplo(Uplo); |
| if (!A.getType().getElement().isCompatible(e) || |
| !X.getType().getElement().isCompatible(e) || |
| !Y.getType().getElement().isCompatible(e)) { |
| throw new RSRuntimeException("Called BLAS with wrong Element type"); |
| } |
| |
| if (X.getType().getY() > 1 || Y.getType().getY() > 1) { |
| throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); |
| } |
| |
| int N = A.getType().getX(); |
| |
| if (N != A.getType().getY()) { |
| throw new RSRuntimeException("A must be a symmetric matrix"); |
| } |
| |
| int expectedXDim = 1 + (N - 1) * incX; |
| int expectedYDim = 1 + (N - 1) * incY; |
| if (X.getType().getX() != expectedXDim || Y.getType().getX() != expectedYDim) { |
| throw new RSRuntimeException("Incorrect vector dimensions for SYR"); |
| } |
| return N; |
| |
| } |
| static int validateSPR2(Element e, @Uplo int Uplo, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) { |
| validateUplo(Uplo); |
| if (!Ap.getType().getElement().isCompatible(e) || |
| !X.getType().getElement().isCompatible(e) || |
| !Y.getType().getElement().isCompatible(e)) { |
| throw new RSRuntimeException("Called BLAS with wrong Element type"); |
| } |
| if (X.getType().getY() > 1 || Y.getType().getY() > 1) { |
| throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); |
| } |
| |
| if (Ap.getType().getY() > 1) { |
| throw new RSRuntimeException("Ap must have a Y dimension of 0 or 1"); |
| } |
| |
| int N = (int)Math.sqrt((double)Ap.getType().getX() * 2); |
| if (Ap.getType().getX() != ((N * (N+1)) / 2)) { |
| throw new RSRuntimeException("Invalid dimension for Ap"); |
| } |
| |
| int expectedXDim = 1 + (N - 1) * incX; |
| int expectedYDim = 1 + (N - 1) * incY; |
| if (X.getType().getX() != expectedXDim || Y.getType().getX() != expectedYDim) { |
| throw new RSRuntimeException("Incorrect vector dimensions for SPMV"); |
| } |
| |
| return N; |
| } |
| |
| void SSYMV(@Uplo int Uplo, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) { |
| int N = validateSYMV(Element.F32(mRS), Uplo, A, X, Y, incX, incY); |
| mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssymv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0); |
| } |
| void SSBMV(@Uplo int Uplo, int K, float alpha, Allocation A, Allocation X, int incX, float beta, Allocation Y, int incY) { |
| // SBMV is the same as SYMV |
| int N = validateSYMV(Element.F32(mRS), Uplo, A, X, Y, incX, incY); |
| mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0); |
| } |
| void SSPMV(@Uplo int Uplo, float alpha, Allocation Ap, Allocation X, int incX, float beta, Allocation Y, int incY) { |
| int N = validateSPMV(Element.F32(mRS), Uplo, Ap, X, incX, Y, incY); |
| mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sspmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, Ap.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0); |
| } |
| void SGER(float alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { |
| int M = A.getType().getY(); |
| int N = A.getType().getX(); |
| mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sger, 0, 0, 0, 0, 0, M, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0.f, A.getID(mRS), incX, incY, 0, 0); |
| } |
| void SSYR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation A) { |
| int N = validateSYR(Element.F32(mRS), Uplo, X, incX, A); |
| mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssyr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), A.getID(mRS), 0.f, 0, incX, 0, 0, 0); |
| } |
| void SSPR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Ap) { |
| int N = validateSPR(Element.F32(mRS), Uplo, X, incX, Ap); |
| mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sspr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Ap.getID(mRS), 0.f, 0, incX, 0, 0, 0); |
| } |
| void SSYR2(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { |
| int N = validateSYR2(Element.F32(mRS), Uplo, X, incX, Y, incY, A); |
| mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssyr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, A.getID(mRS), incX, incY, 0, 0); |
| } |
| void SSPR2(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) { |
| int N = validateSPR2(Element.F32(mRS), Uplo, X, incX, Y, incY, Ap); |
| mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sspr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, Ap.getID(mRS), incX, incY, 0, 0); |
| } |
| void DSYMV(@Uplo int Uplo, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) { |
| int N = validateSYMV(Element.F64(mRS), Uplo, A, X, Y, incX, incY); |
| mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsymv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0); |
| } |
| void DSBMV(@Uplo int Uplo, int K, double alpha, Allocation A, Allocation X, int incX, double beta, Allocation Y, int incY) { |
| // SBMV is the same as SYMV |
| int N = validateSYMV(Element.F64(mRS), Uplo, A, X, Y, incX, incY); |
| mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha, A.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0); |
| } |
| void DSPMV(@Uplo int Uplo, double alpha, Allocation Ap, Allocation X, int incX, double beta, Allocation Y, int incY) { |
| int N = validateSPMV(Element.F64(mRS), Uplo, Ap, X, incX, Y, incY); |
| mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dspmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, Ap.getID(mRS), X.getID(mRS), beta, Y.getID(mRS), incX, incY, 0, 0); |
| } |
| void DGER(double alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { |
| int M = A.getType().getY(); |
| int N = A.getType().getX(); |
| mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dger, 0, 0, 0, 0, 0, M, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0.f, A.getID(mRS), incX, incY, 0, 0); |
| } |
| void DSYR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation A) { |
| int N = validateSYR(Element.F64(mRS), Uplo, X, incX, A); |
| mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsyr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), A.getID(mRS), 0.f, 0, incX, 0, 0, 0); |
| } |
| void DSPR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Ap) { |
| int N = validateSPR(Element.F64(mRS), Uplo, X, incX, Ap); |
| mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dspr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Ap.getID(mRS), 0.f, 0, incX, 0, 0, 0); |
| } |
| void DSYR2(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { |
| int N = validateSYR2(Element.F64(mRS), Uplo, X, incX, Y, incY, A); |
| mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsyr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, A.getID(mRS), incX, incY, 0, 0); |
| } |
| void DSPR2(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) { |
| int N = validateSPR2(Element.F64(mRS), Uplo, X, incX, Y, incY, Ap); |
| mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dspr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X.getID(mRS), Y.getID(mRS), 0, Ap.getID(mRS), incX, incY, 0, 0); |
| } |
| |
| |
| /** |
| * Level 2, C and Z only |
| */ |
| |
| static void validateGERU(Element e, Allocation X, int incX, Allocation Y, int incY, Allocation A) { |
| if (!A.getType().getElement().isCompatible(e) || |
| !X.getType().getElement().isCompatible(e) || |
| !Y.getType().getElement().isCompatible(e)) { |
| throw new RSRuntimeException("Called BLAS with wrong Element type"); |
| } |
| if (X.getType().getY() > 1 || Y.getType().getY() > 1) { |
| throw new RSRuntimeException("BLAS vectors must have Y dimension of 0 or 1"); |
| } |
| |
| int M = A.getType().getY(); |
| int N = A.getType().getX(); |
| |
| int expectedXDim = 1 + (N - 1) * incX; |
| if (X.getType().getX() != expectedXDim) { |
| throw new RSRuntimeException("Incorrect vector dimensions for GERU"); |
| } |
| int expectedYDim = 1 + (N - 1) * incY; |
| if (Y.getType().getX() != expectedYDim) { |
| throw new RSRuntimeException("Incorrect vector dimensions for GERU"); |
| } |
| |
| } |
| |
| void CHEMV(@Uplo int Uplo, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) { |
| // HEMV is the same as SYR2 validation-wise |
| int N = validateSYR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, A); |
| mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chemv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0); |
| } |
| void CHBMV(@Uplo int Uplo, int K, Float2 alpha, Allocation A, Allocation X, int incX, Float2 beta, Allocation Y, int incY) { |
| // HBMV is the same as SYR2 validation-wise |
| int N = validateSYR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, A); |
| if (K < 0) { |
| throw new RSRuntimeException("K must be 0 or greater for HBMV"); |
| } |
| mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0); |
| } |
| void CHPMV(@Uplo int Uplo, Float2 alpha, Allocation Ap, Allocation X, int incX, Float2 beta, Allocation Y, int incY) { |
| // HPMV is the same as SPR2 |
| int N = validateSPR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, Ap); |
| mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chpmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, Ap.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0); |
| } |
| void CGERU(Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { |
| validateGERU(Element.F32_2(mRS), X, incX, Y, incY, A); |
| int M = A.getType().getY(); |
| int N = A.getType().getX(); |
| mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgeru, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0); |
| } |
| void CGERC(Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { |
| // same as GERU |
| validateGERU(Element.F32_2(mRS), X, incX, Y, incY, A); |
| int M = A.getType().getY(); |
| int N = A.getType().getX(); |
| mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgerc, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0); |
| } |
| void CHER(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation A) { |
| // same as SYR |
| int N = validateSYR(Element.F32(mRS), Uplo, X, incX, A); |
| mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cher, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, A.getID(mRS), incX, 0, 0, 0); |
| } |
| void CHPR(@Uplo int Uplo, float alpha, Allocation X, int incX, Allocation Ap) { |
| // equivalent to SPR for validation |
| int N = validateSPR(Element.F32_2(mRS), Uplo, X, incX, Ap); |
| mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chpr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, Ap.getID(mRS), incX, 0, 0, 0); |
| } |
| void CHER2(@Uplo int Uplo, Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { |
| // same as SYR2 |
| int N = validateSYR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, A); |
| mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cher2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0); |
| } |
| void CHPR2(@Uplo int Uplo, Float2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) { |
| // same as SPR2 |
| int N = validateSPR2(Element.F32_2(mRS), Uplo, X, incX, Y, incY, Ap); |
| mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chpr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, Ap.getID(mRS), incX, incY, 0, 0); |
| } |
| void ZHEMV(@Uplo int Uplo, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) { |
| // HEMV is the same as SYR2 validation-wise |
| int N = validateSYR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, A); |
| mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhemv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0); |
| } |
| void ZHBMV(@Uplo int Uplo, int K, Double2 alpha, Allocation A, Allocation X, int incX, Double2 beta, Allocation Y, int incY) { |
| // HBMV is the same as SYR2 validation-wise |
| int N = validateSYR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, A); |
| if (K < 0) { |
| throw new RSRuntimeException("K must be 0 or greater for HBMV"); |
| } |
| mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha.x, alpha.y, A.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0); |
| } |
| void ZHPMV(@Uplo int Uplo, Double2 alpha, Allocation Ap, Allocation X, int incX, Double2 beta, Allocation Y, int incY) { |
| // HPMV is the same as SPR2 |
| int N = validateSPR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, Ap); |
| mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhpmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, Ap.getID(mRS), X.getID(mRS), beta.x, beta.y, Y.getID(mRS), incX, incY, 0, 0); |
| } |
| void ZGERU(Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { |
| validateGERU(Element.F64_2(mRS), X, incX, Y, incY, A); |
| int M = A.getType().getY(); |
| int N = A.getType().getX(); |
| mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgeru, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0); |
| } |
| void ZGERC(Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { |
| // same as GERU |
| validateGERU(Element.F64_2(mRS), X, incX, Y, incY, A); |
| int M = A.getType().getY(); |
| int N = A.getType().getX(); |
| mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgerc, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0); |
| } |
| void ZHER(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation A) { |
| // same as SYR |
| int N = validateSYR(Element.F64(mRS), Uplo, X, incX, A); |
| mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zher, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, A.getID(mRS), incX, 0, 0, 0); |
| } |
| void ZHPR(@Uplo int Uplo, double alpha, Allocation X, int incX, Allocation Ap) { |
| // equivalent to SPR for validation |
| int N = validateSPR(Element.F64_2(mRS), Uplo, X, incX, Ap); |
| mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhpr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X.getID(mRS), 0, 0, 0, Ap.getID(mRS), incX, 0, 0, 0); |
| } |
| void ZHER2(@Uplo int Uplo, Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation A) { |
| // same as SYR2 |
| int N = validateSYR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, A); |
| mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zher2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, A.getID(mRS), incX, incY, 0, 0); |
| } |
| void ZHPR2(@Uplo int Uplo, Double2 alpha, Allocation X, int incX, Allocation Y, int incY, Allocation Ap) { |
| // same as SPR2 |
| int N = validateSPR2(Element.F64_2(mRS), Uplo, X, incX, Y, incY, Ap); |
| mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhpr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X.getID(mRS), Y.getID(mRS), 0, 0, Ap.getID(mRS), incX, incY, 0, 0); |
| } |
| |
| |
| /** |
| * Level 3 BLAS |
| */ |
| |
| static void validateL3(Element e, int TransA, int TransB, int Side, Allocation A, Allocation B, Allocation C) { |
| int aX = -1, aY = -1, bX = -1, bY = -1, cX = -1, cY = -1; |
| if ((A != null && !A.getType().getElement().isCompatible(e)) || |
| (B != null && !B.getType().getElement().isCompatible(e)) || |
| (C != null && !C.getType().getElement().isCompatible(e))) { |
| throw new RSRuntimeException("Called BLAS with wrong Element type"); |
| } |
| if (C != null) { |
| cX = C.getType().getY(); |
| cY = C.getType().getX(); |
| } |
| if (Side == RIGHT) { |
| if (B != null) { |
| bX = A.getType().getY(); |
| bY = A.getType().getX(); |
| } |
| if (A != null) { |
| aX = B.getType().getY(); |
| aY = B.getType().getX(); |
| } |
| } else { |
| if (A != null) { |
| if (TransA == TRANSPOSE) { |
| aY = A.getType().getY(); |
| aX = A.getType().getX(); |
| } else { |
| aX = A.getType().getY(); |
| aY = A.getType().getX(); |
| } |
| } |
| if (B != null) { |
| if (TransB == TRANSPOSE) { |
| bY = B.getType().getY(); |
| bX = B.getType().getX(); |
| } else { |
| bX = B.getType().getY(); |
| bY = B.getType().getX(); |
| } |
| } |
| } |
| if (A != null && B != null && C != null) { |
| if (aY != bX || aX != cX || bY != cY) { |
| throw new RSRuntimeException("Called BLAS with invalid dimensions"); |
| } |
| } else if (A != null && C != null) { |
| // A and C only |
| if (aX != cY || aY != cX) { |
| throw new RSRuntimeException("Called BLAS with invalid dimensions"); |
| } |
| } else if (A != null && B != null) { |
| // A and B only |
| } |
| |
| } |
| |
| public void SGEMM(@Transpose int TransA, @Transpose int TransB, float alpha, Allocation A, |
| Allocation B, float beta, Allocation C) { |
| validateTranspose(TransA); |
| validateTranspose(TransB); |
| validateL3(Element.F32(mRS), TransA, TransB, 0, A, B, C); |
| |
| int M = -1, N = -1, K = -1; |
| if (TransA == TRANSPOSE) { |
| M = A.getType().getX(); |
| K = A.getType().getY(); |
| } else { |
| M = A.getType().getY(); |
| K = A.getType().getX(); |
| } |
| if (TransB == TRANSPOSE) { |
| N = B.getType().getY(); |
| } else { |
| N = B.getType().getX(); |
| } |
| mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_sgemm, TransA, TransB, 0, 0, 0, M, N, K, alpha, A.getID(mRS), B.getID(mRS), |
| beta, C.getID(mRS), 0, 0, 0, 0); |
| } |
| public void DGEMM(@Transpose int TransA, @Transpose int TransB, double alpha, Allocation A, |
| Allocation B, double beta, Allocation C) { |
| validateTranspose(TransA); |
| validateTranspose(TransB); |
| validateL3(Element.F64(mRS), TransA, TransB, 0, A, B, C); |
| int M = -1, N = -1, K = -1; |
| if (TransA == TRANSPOSE) { |
| M = A.getType().getX(); |
| K = A.getType().getY(); |
| } else { |
| M = A.getType().getY(); |
| K = A.getType().getX(); |
| } |
| if (TransB == TRANSPOSE) { |
| N = B.getType().getY(); |
| } else { |
| N = B.getType().getX(); |
| } |
| mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dgemm, TransA, TransB, 0, 0, 0, M, N, K, alpha, A.getID(mRS), B.getID(mRS), |
| beta, C.getID(mRS), 0, 0, 0, 0); |
| } |
| public void CGEMM(@Transpose int TransA, @Transpose int TransB, Float2 alpha, Allocation A, |
| Allocation B, Float2 beta, Allocation C) { |
| validateTranspose(TransA); |
| validateTranspose(TransB); |
| validateL3(Element.F32_2(mRS), TransA, TransB, 0, A, B, C); |
| int M = -1, N = -1, K = -1; |
| if (TransA == TRANSPOSE) { |
| M = A.getType().getX(); |
| K = A.getType().getY(); |
| } else { |
| M = A.getType().getY(); |
| K = A.getType().getX(); |
| } |
| if (TransB == TRANSPOSE) { |
| N = B.getType().getY(); |
| } else { |
| N = B.getType().getX(); |
| } |
| mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cgemm, TransA, TransB, 0, 0, 0, M, N, K, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), |
| beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); |
| } |
| |
| public void ZGEMM(@Transpose int TransA, @Transpose int TransB, Double2 alpha, Allocation A, |
| Allocation B, Double2 beta, Allocation C) { |
| validateTranspose(TransA); |
| validateTranspose(TransB); |
| validateL3(Element.F64_2(mRS), TransA, TransB, 0, A, B, C); |
| int M = -1, N = -1, K = -1; |
| if (TransA == TRANSPOSE) { |
| M = A.getType().getX(); |
| K = A.getType().getY(); |
| } else { |
| M = A.getType().getY(); |
| K = A.getType().getX(); |
| } |
| if (TransB == TRANSPOSE) { |
| N = B.getType().getY(); |
| } else { |
| N = B.getType().getX(); |
| } |
| mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zgemm, TransA, TransB, 0, 0, 0, M, N, K, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), |
| beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); |
| } |
| |
| public void SSYMM(@Side int Side, @Uplo int Uplo, float alpha, Allocation A, |
| Allocation B, float beta, Allocation C) { |
| validateSide(Side); |
| validateUplo(Uplo); |
| validateL3(Element.F32(mRS), 0, 0, Side, A, B, C); |
| mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha, A.getID(mRS), B.getID(mRS), |
| beta, C.getID(mRS), 0, 0, 0, 0); |
| } |
| public void DSYMM(@Side int Side, @Uplo int Uplo, double alpha, Allocation A, |
| Allocation B, double beta, Allocation C) { |
| validateSide(Side); |
| validateUplo(Uplo); |
| validateL3(Element.F64(mRS), 0, 0, Side, A, B, C); |
| mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha, A.getID(mRS), B.getID(mRS), |
| beta, C.getID(mRS), 0, 0, 0, 0); |
| } |
| public void CSYMM(@Side int Side, @Uplo int Uplo, Float2 alpha, Allocation A, |
| Allocation B, Float2 beta, Allocation C) { |
| validateSide(Side); |
| validateUplo(Uplo); |
| validateL3(Element.F32_2(mRS), 0, 0, Side, A, B, C); |
| mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_csymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), |
| beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); |
| } |
| public void ZSYMM(@Side int Side, @Uplo int Uplo, Double2 alpha, Allocation A, |
| Allocation B, Double2 beta, Allocation C) { |
| validateSide(Side); |
| validateUplo(Uplo); |
| validateL3(Element.F64_2(mRS), 0, 0, Side, A, B, C); |
| mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zsymm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), |
| beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); |
| } |
| |
| public void SSYRK(@Uplo int Uplo, @Transpose int Trans, float alpha, Allocation A, float beta, Allocation C) { |
| validateTranspose(Trans); |
| validateUplo(Uplo); |
| validateL3(Element.F32(mRS), Trans, 0, 0, A, null, C); |
| int K = -1; |
| if (Trans == TRANSPOSE) { |
| K = A.getType().getY(); |
| } else { |
| K = A.getType().getX(); |
| } |
| |
| mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha, A.getID(mRS), 0, beta, C.getID(mRS), 0, 0, 0, 0); |
| } |
| |
| public void DSYRK(@Uplo int Uplo, @Transpose int Trans, double alpha, Allocation A, double beta, Allocation C) { |
| validateTranspose(Trans); |
| validateUplo(Uplo); |
| validateL3(Element.F64(mRS), Trans, 0, 0, A, null, C); |
| int K = -1; |
| if (Trans == TRANSPOSE) { |
| K = A.getType().getY(); |
| } else { |
| K = A.getType().getX(); |
| } |
| mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_dsyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha, A.getID(mRS), 0, beta, C.getID(mRS), 0, 0, 0, 0); |
| } |
| public void CSYRK(@Uplo int Uplo, @Transpose int Trans, float alphaX, float alphaY, Allocation A, float betaX, float betaY, Allocation C) { |
| validateTranspose(Trans); |
| validateUplo(Uplo); |
| validateL3(Element.F32_2(mRS), Trans, 0, 0, A, null, C); |
| int K = -1; |
| if (Trans == TRANSPOSE) { |
| K = A.getType().getY(); |
| } else { |
| K = A.getType().getX(); |
| } |
| mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_csyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alphaX, alphaY, A.getID(mRS), 0, betaX, betaY, |
| C.getID(mRS), 0, 0, 0, 0); |
| } |
| public void ZSYRK(@Uplo int Uplo, @Transpose int Trans, double alphaX, double alphaY, Allocation A, double betaX, double betaY, Allocation C) { |
| validateTranspose(Trans); |
| validateUplo(Uplo); |
| validateL3(Element.F64_2(mRS), Trans, 0, 0, A, null, C); |
| int K = -1; |
| if (Trans == TRANSPOSE) { |
| K = A.getType().getY(); |
| } else { |
| K = A.getType().getX(); |
| } |
| mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zsyrk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alphaX, alphaY, A.getID(mRS), 0, betaX, betaY, |
| C.getID(mRS), 0, 0, 0, 0); |
| } |
| |
| static void validateSYR2K(Element e, @Transpose int Trans, Allocation A, Allocation B, Allocation C) { |
| validateTranspose(Trans); |
| if (!A.getType().getElement().isCompatible(e) || |
| !B.getType().getElement().isCompatible(e) || |
| !C.getType().getElement().isCompatible(e)) { |
| throw new RSRuntimeException("Called BLAS with wrong Element type"); |
| } |
| int Cdim = -1; |
| // A is n x k if no transpose, k x n if transpose |
| // C is n x n |
| if (Trans == TRANSPOSE) { |
| // check columns versus C |
| Cdim = A.getType().getX(); |
| } else { |
| // check rows versus C |
| Cdim = A.getType().getY(); |
| } |
| if (C.getType().getX() != Cdim && C.getType().getY() != Cdim) { |
| throw new RSRuntimeException("Invalid symmetric matrix in SYR2K"); |
| } |
| // A dims == B dims |
| if (A.getType().getX() != B.getType().getX() || A.getType().getY() != B.getType().getY()) { |
| throw new RSRuntimeException("Invalid A and B in SYR2K"); |
| } |
| } |
| public void SSYR2K(@Uplo int Uplo, @Transpose int Trans, float alpha, Allocation A, Allocation B, float beta, Allocation C) { |
| validateUplo(Uplo); |
| validateSYR2K(Element.F32(mRS), Trans, A, B, C); |
| int K = -1; |
| if (Trans == TRANSPOSE) { |
| K = A.getType().getY(); |
| } else { |
| K = A.getType().getX(); |
| } |
| mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_ssyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha, A.getID(mRS), B.getID(mRS), beta, C.getID(mRS), 0, 0, 0, 0); |
| } |
| public void DSYR2K(@Uplo int Uplo, @Transpose int Trans, double alpha, Allocation A, Allocation B, double beta, Allocation C) { |
| validateUplo(Uplo); |
| validateSYR2K(Element.F64(mRS), Trans, A, B, C); |
| int K = -1; |
| if (Trans == TRANSPOSE) { |
| K = A.getType().getY(); |
| } else { |
| K = A.getType().getX(); |
| } |
| mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_ssyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha, A.getID(mRS), B.getID(mRS), beta, C.getID(mRS), 0, 0, 0, 0); |
| } |
| public void CSYR2K(@Uplo int Uplo, @Transpose int Trans, Float2 alpha, Allocation A, Allocation B, Float2 beta, Allocation C) { |
| validateUplo(Uplo); |
| validateSYR2K(Element.F32_2(mRS), Trans, A, B, C); |
| int K = -1; |
| if (Trans == TRANSPOSE) { |
| K = A.getType().getY(); |
| } else { |
| K = A.getType().getX(); |
| } |
| mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_ssyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); |
| } |
| public void ZSYR2K(@Uplo int Uplo, @Transpose int Trans, Double2 alpha, Allocation A, Allocation B, Double2 beta, Allocation C) { |
| validateUplo(Uplo); |
| validateSYR2K(Element.F64_2(mRS), Trans, A, B, C); |
| int K = -1; |
| if (Trans == TRANSPOSE) { |
| K = A.getType().getY(); |
| } else { |
| K = A.getType().getX(); |
| } |
| mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_ssyr2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), K, alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), beta.x, beta.y, C.getID(mRS), 0, 0, 0, 0); |
| } |
| |
| static void validateTRMM(Element e, @Side int Side, @Transpose int TransA, Allocation A, Allocation B) { |
| validateSide(Side); |
| validateTranspose(TransA); |
| int aX = -1, aY = -1, bX = -1, bY = -1; |
| if (!A.getType().getElement().isCompatible(e) || |
| !B.getType().getElement().isCompatible(e)) { |
| throw new RSRuntimeException("Called BLAS with wrong Element type"); |
| } |
| if (TransA == TRANSPOSE) { |
| aY = A.getType().getY(); |
| aX = A.getType().getX(); |
| } else { |
| aY = A.getType().getX(); |
| aX = A.getType().getY(); |
| } |
| bX = B.getType().getY(); |
| bY = B.getType().getX(); |
| if (Side == LEFT) { |
| if (aX == 0 || aY != bX) { |
| throw new RSRuntimeException("Called TRMM with invalid matrices"); |
| } |
| } else { |
| if (bY != aX || aY == 0) { |
| throw new RSRuntimeException("Called TRMM with invalid matrices"); |
| } |
| } |
| } |
| public void STRMM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, float alpha, Allocation A, Allocation B) { |
| validateUplo(Uplo); |
| validateDiag(Diag); |
| validateTRMM(Element.F32(mRS), Side, TransA, A, B); |
| mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0, |
| alpha, A.getID(mRS), B.getID(mRS), 0.f, 0, 0, 0, 0, 0); |
| } |
| public void DTRMM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, double alpha, Allocation A, Allocation B) { |
| validateUplo(Uplo); |
| validateDiag(Diag); |
| validateTRMM(Element.F64(mRS), Side, TransA, A, B); |
| mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_strmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0, |
| alpha, A.getID(mRS), B.getID(mRS), 0.f, 0, 0, 0, 0, 0); |
| } |
| public void CTRMM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Float2 alpha, Allocation A, Allocation B) { |
| validateUplo(Uplo); |
| validateDiag(Diag); |
| validateTRMM(Element.F32_2(mRS), Side, TransA, A, B); |
| mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_strmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0, |
| alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0, 0); |
| } |
| public void ZTRMM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Double2 alpha, Allocation A, Allocation B) { |
| validateUplo(Uplo); |
| validateDiag(Diag); |
| validateTRMM(Element.F64_2(mRS), Side, TransA, A, B); |
| mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_strmm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0, |
| alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0, 0); |
| } |
| |
| static void validateTRSM(Element e, @Side int Side, @Transpose int TransA, Allocation A, Allocation B) { |
| int adim = -1, bX = -1, bY = -1; |
| validateSide(Side); |
| validateTranspose(TransA); |
| if (!A.getType().getElement().isCompatible(e) || |
| !B.getType().getElement().isCompatible(e)) { |
| throw new RSRuntimeException("Called BLAS with wrong Element type"); |
| } |
| adim = A.getType().getX(); |
| if (adim != A.getType().getY()) { |
| // this may be unnecessary, the restriction could potentially be relaxed |
| // A needs to contain at least that symmetric matrix but could theoretically be larger |
| // for now we assume adapters are sufficient, will reevaluate in the future |
| throw new RSRuntimeException("Called TRSM with a non-symmetric matrix A"); |
| } |
| bX = B.getType().getY(); |
| bY = B.getType().getX(); |
| if (Side == LEFT) { |
| // A is M*M |
| if (adim != bY) { |
| throw new RSRuntimeException("Called TRSM with invalid matrix dimensions"); |
| } |
| } else { |
| // A is N*N |
| if (adim != bX) { |
| throw new RSRuntimeException("Called TRSM with invalid matrix dimensions"); |
| } |
| } |
| } |
| public void STRSM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, float alpha, Allocation A, Allocation B) { |
| validateUplo(Uplo); |
| validateDiag(Diag); |
| validateTRSM(Element.F32(mRS), Side, TransA, A, B); |
| mRS.nScriptIntrinsicBLAS_Single(getID(mRS), RsBlas_strsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0, |
| alpha, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0); |
| } |
| public void DTRSM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, double alpha, Allocation A, Allocation B) { |
| validateUplo(Uplo); |
| validateDiag(Diag); |
| validateTRSM(Element.F64(mRS), Side, TransA, A, B); |
| mRS.nScriptIntrinsicBLAS_Double(getID(mRS), RsBlas_strsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0, |
| alpha, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0); |
| } |
| public void CTRSM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Float2 alpha, Allocation A, Allocation B) { |
| validateUplo(Uplo); |
| validateDiag(Diag); |
| validateTRSM(Element.F32_2(mRS), Side, TransA, A, B); |
| mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_strsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0, |
| alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0, 0); |
| } |
| public void ZTRSM(@Side int Side, @Uplo int Uplo, @Transpose int TransA, @Diag int Diag, Double2 alpha, Allocation A, Allocation B) { |
| validateUplo(Uplo); |
| validateDiag(Diag); |
| validateTRSM(Element.F64_2(mRS), Side, TransA, A, B); |
| mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_strsm, TransA, 0, Side, Uplo, Diag, B.getType().getY(), B.getType().getX(), 0, |
| alpha.x, alpha.y, A.getID(mRS), B.getID(mRS), 0, 0, 0, 0, 0, 0, 0); |
| } |
| |
| static void validateHEMM(Element e, @Side int Side, Allocation A, Allocation B, Allocation C) { |
| validateSide(Side); |
| |
| if (!A.getType().getElement().isCompatible(e) || |
| !B.getType().getElement().isCompatible(e) || |
| !C.getType().getElement().isCompatible(e)) { |
| throw new RSRuntimeException("Called BLAS with wrong Element type"); |
| } |
| |
| // A must be square; can potentially be relaxed similar to TRSM |
| int adim = A.getType().getX(); |
| if (adim != A.getType().getY()) { |
| throw new RSRuntimeException("Called HEMM with non-square A"); |
| } |
| if ((Side == LEFT && adim != B.getType().getY()) || |
| (Side == RIGHT && adim != B.getType().getX())) { |
| throw new RSRuntimeException("Called HEMM with invalid B"); |
| } |
| if (B.getType().getX() != C.getType().getX() || |
| B.getType().getY() != C.getType().getY()) { |
| throw new RSRuntimeException("Called HEMM with mismatched B and C"); |
| } |
| } |
| public void CHEMM(@Side int Side, @Uplo int Uplo, float alpha, Allocation A, Allocation B, float beta, Allocation C) { |
| validateUplo(Uplo); |
| validateHEMM(Element.F32_2(mRS), Side, A, B, C); |
| mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_chemm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, |
| alpha, 0, A.getID(mRS), B.getID(mRS), beta, 0, C.getID(mRS), 0, 0, 0, 0); |
| } |
| public void ZHEMM(@Side int Side, @Uplo int Uplo, double alpha, Allocation A, Allocation B, double beta, Allocation C) { |
| validateUplo(Uplo); |
| validateHEMM(Element.F32_2(mRS), Side, A, B, C); |
| mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zhemm, 0, 0, Side, Uplo, 0, C.getType().getY(), C.getType().getX(), 0, |
| alpha, 0, A.getID(mRS), B.getID(mRS), beta, 0, C.getID(mRS), 0, 0, 0, 0); |
| } |
| |
| static void validateHERK(Element e, @Transpose int Trans, Allocation A, Allocation C) { |
| if (!A.getType().getElement().isCompatible(e) || |
| !C.getType().getElement().isCompatible(e)) { |
| throw new RSRuntimeException("Called BLAS with wrong Element type"); |
| } |
| validateConjTranspose(Trans); |
| int cdim = C.getType().getX(); |
| if (cdim != C.getType().getY()) { |
| throw new RSRuntimeException("Called HERK with non-square C"); |
| } |
| if (Trans == NO_TRANSPOSE) { |
| if (cdim != A.getType().getX()) { |
| throw new RSRuntimeException("Called HERK with invalid A"); |
| } |
| } else { |
| if (cdim != A.getType().getY()) { |
| throw new RSRuntimeException("Called HERK with invalid A"); |
| } |
| } |
| } |
| public void CHERK(@Uplo int Uplo, @Transpose int Trans, float alpha, Allocation A, float beta, Allocation C) { |
| validateUplo(Uplo); |
| validateHERK(Element.F32_2(mRS), Trans, A, C); |
| int k = 0; |
| if (Trans == TRANSPOSE) { |
| k = A.getType().getY(); |
| } else { |
| k = A.getType().getX(); |
| } |
| mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cherk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), k, |
| alpha, 0, A.getID(mRS), 0, beta, 0, C.getID(mRS), 0, 0, 0, 0); |
| } |
| public void ZHERK(@Uplo int Uplo, @Transpose int Trans, double alpha, Allocation A, double beta, Allocation C) { |
| validateUplo(Uplo); |
| validateHERK(Element.F64_2(mRS), Trans, A, C); |
| int k = 0; |
| if (Trans == TRANSPOSE) { |
| k = A.getType().getY(); |
| } else { |
| k = A.getType().getX(); |
| } |
| mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zherk, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), k, |
| alpha, 0, A.getID(mRS), 0, beta, 0, C.getID(mRS), 0, 0, 0, 0); |
| } |
| |
| static void validateHER2K(Element e, @Transpose int Trans, Allocation A, Allocation B, Allocation C) { |
| if (!A.getType().getElement().isCompatible(e) || |
| !B.getType().getElement().isCompatible(e) || |
| !C.getType().getElement().isCompatible(e)) { |
| throw new RSRuntimeException("Called BLAS with wrong Element type"); |
| } |
| validateConjTranspose(Trans); |
| int cdim = C.getType().getX(); |
| if (cdim != C.getType().getY()) { |
| throw new RSRuntimeException("Called HER2K with non-square C"); |
| } |
| if (Trans == NO_TRANSPOSE) { |
| if (A.getType().getY() != cdim) { |
| throw new RSRuntimeException("Called HER2K with invalid matrices"); |
| } |
| } else { |
| if (A.getType().getX() != cdim) { |
| throw new RSRuntimeException("Called HER2K with invalid matrices"); |
| } |
| } |
| if (A.getType().getX() != B.getType().getX() || A.getType().getY() != B.getType().getY()) { |
| throw new RSRuntimeException("Called HER2K with invalid A and B matrices"); |
| } |
| } |
| public void CHER2K(@Uplo int Uplo, @Transpose int Trans, Float2 alpha, Allocation A, Allocation B, float beta, Allocation C) { |
| validateUplo(Uplo); |
| validateHER2K(Element.F32_2(mRS), Trans, A, B, C); |
| int k = 0; |
| if (Trans == NO_TRANSPOSE) { |
| k = A.getType().getX(); |
| } else { |
| k = A.getType().getY(); |
| } |
| mRS.nScriptIntrinsicBLAS_Complex(getID(mRS), RsBlas_cher2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), k, alpha.x, alpha.y, |
| A.getID(mRS), B.getID(mRS), beta, 0, C.getID(mRS), 0, 0, 0, 0); |
| } |
| public void ZHER2K(@Uplo int Uplo, @Transpose int Trans, Double2 alpha, Allocation A, Allocation B, double beta, Allocation C) { |
| validateUplo(Uplo); |
| validateHER2K(Element.F64_2(mRS), Trans, A, B, C); |
| int k = 0; |
| if (Trans == NO_TRANSPOSE) { |
| k = A.getType().getX(); |
| } else { |
| k = A.getType().getY(); |
| } |
| mRS.nScriptIntrinsicBLAS_Z(getID(mRS), RsBlas_zher2k, Trans, 0, 0, Uplo, 0, 0, C.getType().getX(), k, alpha.x, alpha.y, |
| A.getID(mRS), B.getID(mRS), beta, 0, C.getID(mRS), 0, 0, 0, 0); |
| } |
| |
| |
| /** |
| * |
| * 8-bit GEMM-like operation for neural networks |
| * |
| * @hide |
| **/ |
| public void BNNM(Allocation A, int a_offset, Allocation B, int b_offset, Allocation C, int c_offset, int c_mult) { |
| validateL3(Element.U8(mRS), NO_TRANSPOSE, TRANSPOSE, 0, A, B, C); |
| |
| int M = -1, N = -1, K = -1; |
| M = A.getType().getY(); |
| N = B.getType().getY(); |
| K = A.getType().getX(); |
| |
| |
| mRS.nScriptIntrinsicBLAS_BNNM(getID(mRS), M, N, K, A.getID(mRS), a_offset, B.getID(mRS), b_offset, C.getID(mRS), c_offset, c_mult); |
| |
| } |
| |
| } |