如何使用Metal Shader Lanuage计算16通道图像的均值和方差值

问题描述 投票:0回答:2

如何使用Metal计算16通道图像的均值和方差值?

我想分别计算不同通道的均值和方差值!

例如:

kernel void meanandvariance(texture2d_array<float, access::read> in[[texture(0)]],
                          texture2d_array<float, access::write> out[[texture(1)]],

                          ushort3 gid[[thread_position_in_grid]],
                          ushort tid[[thread_index_in_threadgroup]],
                          ushort3 tg_size[[threads_per_threadgroup]]) {

                          }


metal metalkit metal-performance-shaders
2个回答
0
投票

[通过在输入纹理数组和输出纹理数组上创建纹理视图的序列,为每个切片编码MPSImageStatisticsMeanAndVariance内核调用,可能有一种方法。

但是让我们看看如何自己做。可能的方法有很多,因此我选择了一种简单的方法,并使用了一些有趣的统计结果。

基本上,我们将执行以下操作:

  1. 编写一个可以为图像的单行产生子集均值和方差的内核。
  2. 编写一个可以根据步骤1的部分结果产生整体均值和方差的内核。

这里是内核:

kernel void compute_row_mean_variance_array(texture2d_array<float, access::read> inTexture [[texture(0)]],
                                            texture2d_array<float, access::write> outTexture [[texture(1)]],
                                            uint3 tpig [[thread_position_in_grid]])
{
    uint row = tpig.x;
    uint slice = tpig.y;
    uint width = inTexture.get_width();

    if (row >= inTexture.get_height() || slice >= inTexture.get_array_size()) { return; }

    float4 mean(0.0f);
    float4 var(0.0f);
    for (uint col = 0; col < width; ++col) {
        float4 rgba = inTexture.read(ushort2(col, row), slice);
        // http://datagenetics.com/blog/november22017/index.html
        float weight = 1.0f / (col + 1);
        float4 oldMean = mean;
        mean = mean + (rgba - mean) * weight;
        var = var + (rgba - oldMean) * (rgba - mean);
    }

    var = var / width;

    outTexture.write(mean, ushort2(row, 0), slice);
    outTexture.write(var, ushort2(row, 1), slice);
}

kernel void reduce_mean_variance_array(texture2d_array<float, access::read> inTexture [[texture(0)]],
                                       texture2d_array<float, access::write> outTexture [[texture(1)]],
                                       uint3 tpig [[thread_position_in_grid]])
{
    uint width = inTexture.get_width();
    uint slice = tpig.x;

    // https://arxiv.org/pdf/1007.1012.pdf
    float4 mean(0.0f);
    float4 meanOfVar(0.0f);
    float4 varOfMean(0.0f);
    for (uint col = 0; col < width; ++col) {
        float weight = 1.0f / (col + 1);

        float4 oldMean = mean;
        float4 submean = inTexture.read(ushort2(col, 0), slice);
        mean = mean + (submean - mean) * weight;

        float4 subvar = inTexture.read(ushort2(col, 1), slice);
        meanOfVar = meanOfVar +  (subvar - meanOfVar) * weight;

        varOfMean = varOfMean + (submean - oldMean) * (submean - mean);
    }
    float4 var = meanOfVar + varOfMean / width;

    outTexture.write(mean, ushort2(0, 0), slice);
    outTexture.write(var, ushort2(1, 0), slice);
}

总而言之,要实现第1步,我们使用“在线”(增量)算法来计算行的部分均值/方差,其方法在数值上比仅添加所有像素值并除以宽度更稳定。 。我写这个内核的参考是this post。网格中的每个线程将其行的统计信息写入中间纹理数组的相应列和切片。

要实现第2步,我们需要找到一种统计上合理的方法,可以从部分结果中计算出总体统计数据。在寻找均值的情况下,这非常简单:总体均值是子集均值的均值(当每个子集的样本量相同时,这成立;在一般情况下,总体均值是子集平均值的加权和)。方差比较棘手,但是turns out总体方差是子集方差的均值和子集均值方差的总和(适用于均等大小子集的相同警告)这里)。这是一个方便的事实,我们可以与上面的增量方法结合使用以产生每个切片的最终均值和方差,并将其写入输出纹理的相应切片中。]

