there are so many cores

Just another WordPress.com site

Matrix multiply for the JIT

The new design for matrix multiply looks good. I’ve been testing on an ATI HD 5870 for several days, shaking out bugs. The code generation handles the following parameters:

  • M, N, and K matrix dimensions, optionally inlined
  • row and column data layout (transposed matrices)
  • single and double precision floating point (may be mixed together)
  • combinations of images and memory buffers of any vector length
  • square outer blocking in work group
  • rectangular inner blocking in registers (width and height)
  • inner product loop order
  • using global versus group/local index space ID functions
  • optimized for general and simple matrix multiply
  • number of matrix multiplies packed as tiles into kernel

Endogenous parameters for the expectation-maximization auto-tuning are in bold. The rest are treated as exogenous with respect to the statistical optimization. I tried making more parameters endogenous but found this invoked the “curse of dimensionality” and gave inconsistent results. For this problem, brute force search over the exogenous parameters with stable auto-tuning optimization for the rest is the more practical approach.

The current design is what I always wanted but was scared to do. It handles all combinations, mixtures of precisions, memory types, vector lengths, etc. I haven’t seen this done anywhere else. Optimized math kernel designs generally constrain the input and output arrays to specific types and formats.

Implementing this also gave me a better sense for why the auto-tuning converges. Why should the benchmark surface be convex?

Here’s my intuition. The optimization is really optimizing register use. To some degree (perhaps due to quirks of the shader compiler), every parameter mentioned above affects the amount of private GPU memory used. This affects both GPU thread scheduling and arithmetic intensity. If a kernel uses too many registers, there isn’t enough memory left on the GPU for efficient scheduling. If a kernel uses too few registers, then each ID is doing too little work and arithmetic intensity is low. Somewhere in the middle is the optimal number of registers.

This is why the expectation-maximization converges when auto-tuning.

I will have to make some pretty charts and graphs, maybe an animation, illustrating this for the conference.

Here’s an example kernel the JIT back-end can generate. It’s something tricky that would be very difficult to do by hand. It is ten 400x400x400 matrix multiplies with mixed precision arrays of images and memory buffers. Instead of the usual float4 quad, it uses float2 (sometimes faster). This runs at about 300 GFLOPS.

