使用 Aparapi 将两个矩阵相乘 - 不适用于 3D 范围

问题描述 投票:0回答:1

目前正在开发用于矩阵运算和机器学习的 Java 库。 我使用 Aparapi 来利用 GPU。

我编写了这段代码来将两个矩阵相乘:

   public static NDmatrix matMul(float[][] a, float[][] b) {
        int[] aDim = new int[]{a.length, a[0].length};
        int[] bDim = new int[]{b.length, b[0].length};
        if(aDim[1] == bDim[0]){

            int[] Dim = new int[]{aDim[0], bDim[1]};
            int aVSize = aDim[0] * aDim[1];
            float[] aVector = new float[aVSize];
            for(int i = 0; i < aDim[0]; i++)
                System.arraycopy(a[i], 0, aVector, i * aDim[1], aDim[1]);
            int bVSize = bDim[0] * bDim[1];
            float[] bVector = new float[bVSize];
            for(int i = 0; i < bDim[0]; i++)
                System.arraycopy(b[i], 0, bVector, i * bDim[1], bDim[1]);
            int resVSize = Dim[0] * Dim[1];
            float[] resVector = new float[resVSize];
            int d[] = new int[]{aDim[1]};
            Kernel mKernel = new Kernel() {
                final int ht = Dim[0];
                final int wt = Dim[1];
                final int dpt = d[0];

                public void run() {
                    int c = getGlobalId(0);
                    int r = getGlobalId(1);
                    int l = getGlobalId(2);
                    localBarrier();
                    //for(int l = 0; l < dpt; l++)
                    resVector[r * wt + c] = resVector[r * wt + c] + aVector[r * dpt + l] * bVector[l * wt + c];
                }
            };
            mKernel.setExplicit(true);
            mKernel.put(aVector);
            mKernel.put(bVector);
            mKernel.put(resVector);
            mKernel.put(Dim);
            mKernel.put(d);
            mKernel.execute(Range.create3D(Dim[1], Dim[0], d[0]));
            //mKernel.execute(Range.create2D(Dim[1], Dim[0]));
            mKernel.get(resVector);
            mKernel.dispose();
            return new NDmatrix(Dim, resVector, null);
        }
        System.out.println("The number of columns in left matrix and number of rows in right matrix do not match.");
        System.out.println();
        return null;
    }

但是,看起来 resVector[some_index] 只更新一次。 相反,如果我使用 2D 范围和循环(代码中的注释位),则它可以正常工作。 这种行为的原因可能是什么?我如何强制它完全并行工作?

有趣的是,我尝试了一件“有效”的事情 - 更新 resVector[some_index] 后,我调用了 this.put(resVector)。 然而,它无法在 OpenCL 中编译,最终改用 Java 的多线程,最终得到了正确的结果。

java matrix aparapi
1个回答
0
投票

好吧,我想我找到了解决方案,虽然有点疯狂:

public static NDmatrix matMul(float[][] a, float[][] b) {
    int[] aDim = new int[]{a.length, a[0].length};
    int[] bDim = new int[]{b.length, b[0].length};
    if(aDim[1] == bDim[0]){

        int[] Dim = new int[]{aDim[0], bDim[1]};
        int aVSize = aDim[0] * aDim[1];
        float[] aVector = new float[aVSize];
        for(int i = 0; i < aDim[0]; i++)
            System.arraycopy(a[i], 0, aVector, i * aDim[1], aDim[1]);
        int bVSize = bDim[0] * bDim[1];
        float[] bVector = new float[bVSize];
        for(int i = 0; i < bDim[0]; i++)
            System.arraycopy(b[i], 0, bVector, i * bDim[1], bDim[1]);
        int resVSize = Dim[0] * Dim[1];
        float[] resVector = new float[resVSize];
        int d[] = new int[]{aDim[1]};
        return getMult(Dim, aVector, bVector, resVector, d);
    }
    System.out.println("The number of columns in left matrix and number of rows in right matrix do not match.");
    System.out.println();
    return null;
}

@NotNull
private static NDmatrix getMult(int[] Dim, float[] aVector, float[] bVector, float[] resVector, int[] d) {
    Kernel mKernel = new Kernel() {
        final int ht = Dim[0];
        final int wt = Dim[1];
        final int dpt = d[0];

        public void run() {
            int c = getGlobalId(0);
            int r = getGlobalId(1);
            int l = getGlobalId(2);
            localBarrier();
            for(int i = 0; i < dpt; i++) // really?!
                if(i == l)               // why, oh why do I have to do this...
                    resVector[r * wt + c] += aVector[r * dpt + i] * bVector[i * wt + c];
        }
    };
    mKernel.setExplicit(true);
    mKernel.put(aVector);
    mKernel.put(bVector);
    mKernel.put(resVector);
    mKernel.put(Dim);
    mKernel.put(d);
    mKernel.execute(Range.create3D(Dim[1], Dim[0], d[0]));
    mKernel.get(resVector);
    mKernel.dispose();
    return new NDmatrix(Dim, resVector, null);
}

不敢相信这样一个奇怪的循环会以某种方式迫使它做它的事情......

© www.soinside.com 2019 - 2024. All rights reserved.