为了完整起见,这是我用来驱动这些内核的Swift代码:

let library = device.makeDefaultLibrary()!

let meanVarKernelFunction = library.makeFunction(name: "compute_row_mean_variance_array")!
let meanVarComputePipelineState = try! device.makeComputePipelineState(function: meanVarKernelFunction)

let reduceKernelFunction = library.makeFunction(name: "reduce_mean_variance_array")!
let reduceComputePipelineState = try! device.makeComputePipelineState(function: reduceKernelFunction)

let width = sourceTexture.width
let height = sourceTexture.height
let arrayLength = sourceTexture.arrayLength

let textureDescriptor = MTLTextureDescriptor.texture2DDescriptor(pixelFormat: .rgba32Float, width: width, height: height, mipmapped: false)
textureDescriptor.textureType = .type2DArray
textureDescriptor.arrayLength = arrayLength
textureDescriptor.width = height
textureDescriptor.height = 2
textureDescriptor.usage = [.shaderRead, .shaderWrite]

let partialResultsTexture = device.makeTexture(descriptor: textureDescriptor)!

textureDescriptor.width = 2
textureDescriptor.height = 1
textureDescriptor.usage = .shaderWrite

let destTexture = device.makeTexture(descriptor: textureDescriptor)!

let commandBuffer = commandQueue.makeCommandBuffer()!

let computeCommandEncoder = commandBuffer.makeComputeCommandEncoder()!

computeCommandEncoder.setComputePipelineState(meanVarComputePipelineState)
computeCommandEncoder.setTexture(sourceTexture, index: 0)
computeCommandEncoder.setTexture(partialResultsTexture, index: 1)
let meanVarGridSize = MTLSize(width: sourceTexture.height, height: sourceTexture.arrayLength, depth: 1)
let meanVarThreadgroupSize = MTLSizeMake(meanVarComputePipelineState.threadExecutionWidth, 1, 1)
let meanVarThreadgroupCount = MTLSizeMake((meanVarGridSize.width + meanVarThreadgroupSize.width - 1) / meanVarThreadgroupSize.width,
                                          (meanVarGridSize.height + meanVarThreadgroupSize.height - 1) / meanVarThreadgroupSize.height,
                                          1)
computeCommandEncoder.dispatchThreadgroups(meanVarThreadgroupCount, threadsPerThreadgroup: meanVarThreadgroupSize)

computeCommandEncoder.setComputePipelineState(reduceComputePipelineState)
computeCommandEncoder.setTexture(partialResultsTexture, index: 0)
computeCommandEncoder.setTexture(destTexture, index: 1)
let reduceThreadgroupSize = MTLSizeMake(1, 1, 1)
let reduceThreadgroupCount = MTLSizeMake(arrayLength, 1, 1)
computeCommandEncoder.dispatchThreadgroups(reduceThreadgroupCount, threadsPerThreadgroup: reduceThreadgroupSize)

computeCommandEncoder.endEncoding()

let destTexture2DDesc = MTLTextureDescriptor.texture2DDescriptor(pixelFormat: .rgba32Float, width: 2, height: 1, mipmapped: false)
destTexture2DDesc.usage = .shaderWrite
let destTexture2D = device.makeTexture(descriptor: destTexture2DDesc)!

meanVarKernel.encode(commandBuffer: commandBuffer, sourceTexture: sourceTexture2D, destinationTexture: destTexture2D)

#if os(macOS)
let blitCommandEncoder = commandBuffer.makeBlitCommandEncoder()!
blitCommandEncoder.synchronize(resource: destTexture)
blitCommandEncoder.synchronize(resource: destTexture2D)
blitCommandEncoder.endEncoding()
#endif

commandBuffer.commit()

commandBuffer.waitUntilCompleted()