#pragma OPENCL EXTENSION cl_khr_fp64 : enable
__kernel void evergreenmatmul_4_8_4_2_0_0_10_0_400_400_400_1_0_10_10_4_4_23(
    __write_only image2d_t matrixC,
    __global const float2* matrixA,
    __read_only image2d_t matrixB,
    __local float2* localA)
{
    const sampler_t sampler = CLK_FILTER_NEAREST | CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE;
    __global const float2* ptrMatrixA;
    __local const float2* ptrLocalA;
    float4 accum00;
    float4 accum01;
    float4 accum02;
    float4 accum03;
    float2 valueA00;
    float2 valueA01;
    float2 valueA02;
    float2 valueA03;
    float2 valueA10;
    float2 valueA11;
    float2 valueA12;
    float2 valueA13;
    double2 valueB00;
    double2 valueB01;
    double2 valueB02;
    double2 valueB03;
    double2 valueB10;
    double2 valueB11;
    double2 valueB12;
    double2 valueB13;
    for (int packIdx = 0; packIdx < 10; packIdx++)
    {
        accum00 = (float4)(0);
        accum01 = (float4)(0);
        accum02 = (float4)(0);
        accum03 = (float4)(0);
        ptrMatrixA = (((matrixA + (packIdx * ((400 * 400) / 2))) + (get_global_id(1) * 2)) + (get_local_id(0) * (400 * 2)));
        for (int outerIdx = 0; outerIdx < (400 / 40); outerIdx++)
        {
            *(localA + (((get_local_id(1) * 44) + get_local_id(0)) * 2)) = *ptrMatrixA;
            *((localA + (((get_local_id(1) * 44) + get_local_id(0)) * 2)) + 1) = *(ptrMatrixA + 400);
            *(localA + ((((get_local_id(1) * 44) + get_local_id(0)) + 11) * 2)) = *(ptrMatrixA + (400 / 2));
            *((localA + ((((get_local_id(1) * 44) + get_local_id(0)) + 11) * 2)) + 1) = *((ptrMatrixA + (400 / 2)) + 400);
            *(localA + ((((get_local_id(1) * 44) + get_local_id(0)) + 22) * 2)) = *(ptrMatrixA + 1);
            *((localA + ((((get_local_id(1) * 44) + get_local_id(0)) + 22) * 2)) + 1) = *((ptrMatrixA + 1) + 400);
            *(localA + ((((get_local_id(1) * 44) + get_local_id(0)) + 33) * 2)) = *((ptrMatrixA + 1) + (400 / 2));
            *((localA + ((((get_local_id(1) * 44) + get_local_id(0)) + 33) * 2)) + 1) = *(((ptrMatrixA + 1) + (400 / 2)) + 400);
            barrier(CLK_LOCAL_MEM_FENCE);
            ptrMatrixA += (400 * 20);
            ptrLocalA = (localA + (get_local_id(1) * 88));
            for (int innerIdx = 0; innerIdx < 10; innerIdx++)
            {
                valueA00 = *ptrLocalA;
                valueA10 = *(ptrLocalA + 1);
                valueA01 = *(ptrLocalA + 22);
                valueA11 = *((ptrLocalA + 22) + 1);
                valueA02 = *(ptrLocalA + 44);
                valueA12 = *((ptrLocalA + 44) + 1);
                valueA03 = *(ptrLocalA + 66);
                valueA13 = *((ptrLocalA + 66) + 1);
                ptrLocalA += 2;
                valueB00 = as_double2(read_imageui(matrixB, sampler, (int2)((get_global_id(0) * 2), ((packIdx * 400) + (((outerIdx * 10) + innerIdx) * 4)))));
                valueB10 = as_double2(read_imageui(matrixB, sampler, (int2)(((get_global_id(0) * 2) + 1), ((packIdx * 400) + (((outerIdx * 10) + innerIdx) * 4)))));
                valueB01 = as_double2(read_imageui(matrixB, sampler, (int2)((get_global_id(0) * 2), (((packIdx * 400) + (((outerIdx * 10) + innerIdx) * 4)) + 1))));
                valueB11 = as_double2(read_imageui(matrixB, sampler, (int2)(((get_global_id(0) * 2) + 1), (((packIdx * 400) + (((outerIdx * 10) + innerIdx) * 4)) + 1))));
                valueB02 = as_double2(read_imageui(matrixB, sampler, (int2)((get_global_id(0) * 2), (((packIdx * 400) + (((outerIdx * 10) + innerIdx) * 4)) + 2))));
                valueB12 = as_double2(read_imageui(matrixB, sampler, (int2)(((get_global_id(0) * 2) + 1), (((packIdx * 400) + (((outerIdx * 10) + innerIdx) * 4)) + 2))));
                valueB03 = as_double2(read_imageui(matrixB, sampler, (int2)((get_global_id(0) * 2), (((packIdx * 400) + (((outerIdx * 10) + innerIdx) * 4)) + 3))));
                valueB13 = as_double2(read_imageui(matrixB, sampler, (int2)(((get_global_id(0) * 2) + 1), (((packIdx * 400) + (((outerIdx * 10) + innerIdx) * 4)) + 3))));
                accum00.s0 = mad(valueA00.s0, valueB00.s0, accum00.s0);
                accum01.s0 = mad(valueA00.s1, valueB00.s0, accum01.s0);
                accum02.s0 = mad(valueA02.s0, valueB00.s0, accum02.s0);
                accum03.s0 = mad(valueA02.s1, valueB00.s0, accum03.s0);
                accum00.s1 = mad(valueA00.s0, valueB00.s1, accum00.s1);
                accum01.s1 = mad(valueA00.s1, valueB00.s1, accum01.s1);
                accum02.s1 = mad(valueA02.s0, valueB00.s1, accum02.s1);
                accum03.s1 = mad(valueA02.s1, valueB00.s1, accum03.s1);
                accum00.s2 = mad(valueA00.s0, valueB10.s0, accum00.s2);
                accum01.s2 = mad(valueA00.s1, valueB10.s0, accum01.s2);
                accum02.s2 = mad(valueA02.s0, valueB10.s0, accum02.s2);
                accum03.s2 = mad(valueA02.s1, valueB10.s0, accum03.s2);
                accum00.s3 = mad(valueA00.s0, valueB10.s1, accum00.s3);
                accum01.s3 = mad(valueA00.s1, valueB10.s1, accum01.s3);
                accum02.s3 = mad(valueA02.s0, valueB10.s1, accum02.s3);
                accum03.s3 = mad(valueA02.s1, valueB10.s1, accum03.s3);
                accum00.s0 = mad(valueA01.s0, valueB01.s0, accum00.s0);
                accum01.s0 = mad(valueA01.s1, valueB01.s0, accum01.s0);
                accum02.s0 = mad(valueA03.s0, valueB01.s0, accum02.s0);
                accum03.s0 = mad(valueA03.s1, valueB01.s0, accum03.s0);
                accum00.s1 = mad(valueA01.s0, valueB01.s1, accum00.s1);
                accum01.s1 = mad(valueA01.s1, valueB01.s1, accum01.s1);
                accum02.s1 = mad(valueA03.s0, valueB01.s1, accum02.s1);
                accum03.s1 = mad(valueA03.s1, valueB01.s1, accum03.s1);
                accum00.s2 = mad(valueA01.s0, valueB11.s0, accum00.s2);
                accum01.s2 = mad(valueA01.s1, valueB11.s0, accum01.s2);
                accum02.s2 = mad(valueA03.s0, valueB11.s0, accum02.s2);
                accum03.s2 = mad(valueA03.s1, valueB11.s0, accum03.s2);
                accum00.s3 = mad(valueA01.s0, valueB11.s1, accum00.s3);
                accum01.s3 = mad(valueA01.s1, valueB11.s1, accum01.s3);
                accum02.s3 = mad(valueA03.s0, valueB11.s1, accum02.s3);
                accum03.s3 = mad(valueA03.s1, valueB11.s1, accum03.s3);
                accum00.s0 = mad(valueA10.s0, valueB02.s0, accum00.s0);
                accum01.s0 = mad(valueA10.s1, valueB02.s0, accum01.s0);
                accum02.s0 = mad(valueA12.s0, valueB02.s0, accum02.s0);
                accum03.s0 = mad(valueA12.s1, valueB02.s0, accum03.s0);
                accum00.s1 = mad(valueA10.s0, valueB02.s1, accum00.s1);
                accum01.s1 = mad(valueA10.s1, valueB02.s1, accum01.s1);
                accum02.s1 = mad(valueA12.s0, valueB02.s1, accum02.s1);
                accum03.s1 = mad(valueA12.s1, valueB02.s1, accum03.s1);
                accum00.s2 = mad(valueA10.s0, valueB12.s0, accum00.s2);
                accum01.s2 = mad(valueA10.s1, valueB12.s0, accum01.s2);
                accum02.s2 = mad(valueA12.s0, valueB12.s0, accum02.s2);
                accum03.s2 = mad(valueA12.s1, valueB12.s0, accum03.s2);
                accum00.s3 = mad(valueA10.s0, valueB12.s1, accum00.s3);
                accum01.s3 = mad(valueA10.s1, valueB12.s1, accum01.s3);
                accum02.s3 = mad(valueA12.s0, valueB12.s1, accum02.s3);
                accum03.s3 = mad(valueA12.s1, valueB12.s1, accum03.s3);
                accum00.s0 = mad(valueA11.s0, valueB03.s0, accum00.s0);
                accum01.s0 = mad(valueA11.s1, valueB03.s0, accum01.s0);
                accum02.s0 = mad(valueA13.s0, valueB03.s0, accum02.s0);
                accum03.s0 = mad(valueA13.s1, valueB03.s0, accum03.s0);
                accum00.s1 = mad(valueA11.s0, valueB03.s1, accum00.s1);
                accum01.s1 = mad(valueA11.s1, valueB03.s1, accum01.s1);
                accum02.s1 = mad(valueA13.s0, valueB03.s1, accum02.s1);
                accum03.s1 = mad(valueA13.s1, valueB03.s1, accum03.s1);
                accum00.s2 = mad(valueA11.s0, valueB13.s0, accum00.s2);
                accum01.s2 = mad(valueA11.s1, valueB13.s0, accum01.s2);
                accum02.s2 = mad(valueA13.s0, valueB13.s0, accum02.s2);
                accum03.s2 = mad(valueA13.s1, valueB13.s0, accum03.s2);
                accum00.s3 = mad(valueA11.s0, valueB13.s1, accum00.s3);
                accum01.s3 = mad(valueA11.s1, valueB13.s1, accum01.s3);
                accum02.s3 = mad(valueA13.s0, valueB13.s1, accum02.s3);
                accum03.s3 = mad(valueA13.s1, valueB13.s1, accum03.s3);
            }
        }
        write_imagef(matrixC, (int2)(get_global_id(0), ((packIdx * 400) + (get_global_id(1) * 4))), accum00);
        write_imagef(matrixC, (int2)(get_global_id(0), (((packIdx * 400) + (get_global_id(1) * 4)) + 1)), accum01);
        write_imagef(matrixC, (int2)(get_global_id(0), (((packIdx * 400) + (get_global_id(1) * 4)) + 2)), accum02);
        write_imagef(matrixC, (int2)(get_global_id(0), (((packIdx * 400) + (get_global_id(1) * 4)) + 3)), accum03);
    }
}
Advertisements

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s

%d bloggers like this: