如何将张量以2x2以外的模式混合在一起?

问题描述 投票:0回答:1

我已经创建了将张量混合在一起的代码,首先将它们混合在一起成行,然后将这些行混合在一起成最终输出。它适用于2x2模式中的4个张量,但不能执行2x3(6张量),3x3(9张量),4x4(16张量)模式。

张量的形式为(B x C x H x W),其中B是批处理大小,C是通道,H是高度,W是宽度。

对于平铺到行(tile_overlay())和行到最终图像(row_overlay()),我创建一个基本张量,向其添加平铺/行。我怀疑代码的问题在于如何获取基本张量的尺寸,如何跟踪在基本张量上的行/小块的放置位置或这两个方面的问题。

import torch
from PIL import Image
import torchvision.transforms as transforms


def preprocess(image_name, image_size):
    image = Image.open(image_name).convert('RGB')
    if type(image_size) is not tuple:
        image_size = tuple([int((float(image_size) / max(image.size))*x) for x in (image.height, image.width)])
    Loader = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()])
    tensor = (Loader(image) * 256).unsqueeze(0)
    return tensor

def deprocess(output_tensor):
    output_tensor = output_tensor.clone().squeeze(0).cpu() / 256
    output_tensor.clamp_(0, 1)
    Image2PIL = transforms.ToPILImage()
    image = Image2PIL(output_tensor.cpu())
    return image




def prepare_tile(tile, overlap, side='both'):
    lin_mask_left = torch.linspace(0,1,overlap).repeat(tile.size(3),1).repeat(3,1,1).unsqueeze(0)
    lin_mask_right = torch.linspace(1,0,overlap).repeat(tile.size(3),1).repeat(3,1,1).unsqueeze(0)
    if side == 'both' or side == 'right':
        tile[:,:,:,overlap:] = tile[:,:,:,overlap:] * lin_mask_right
    if side == 'both' or side == 'left':
        tile[:,:,:,:overlap] = tile[:,:,:,:overlap] * lin_mask_left 
    return tile

def overlay_tiles(tile_list, rows, overlap):        
    c = 1
    f_tiles = []
    base_length = 0
    for i, tile in enumerate(tile_list):
        if c == 1:    
             f_tile = prepare_tile(tile.clone(), overlap, side='right')
             if i + 1<= rows[1]:
                 base_length += tile.clone().size(3) - overlap
        elif c == rows[1]:
             f_tile = prepare_tile(tile.clone(), overlap, side='left')
             if i + 1<= rows[1]:
                 base_length += tile.size(3) - overlap
        elif c > 0 and c < rows[1]:
             f_tile = prepare_tile(tile.clone(), overlap, side='both')
             if i + 1<= rows[1]:
                 base_length += tile.size(3) - (overlap*2)
        f_tiles.append(f_tile)  
        if c == rows[1]:
             c = 0
        c+=1

    base_length += overlap           
    base_tensor = torch.zeros(3, tile_list[0].size(2), base_length).unsqueeze(0)
    row_list = []
    for row in range(rows[1]):
        row_list.append(base_tensor.clone())

    row_val, num_tiles = 0, 0
    l_max = tile_list[0].size(3)
    for y in range(rows[0]):       
        for x in range(rows[1]):        
            if num_tiles % rows[1] != 0:
                l_max += (f_tiles[num_tiles].size(3)-overlap)*x
                l_min = l_max - f_tiles[num_tiles].size(3)
                row_list[row_val][:, :, :, l_min:l_max] = row_list[row_val][:, :, :, l_min:l_max] + f_tiles[num_tiles]
            else:
                row_list[row_val][:, :, :, :f_tiles[num_tiles].size(3)] = f_tiles[num_tiles]  
                l_max = tile_list[0].size(3)
            num_tiles+=1 
        row_val+=1  
    return row_list