[在我的实验中,该程序产生的结果与MPSImageStatisticsMeanAndVariance相同,给出或取一些差异,范围为1e-7。它也比我的Mac上的MPS 慢2.5倍,这可能部分是由于未能利用粒度并行机制来隐藏延迟。


0
投票
#include <metal_stdlib>
using namespace metal;

kernel void instance_norm(constant float4* scale[[buffer(0)]],
                          constant float4* shift[[buffer(1)]],
                          texture2d_array<float, access::read> in[[texture(0)]],
                          texture2d_array<float, access::write> out[[texture(1)]],

                          ushort3 gid[[thread_position_in_grid]],
                          ushort tid[[thread_index_in_threadgroup]],
                          ushort3 tg_size[[threads_per_threadgroup]]) {

    ushort width = in.get_width();
    ushort height = in.get_height();
    const ushort thread_count = tg_size.x * tg_size.y;

    threadgroup float4 shared_mem [256];

    float4 sum = 0;
    for(ushort xIndex = gid.x; xIndex < width; xIndex += tg_size.x) {
        for(ushort yIndex = gid.y; yIndex < height; yIndex += tg_size.y) {
            sum += in.read(ushort2(xIndex, yIndex), gid.z);
        }
    }
    shared_mem[tid] = sum;

    threadgroup_barrier(mem_flags::mem_threadgroup);

    // Reduce to 32 values
    sum = 0;
    if (tid < 32) {
        for (ushort i = tid + 32; i < thread_count; i += 32) {
            sum += shared_mem[i];
        }
    }
    shared_mem[tid] += sum;

    threadgroup_barrier(mem_flags::mem_threadgroup);

    // Calculate mean
    sum = 0;
    if (tid == 0) {
        ushort top = min(ushort(32), thread_count);
        for (ushort i = 0; i < top; i += 1) {
            sum += shared_mem[i];
        }
        shared_mem[0] = sum / (width * height);
    }
    threadgroup_barrier(mem_flags::mem_threadgroup);

    const float4 mean = shared_mem[0];

    threadgroup_barrier(mem_flags::mem_threadgroup);

    // Variance
    sum = 0;
    for(ushort xIndex = gid.x; xIndex < width; xIndex += tg_size.x) {
        for(ushort yIndex = gid.y; yIndex < height; yIndex += tg_size.y) {
            sum += pow(in.read(ushort2(xIndex, yIndex), gid.z) - mean, 2);
        }
    }

    shared_mem[tid] = sum;
    threadgroup_barrier(mem_flags::mem_threadgroup);

    // Reduce to 32 values
    sum = 0;
    if (tid < 32) {
        for (ushort i = tid + 32; i < thread_count; i += 32) {
            sum += shared_mem[i];
        }
    }
    shared_mem[tid] += sum;

    threadgroup_barrier(mem_flags::mem_threadgroup);

    // Calculate variance
    sum = 0;
    if (tid == 0) {
        ushort top = min(ushort(32), thread_count);
        for (ushort i = 0; i < top; i += 1) {
            sum += shared_mem[i];
        }
        shared_mem[0] = sum / (width * height);
    }
    threadgroup_barrier(mem_flags::mem_threadgroup);

    const float4 sigma = sqrt(shared_mem[0] + float4(1e-4));

    float4 multiplier = scale[gid.z] / sigma;
    for(ushort xIndex = gid.x; xIndex < width; xIndex += tg_size.x) {
        for(ushort yIndex = gid.y; yIndex < height; yIndex += tg_size.y) {
            float4 val = in.read(ushort2(xIndex, yIndex), gid.z);
            out.write(clamp((val - mean) * multiplier + shift[gid.z], -10.0, 10.0), ushort2(xIndex, yIndex), gid.z);
        }
    }

}

这是Blend工具的实现方式,但我认为这不是真的,有人可以解决吗?

https://github.com/xmartlabs/Bender/blob/master/Sources/Metal/instanceNorm.metal

© www.soinside.com 2019 - 2024. All rights reserved.