我已经创建了将张量混合在一起的代码,首先将它们混合在一起成行,然后将这些行混合在一起成最终输出。它适用于2x2模式中的4个张量,但不能执行2x3(6张量),3x3(9张量),4x4(16张量)模式。
张量的形式为(B x C x H x W),其中B是批处理大小,C是通道,H是高度,W是宽度。
对于平铺到行(tile_overlay()
)和行到最终图像(row_overlay()
),我创建一个基本张量,向其添加平铺/行。我怀疑代码的问题在于如何获取基本张量的尺寸,如何跟踪在基本张量上的行/小块的放置位置或这两个方面的问题。
import torch
from PIL import Image
import torchvision.transforms as transforms
def preprocess(image_name, image_size):
image = Image.open(image_name).convert('RGB')
if type(image_size) is not tuple:
image_size = tuple([int((float(image_size) / max(image.size))*x) for x in (image.height, image.width)])
Loader = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()])
tensor = (Loader(image) * 256).unsqueeze(0)
return tensor
def deprocess(output_tensor):
output_tensor = output_tensor.clone().squeeze(0).cpu() / 256
output_tensor.clamp_(0, 1)
Image2PIL = transforms.ToPILImage()
image = Image2PIL(output_tensor.cpu())
return image
def prepare_tile(tile, overlap, side='both'):
lin_mask_left = torch.linspace(0,1,overlap).repeat(tile.size(3),1).repeat(3,1,1).unsqueeze(0)
lin_mask_right = torch.linspace(1,0,overlap).repeat(tile.size(3),1).repeat(3,1,1).unsqueeze(0)
if side == 'both' or side == 'right':
tile[:,:,:,overlap:] = tile[:,:,:,overlap:] * lin_mask_right
if side == 'both' or side == 'left':
tile[:,:,:,:overlap] = tile[:,:,:,:overlap] * lin_mask_left
return tile
def overlay_tiles(tile_list, rows, overlap):
c = 1
f_tiles = []
base_length = 0
for i, tile in enumerate(tile_list):
if c == 1:
f_tile = prepare_tile(tile.clone(), overlap, side='right')
if i + 1<= rows[1]:
base_length += tile.clone().size(3) - overlap
elif c == rows[1]:
f_tile = prepare_tile(tile.clone(), overlap, side='left')
if i + 1<= rows[1]:
base_length += tile.size(3) - overlap
elif c > 0 and c < rows[1]:
f_tile = prepare_tile(tile.clone(), overlap, side='both')
if i + 1<= rows[1]:
base_length += tile.size(3) - (overlap*2)
f_tiles.append(f_tile)
if c == rows[1]:
c = 0
c+=1
base_length += overlap
base_tensor = torch.zeros(3, tile_list[0].size(2), base_length).unsqueeze(0)
row_list = []
for row in range(rows[1]):
row_list.append(base_tensor.clone())
row_val, num_tiles = 0, 0
l_max = tile_list[0].size(3)
for y in range(rows[0]):
for x in range(rows[1]):
if num_tiles % rows[1] != 0:
l_max += (f_tiles[num_tiles].size(3)-overlap)*x
l_min = l_max - f_tiles[num_tiles].size(3)
row_list[row_val][:, :, :, l_min:l_max] = row_list[row_val][:, :, :, l_min:l_max] + f_tiles[num_tiles]
else:
row_list[row_val][:, :, :, :f_tiles[num_tiles].size(3)] = f_tiles[num_tiles]
l_max = tile_list[0].size(3)
num_tiles+=1
row_val+=1
return row_list
def prepare_row(row_tensor, overlap, side='both'):
lin_mask_top = torch.linspace(0,1,overlap).repeat(row_tensor.size(3),1).rot90(3).repeat(3,1,1).unsqueeze(0)
lin_mask_bottom = torch.linspace(1,0,overlap).repeat(row_tensor.size(3),1).rot90(3).repeat(3,1,1).unsqueeze(0)
if side == 'both' or side == 'top':
row_tensor[:,:,:overlap,:] = row_tensor[:,:,:overlap,:] * lin_mask_top
if side == 'both' or side == 'bottom':
row_tensor[:,:,overlap:,:] = row_tensor[:,:,overlap:,:] * lin_mask_bottom
return row_tensor
def overlay_rows(row_list, rows, overlap):
c = 1
f_rows = []
base_height = 0
for i, row_tensor in enumerate(row_list):
if c == 1:
f_row = prepare_row(row_tensor.clone(), overlap, side='bottom')
if i + 1<= rows[0]:
base_height += row_tensor.size(2) - overlap
elif c == rows[1]:
f_row = prepare_row(row_tensor.clone(), overlap, side='top')
if i + 1<= rows[0]:
base_height += row_tensor.size(2) - overlap
elif c > 0 and c < rows[0]:
f_row = prepare_row(row_tensor.clone(), overlap, side='both')
if i + 1<= rows[0]:
base_height += tile.size(2) - (overlap*2)
f_rows.append(f_row)
if c == rows[0]:
c = 0
c+=1
base_height += overlap
base_tensor = torch.zeros(3, base_height, row_list[0].size(3)).unsqueeze(0)
num_rows = 0
l_max = row_list[0].size(3)
for y in range(rows[0]):
if num_rows > 0:
l_max += (f_rows[num_rows].size(2)-overlap)*y
l_min = l_max - f_rows[num_rows].size(2)
base_tensor[:, :, l_min:l_max, :] = base_tensor[:, :, l_min:l_max, :] + f_rows[num_rows]
else:
base_tensor[:, :, :f_rows[num_rows].size(2), :] = f_rows[num_rows]
l_max = row_list[0].size(2)
num_rows+=1
return base_tensor
def rebuild_image(tensor_list, rows, overlap_hw):
row_tensors = overlay_tiles(tensor_list, rows, overlap_hw[1])
full_tensor = overlay_rows(row_tensors, rows, overlap_hw[0])
return full_tensor
test_tensor_1 = preprocess('brad_pitt.jpg', (1080,1080))
test_tensor_2 = preprocess('starry_night_google.jpg', (1080,1080))
tensor_list = [test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), test_tensor_2.clone()]
rows = [2, 2]
overlap = [540, 540]
complete_tensor = rebuild_image(tensor_list, rows, overlap)
ft = deprocess(complete_tensor.clone())
ft.save('complete_tensor_2x2.png')
tensor_list = [test_tensor_1.clone(), test_tensor_2.clone(),test_tensor_1.clone(), \
test_tensor_1.clone(), test_tensor_2.clone(),test_tensor_1.clone(), \
test_tensor_1.clone(), test_tensor_2.clone(),test_tensor_1.clone(),]
rows = [3, 3]
overlap = [540, 540]
complete_tensor = rebuild_image(tensor_list, rows, overlap)
ft = deprocess(complete_tensor.clone())
ft.save('complete_tensor_3x3.png')
tensor_list = [test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), test_tensor_2.clone(), \
test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), test_tensor_2.clone(), \
test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), test_tensor_2.clone(), \
test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), test_tensor_2.clone()]
rows = [4, 4]
overlap = [540, 540]
complete_tensor = rebuild_image(tensor_list, rows, overlap)
ft = deprocess(complete_tensor.clone())
ft.save('complete_tensor_4x4.png')
尝试创建3x3输出时,运行上面的代码将导致此错误消息:
Traceback (most recent call last):
File "t0.py", line 148, in <module>
complete_tensor = rebuild_image(tensor_list, rows, overlap)
File "t0.py", line 126, in rebuild_image
row_tensors = overlay_tiles(tensor_list, rows, overlap_hw[1])
File "t0.py", line 68, in overlay_tiles
row_list[row_val][:, :, :, l_min:l_max] = row_list[row_val][:, :, :, l_min:l_max] + f_tiles[num_tiles]
RuntimeError: The size of tensor a (0) must match the size of tensor b (1080) at non-singleton dimension 3
这是2x2输出的示例:
这是一个可视化图表,其中包含我正在做的两个示例:
代码现在可以进行一些更改:
import torch
from PIL import Image
import torchvision.transforms as transforms
def preprocess(image_name, image_size):
image = Image.open(image_name).convert('RGB')
if type(image_size) is not tuple:
image_size = tuple([int((float(image_size) / max(image.size))*x) for x in (image.height, image.width)])
Loader = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()])
tensor = (Loader(image) * 256).unsqueeze(0)
return tensor
def deprocess(output_tensor):
output_tensor = output_tensor.clone().squeeze(0).cpu() / 256
output_tensor.clamp_(0, 1)
Image2PIL = transforms.ToPILImage()
image = Image2PIL(output_tensor.cpu())
return image
def prepare_tile(tile, overlap, side='both'):
h, w = tile.size(2), tile.size(3)
lin_mask_left = torch.linspace(0,1,overlap).repeat(h,1).repeat(3,1,1).unsqueeze(0)
lin_mask_right = torch.linspace(1,0,overlap).repeat(h,1).repeat(3,1,1).unsqueeze(0)
if side == 'both' or side == 'right':
tile[:,:,:,w-overlap:] = tile[:,:,:,w-overlap:] * lin_mask_right
if side == 'both' or side == 'left':
tile[:,:,:,:overlap] = tile[:,:,:,:overlap] * lin_mask_left
return tile
def calc_length(w, overlap, rows):
count = 0
l_max = w
for y in range(rows[0]):
for x in range(rows[1]):
if count % rows[1] != 0:
l_max += w-overlap
l_min = l_max - w
else:
l_max = w
count+=1
return l_max
def overlay_tiles(tile_list, rows, overlap):
c = 1
f_tiles = []
base_length = 0
for i, tile in enumerate(tile_list):
if c == 1:
f_tile = prepare_tile(tile.clone(), overlap, side='right')
elif c == rows[1]:
f_tile = prepare_tile(tile.clone(), overlap, side='left')
elif c > 0 and c < rows[1]:
f_tile = prepare_tile(tile.clone(), overlap, side='both')
f_tiles.append(f_tile)
if c == rows[1]:
c = 0
c+=1
w = tile_list[0].size(3)
base_length = calc_length(w, overlap, rows)
base_tensor = torch.zeros(3, tile_list[0].size(2), base_length).unsqueeze(0)
row_list = []
for row in range(rows[0]):
row_list.append(base_tensor.clone())
row_num, num_tiles = 0, 0
l_max = w
for y in range(rows[0]):
for x in range(rows[1]):
if num_tiles % rows[1] != 0:
l_max += w-overlap
l_min = l_max - w
print(num_tiles, l_max, l_min)
row_list[row_num][:, :, :, l_min:l_max] = row_list[row_num][:, :, :, l_min:l_max] + f_tiles[num_tiles]
else:
row_list[row_num][:, :, :, :w] = f_tiles[num_tiles]
l_max = w
num_tiles+=1
row_num+=1
return row_list
def prepare_row(row_tensor, overlap, side='both'):
lin_mask_top = torch.linspace(0,1,overlap).repeat(row_tensor.size(3),1).rot90(3).repeat(3,1,1).unsqueeze(0)
lin_mask_bottom = torch.linspace(1,0,overlap).repeat(row_tensor.size(3),1).rot90(3).repeat(3,1,1).unsqueeze(0)
if side == 'both' or side == 'top':
row_tensor[:,:,:overlap,:] = row_tensor[:,:,:overlap,:] * lin_mask_top
if side == 'both' or side == 'bottom':
row_tensor[:,:,overlap:,:] = row_tensor[:,:,overlap:,:] * lin_mask_bottom
return row_tensor
def calc_height(h, overlap, rows):
num_rows = 0
l_max = h
for y in range(rows[0]):
if num_rows > 0:
l_max += (h-overlap)
l_min = l_max - h
else:
l_max = h
num_rows+=1
return l_max
def overlay_rows(row_list, rows, overlap):
c = 1
f_rows = []
base_height = 0
for i, row_tensor in enumerate(row_list):
if c == 1:
f_row = prepare_row(row_tensor.clone(), overlap, side='bottom')
elif c == rows[0]:
f_row = prepare_row(row_tensor.clone(), overlap, side='top')
elif c > 0 and c < rows[0]:
f_row = prepare_row(row_tensor.clone(), overlap, side='both')
f_rows.append(f_row)
if c == rows[0]:
c = 0
c+=1
h = row_list[0].size(2)
base_height = calc_height(h, overlap, rows)
base_tensor = torch.zeros(3, base_height, row_list[0].size(3)).unsqueeze(0)
num_rows = 0
l_max = row_list[0].size(3)
for y in range(rows[0]):
if num_rows > 0:
l_max += (h-overlap)
l_min = l_max - h
base_tensor[:, :, l_min:l_max, :] = base_tensor[:, :, l_min:l_max, :] + f_rows[num_rows]
else:
base_tensor[:, :, :h, :] = f_rows[num_rows]
l_max = h
num_rows+=1
return base_tensor
def rebuild_image(tensor_list, rows, overlap_hw):
row_tensors = overlay_tiles(tensor_list, rows, overlap_hw[1])
full_tensor = overlay_rows(row_tensors, rows, overlap_hw[0])
return full_tensor
test_tensor_1 = preprocess('brad_pitt.jpg', (1080,720))
test_tensor_2 = preprocess('starry_night_google.jpg', (1080,720))
print("2x2 Test")
tensor_list = [test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), test_tensor_2.clone()]
rows = [2, 2]
overlap = [540, 260]
complete_tensor = rebuild_image(tensor_list, rows, overlap)
ft = deprocess(complete_tensor.clone())
ft.save('complete_tensor_2x2.png')
print("3x3 Test")
tensor_list = [test_tensor_1.clone(), test_tensor_2.clone(),test_tensor_1.clone(), \
test_tensor_1.clone(), test_tensor_2.clone(),test_tensor_1.clone(), \
test_tensor_1.clone(), test_tensor_2.clone(),test_tensor_1.clone(),]
rows = [3, 3]
overlap = [540, 540]
complete_tensor = rebuild_image(tensor_list, rows, overlap)
ft = deprocess(complete_tensor.clone())
ft.save('complete_tensor_3x3.png')
print("3x4 Test")
tensor_list = [test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), test_tensor_2.clone(), \
test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), test_tensor_2.clone(), \
test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), test_tensor_2.clone()]
rows = [3, 4]
overlap = [540, 260]
complete_tensor = rebuild_image(tensor_list, rows, overlap)
ft = deprocess(complete_tensor.clone())
ft.save('complete_tensor_3x4.png')
print("4x3 Test")
tensor_list = [test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), \
test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), \
test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), \
test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone()]
rows = [4, 3]
overlap = [540, 260]
complete_tensor = rebuild_image(tensor_list, rows, overlap)
ft = deprocess(complete_tensor.clone())
ft.save('complete_tensor_4x3.png')
print("4x4 Test")
tensor_list = [test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), test_tensor_2.clone(), \
test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), test_tensor_2.clone(), \
test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), test_tensor_2.clone(), \
test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), test_tensor_2.clone()]
rows = [4, 4]
overlap = [540, 260]
complete_tensor = rebuild_image(tensor_list, rows, overlap)
ft = deprocess(complete_tensor.clone())
ft.save('complete_tensor_4x4.png')