def prepare_row(row_tensor, overlap, side='both'):
    lin_mask_top = torch.linspace(0,1,overlap).repeat(row_tensor.size(3),1).rot90(3).repeat(3,1,1).unsqueeze(0)
    lin_mask_bottom = torch.linspace(1,0,overlap).repeat(row_tensor.size(3),1).rot90(3).repeat(3,1,1).unsqueeze(0)
    if side == 'both' or side == 'top':
        row_tensor[:,:,:overlap,:] = row_tensor[:,:,:overlap,:]  * lin_mask_top
    if side == 'both' or side == 'bottom':
        row_tensor[:,:,overlap:,:] = row_tensor[:,:,overlap:,:] * lin_mask_bottom   
    return row_tensor

def overlay_rows(row_list, rows, overlap):
    c = 1
    f_rows = []
    base_height = 0
    for i, row_tensor in enumerate(row_list):
        if c == 1:    
             f_row = prepare_row(row_tensor.clone(), overlap, side='bottom')
             if i + 1<= rows[0]:
                 base_height += row_tensor.size(2) - overlap
        elif c == rows[1]:
             f_row = prepare_row(row_tensor.clone(), overlap, side='top')
             if i + 1<= rows[0]:
                 base_height += row_tensor.size(2) - overlap
        elif c > 0 and c < rows[0]:
             f_row = prepare_row(row_tensor.clone(), overlap, side='both')
             if i + 1<= rows[0]:
                 base_height += tile.size(2) - (overlap*2)
        f_rows.append(f_row)    
        if c == rows[0]:
             c = 0
        c+=1

    base_height += overlap           
    base_tensor = torch.zeros(3, base_height, row_list[0].size(3)).unsqueeze(0)

    num_rows = 0
    l_max = row_list[0].size(3)
    for y in range(rows[0]):        
            if num_rows > 0:
                l_max += (f_rows[num_rows].size(2)-overlap)*y
                l_min = l_max - f_rows[num_rows].size(2)
                base_tensor[:, :, l_min:l_max, :] = base_tensor[:, :, l_min:l_max, :] + f_rows[num_rows]
            else:
                base_tensor[:, :, :f_rows[num_rows].size(2), :] = f_rows[num_rows]  
                l_max = row_list[0].size(2)
            num_rows+=1   
    return base_tensor


def rebuild_image(tensor_list, rows, overlap_hw):
    row_tensors = overlay_tiles(tensor_list, rows, overlap_hw[1])
    full_tensor = overlay_rows(row_tensors, rows, overlap_hw[0])
    return full_tensor

test_tensor_1 = preprocess('brad_pitt.jpg', (1080,1080))
test_tensor_2 = preprocess('starry_night_google.jpg', (1080,1080))


tensor_list = [test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), test_tensor_2.clone()]
rows = [2, 2]
overlap = [540, 540]
complete_tensor = rebuild_image(tensor_list, rows, overlap)

ft = deprocess(complete_tensor.clone())
ft.save('complete_tensor_2x2.png')


tensor_list = [test_tensor_1.clone(), test_tensor_2.clone(),test_tensor_1.clone(), \
test_tensor_1.clone(), test_tensor_2.clone(),test_tensor_1.clone(), \
test_tensor_1.clone(), test_tensor_2.clone(),test_tensor_1.clone(),]
rows = [3, 3]
overlap = [540, 540]
complete_tensor = rebuild_image(tensor_list, rows, overlap)

ft = deprocess(complete_tensor.clone())
ft.save('complete_tensor_3x3.png')



tensor_list = [test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), test_tensor_2.clone(), \
test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), test_tensor_2.clone(), \
test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), test_tensor_2.clone(), \
test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), test_tensor_2.clone()]
rows = [4, 4]
overlap = [540, 540]
complete_tensor = rebuild_image(tensor_list, rows, overlap)

ft = deprocess(complete_tensor.clone())
ft.save('complete_tensor_4x4.png')

尝试创建3x3输出时,运行上面的代码将导致此错误消息:

Traceback (most recent call last):
  File "t0.py", line 148, in <module>
    complete_tensor = rebuild_image(tensor_list, rows, overlap)
  File "t0.py", line 126, in rebuild_image
    row_tensors = overlay_tiles(tensor_list, rows, overlap_hw[1])
  File "t0.py", line 68, in overlay_tiles
    row_list[row_val][:, :, :, l_min:l_max] = row_list[row_val][:, :, :, l_min:l_max] + f_tiles[num_tiles]
RuntimeError: The size of tensor a (0) must match the size of tensor b (1080) at non-singleton dimension 3

这是2x2输出的示例:

enter image description here

这是一个可视化图表,其中包含我正在做的两个示例:

“”

python image-processing pytorch mask tensor
1个回答
0
投票

代码现在可以进行一些更改:

import torch
from PIL import Image
import torchvision.transforms as transforms


def preprocess(image_name, image_size):
    image = Image.open(image_name).convert('RGB')
    if type(image_size) is not tuple:
        image_size = tuple([int((float(image_size) / max(image.size))*x) for x in (image.height, image.width)])
    Loader = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()])
    tensor = (Loader(image) * 256).unsqueeze(0)
    return tensor

def deprocess(output_tensor):
    output_tensor = output_tensor.clone().squeeze(0).cpu() / 256
    output_tensor.clamp_(0, 1)
    Image2PIL = transforms.ToPILImage()
    image = Image2PIL(output_tensor.cpu())
    return image




def prepare_tile(tile, overlap, side='both'):
    h, w = tile.size(2), tile.size(3)
    lin_mask_left = torch.linspace(0,1,overlap).repeat(h,1).repeat(3,1,1).unsqueeze(0)
    lin_mask_right = torch.linspace(1,0,overlap).repeat(h,1).repeat(3,1,1).unsqueeze(0)
    if side == 'both' or side == 'right':
        tile[:,:,:,w-overlap:] = tile[:,:,:,w-overlap:] * lin_mask_right
    if side == 'both' or side == 'left':
        tile[:,:,:,:overlap] = tile[:,:,:,:overlap] * lin_mask_left 
    return tile


def calc_length(w, overlap, rows):
    count = 0
    l_max = w
    for y in range(rows[0]):       
        for x in range(rows[1]):        
            if count % rows[1] != 0:
                l_max += w-overlap
                l_min = l_max - w
            else:  
                l_max = w
            count+=1 
    return l_max

def overlay_tiles(tile_list, rows, overlap):        
    c = 1
    f_tiles = []
    base_length = 0
    for i, tile in enumerate(tile_list):
        if c == 1:    
             f_tile = prepare_tile(tile.clone(), overlap, side='right')
        elif c == rows[1]:
             f_tile = prepare_tile(tile.clone(), overlap, side='left')
        elif c > 0 and c < rows[1]:
             f_tile = prepare_tile(tile.clone(), overlap, side='both')
        f_tiles.append(f_tile)  
        if c == rows[1]:
             c = 0
        c+=1

    w = tile_list[0].size(3)
    base_length = calc_length(w, overlap, rows)
    base_tensor = torch.zeros(3, tile_list[0].size(2), base_length).unsqueeze(0)

    row_list = []
    for row in range(rows[0]):
        row_list.append(base_tensor.clone())    

    row_num, num_tiles = 0, 0
    l_max = w
    for y in range(rows[0]):       
        for x in range(rows[1]):        
            if num_tiles % rows[1] != 0:
                l_max += w-overlap
                l_min = l_max - w
                print(num_tiles, l_max, l_min)
                row_list[row_num][:, :, :, l_min:l_max] = row_list[row_num][:, :, :, l_min:l_max] + f_tiles[num_tiles]
            else:
                row_list[row_num][:, :, :, :w] = f_tiles[num_tiles]  
                l_max = w
            num_tiles+=1 
        row_num+=1  
    return row_list


def prepare_row(row_tensor, overlap, side='both'):
    lin_mask_top = torch.linspace(0,1,overlap).repeat(row_tensor.size(3),1).rot90(3).repeat(3,1,1).unsqueeze(0)
    lin_mask_bottom = torch.linspace(1,0,overlap).repeat(row_tensor.size(3),1).rot90(3).repeat(3,1,1).unsqueeze(0)
    if side == 'both' or side == 'top':
        row_tensor[:,:,:overlap,:] = row_tensor[:,:,:overlap,:]  * lin_mask_top
    if side == 'both' or side == 'bottom':
        row_tensor[:,:,overlap:,:] = row_tensor[:,:,overlap:,:] * lin_mask_bottom
    return row_tensor


def calc_height(h, overlap, rows):
    num_rows = 0
    l_max = h
    for y in range(rows[0]):        
            if num_rows > 0:
                l_max += (h-overlap)
                l_min = l_max - h
            else: 
                l_max = h
            num_rows+=1  
    return l_max

def overlay_rows(row_list, rows, overlap):
    c = 1
    f_rows = []
    base_height = 0
    for i, row_tensor in enumerate(row_list):
        if c == 1:    
             f_row = prepare_row(row_tensor.clone(), overlap, side='bottom')
        elif c == rows[0]:
             f_row = prepare_row(row_tensor.clone(), overlap, side='top')
        elif c > 0 and c < rows[0]:
             f_row = prepare_row(row_tensor.clone(), overlap, side='both')
        f_rows.append(f_row)    
        if c == rows[0]:
             c = 0
        c+=1


    h = row_list[0].size(2)
    base_height = calc_height(h, overlap, rows)          
    base_tensor = torch.zeros(3, base_height, row_list[0].size(3)).unsqueeze(0)

    num_rows = 0
    l_max = row_list[0].size(3)
    for y in range(rows[0]):        
            if num_rows > 0:
                l_max += (h-overlap)
                l_min = l_max - h
                base_tensor[:, :, l_min:l_max, :] = base_tensor[:, :, l_min:l_max, :] + f_rows[num_rows]
            else:
                base_tensor[:, :, :h, :] = f_rows[num_rows]  
                l_max = h
            num_rows+=1   
    return base_tensor


def rebuild_image(tensor_list, rows, overlap_hw):
    row_tensors = overlay_tiles(tensor_list, rows, overlap_hw[1])
    full_tensor = overlay_rows(row_tensors, rows, overlap_hw[0])
    return full_tensor

test_tensor_1 = preprocess('brad_pitt.jpg', (1080,720))
test_tensor_2 = preprocess('starry_night_google.jpg', (1080,720))



print("2x2 Test")
tensor_list = [test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), test_tensor_2.clone()]
rows = [2, 2]
overlap = [540, 260]
complete_tensor = rebuild_image(tensor_list, rows, overlap)

ft = deprocess(complete_tensor.clone())
ft.save('complete_tensor_2x2.png')


print("3x3 Test")
tensor_list = [test_tensor_1.clone(), test_tensor_2.clone(),test_tensor_1.clone(), \
test_tensor_1.clone(), test_tensor_2.clone(),test_tensor_1.clone(), \
test_tensor_1.clone(), test_tensor_2.clone(),test_tensor_1.clone(),]
rows = [3, 3]
overlap = [540, 540]
complete_tensor = rebuild_image(tensor_list, rows, overlap)

ft = deprocess(complete_tensor.clone())
ft.save('complete_tensor_3x3.png')

print("3x4 Test")
tensor_list = [test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), test_tensor_2.clone(), \
test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), test_tensor_2.clone(), \
test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), test_tensor_2.clone()]
rows = [3, 4]
overlap = [540, 260]
complete_tensor = rebuild_image(tensor_list, rows, overlap)

ft = deprocess(complete_tensor.clone())
ft.save('complete_tensor_3x4.png')


print("4x3 Test")
tensor_list = [test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), \
test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), \
test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), \
test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone()]
rows = [4, 3]
overlap = [540, 260]
complete_tensor = rebuild_image(tensor_list, rows, overlap)

ft = deprocess(complete_tensor.clone())
ft.save('complete_tensor_4x3.png')

print("4x4 Test")
tensor_list = [test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), test_tensor_2.clone(), \
test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), test_tensor_2.clone(), \
test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), test_tensor_2.clone(), \
test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), test_tensor_2.clone()]
rows = [4, 4]
overlap = [540, 260]
complete_tensor = rebuild_image(tensor_list, rows, overlap)

ft = deprocess(complete_tensor.clone())
ft.save('complete_tensor_4x4.png')
© www.soinside.com 2019 - 2024. All rights reserved.