在Python上调整图像大小的随机森林

问题描述 投票:0回答:1

我目前正在尝试一个实验项目,其中我使用随机森林回归模型将低分辨率图像调整为与高分辨率图像相同的大小,然后将其与主食图像增强之外的其他调整大小技术(基本上是插值技术)进行比较).

到目前为止,我在这里提供的代码运行顺利并且不需要很长时间,但只返回新图像的一个补丁,而我的实际目标是获取整个图像,就像我使用插值技术所做的那样,或者,如果这需要太多的精力和时间,至少要提供相关图像的可识别片段。尽管我正在关注一篇使用决策树的科学论文,但在计算和质量方面获得了良好的结果(使用 PSNR 和 SSIM 等指标)。

如何修改提供的代码以实现我之前公开的内容?

from sklearn.ensemble import RandomForestRegressor
from matplotlib import pyplot as plt
from skimage.transform import downscale_local_mean, resize
from skimage.metrics import peak_signal_noise_ratio as compare_psnr
from skimage import io, color
import numpy as np
import time


def safely_convert_to_gray(image):
    if len(image.shape) == 3 and image.shape[2] == 3:  # Imagen RGB
        return color.rgb2gray(image)
    elif len(image.shape) == 2:  # Imagen en escala de grises
        return image
    else:
        raise ValueError("La imagen no es RGB ni escala de grises en formato reconocido")

def create_patches(lr_images, hr_images, patch_size, scale):
    lr_patches = []
    hr_patches = []
    for lr_img, hr_img in zip(lr_images, hr_images):
        # Asegúrate de que las imágenes estén en escala de grises y reescaladas adecuadamente
        for i in range(0, lr_img.shape[0] - patch_size + 1, patch_size):
            for j in range(0, lr_img.shape[1] - patch_size + 1, patch_size):
                # Extrae parches de la imagen de baja resolución
                lr_patch = lr_img[i:i + patch_size, j:j + patch_size]
                # Asegura que el parche de HR tenga el tamaño correcto, teniendo en cuenta el factor de escala
                hr_patch = hr_img[i*scale:(i+patch_size)*scale, j*scale:(j+patch_size)*scale]
                if lr_patch.shape == (patch_size, patch_size) and hr_patch.shape == (patch_size*scale, patch_size*scale):
                    lr_patches.append(lr_patch.flatten())
                    hr_patches.append(hr_patch.flatten())
    return np.array(lr_patches), np.array(hr_patches)


def load_images_and_features(patch_size, scale):
    lr_image_url = 'https://raw.githubusercontent.com/scikit-image/scikit-image/main/skimage/data/camera.png'
    hr_image_url = 'https://raw.githubusercontent.com/scikit-image/scikit-image/main/skimage/data/astronaut.png'
    lr_image = safely_convert_to_gray(io.imread(lr_image_url))
    hr_image = safely_convert_to_gray(io.imread(hr_image_url))
    lr_images = [lr_image]
    hr_images = [hr_image]
    features, labels = create_patches(lr_images, hr_images, patch_size, scale)
    return features, labels

patch_size = 8  # El tamaño de los parches extraídos de la imagen de baja resolución
scale = 2       # El factor de escala entre las imágenes de baja y alta resolución

features, labels = load_images_and_features(patch_size, scale)

# Entrenamiento del modelo
rf = RandomForestRegressor(n_estimators=10, random_state=42)
rf.fit(features, labels)

def upscale_image(lr_img, model, patch_size, scale):
    upscaled_img = resize(lr_img, (lr_img.shape[0] * scale, lr_img.shape[1] * scale), anti_aliasing=True)
    for i in range(0, upscaled_img.shape[0] - patch_size * scale + 1, scale):
        for j in range(0, upscaled_img.shape[1] - patch_size * scale + 1, scale):
            lr_patch = upscaled_img[i//scale:(i//scale)+patch_size, j//scale:(j//scale)+patch_size].flatten()
            hr_patch_predicted = model.predict([lr_patch])
            upscaled_img[i:i + patch_size * scale, j:j + patch_size * scale] = hr_patch_predicted.reshape(patch_size * scale, patch_size * scale)
    return upscaled_img

# Prueba de escalamiento de imagen
lr_test_image = features[0].reshape(patch_size, patch_size)
upscaled_image = upscale_image(lr_test_image, rf, patch_size, scale)


# tiempo
tiempo=time.process_time()
print("tiempo: ", round(tiempo,2))

# Visualización de la imagen original y la escalada
plt.title("Upscaled Image")
plt.imshow(upscaled_image)
plt.show()

如前所述,当前代码确实可以运行,但不能满足我的需求,这将是新调整大小的图像的完整图像,或者至少是图像的可识别片段,而不仅仅是补丁。所以我可以稍后将其与原始 HR 版本进行比较。

图片来自当前结果。

python image-processing interpolation random-forest
1个回答
0
投票

下面的代码是对该任务的尝试。选择图像作为训练集和验证集(在制作补丁之前)。图像被转换为补丁(为了训练而打乱),其中输入数据

X
是低分辨率补丁,目标
y
是原始补丁:

模型根据数据进行训练,并根据验证图像补丁进行评估:

请注意,上面重新缩放的图像

rescale()
应用于整个图像。这与应用于训练所用的各个补丁的
rescale()
类似(但不相同)。

from sklearn.ensemble import RandomForestRegressor
from sklearn.feature_extraction.image import extract_patches_2d, reconstruct_from_patches_2d
from matplotlib import pyplot as plt
from skimage.transform import downscale_local_mean, resize, rescale
from skimage.metrics import peak_signal_noise_ratio as compare_psnr
from skimage import io, color
import numpy as np
import time

def safely_convert_to_gray(image):
    if len(image.shape) == 3 and image.shape[2] == 3:  # Imagen RGB
        return color.rgb2gray(image)
    elif len(image.shape) == 2:  # Imagen en escala de grises
        return image
    else:
        raise ValueError("La imagen no es RGB ni escala de grises en formato reconocido")


def image_to_Xy(image, patch_size, scale):
    """
    Returns: (X, y) tuple
    where X: (n_patches, patch_size**2 / scale**2) array of low-res patches
          y: (n_patches, patch_size**2) array of hi-res patches
    """
    
    hires_patches = extract_patches_2d(image, [patch_size] * 2)    
    lowres_patches = np.array(
        [rescale(patch, 1 / scale, anti_aliasing=True) for patch in hires_patches]
    )
    
    hires_patches = hires_patches.reshape(-1, patch_size ** 2)
    lowres_patches = lowres_patches.reshape(-1, (patch_size // scale)**2)
    return lowres_patches.astype(np.float32), hires_patches.astype(np.float32)

def images_to_Xy(images, patch_size, scale, shuffle=True, random_state=None):
    """
    Returns (X, y)
    where X: (patches of all images, patch_size**2 / scale**2) array of low-res patches
    where y: (patches of all images, patch_size**2) array of hi-res patches
    """
    Xy_perimage = [image_to_Xy(image, patch_size, scale) for image in images]
    
    X_arr = np.concatenate([X for X, y in Xy_perimage], axis=0)
    y_arr = np.concatenate([y for X, y in Xy_perimage], axis=0)
    
    if shuffle:
        ixs = np.random.default_rng(random_state).permutation(range(len(X_arr)))
        X_arr, y_arr = [arr[ixs] for arr in (X_arr, y_arr)]
        
    return X_arr, y_arr

def patches_to_image(patches, image_size):
    """
    Reconstructs image from patches.
    Only for the hi-res original & hi-res prediction.
    Does not work for the downsampled images.
    """
    psize = int(patches.shape[-1]**0.5)
    patches_unflat = patches.reshape(-1, psize, psize)
    return reconstruct_from_patches_2d(patches_unflat, image_size)

#
# Load train and validation images
#
image_urls = [
    'https://raw.githubusercontent.com/scikit-image/scikit-image/main/skimage/data/camera.png',
    'https://raw.githubusercontent.com/scikit-image/scikit-image/main/skimage/data/coffee.png'
]
images = [safely_convert_to_gray(io.imread(url)) for url in image_urls]
images = [rescale(image, 0.25) for image in images] #let's work with smaller images due to RAM

train_images = images[0:1]
val_image = images[1]

patch_size = 16
scale = 4

trn_features, trn_labels = images_to_Xy(train_images, patch_size, scale, shuffle=True, random_state=0)
val_features, val_labels = image_to_Xy(val_image, patch_size, scale)

#
#View some samples
#
f, axs = plt.subplots(2, 10, figsize=(9, 2.3))
axs = axs.flatten()
for i in range(0, 20, 2):
    ax = axs[i]
    ax.imshow(trn_features[i].reshape(patch_size // scale, patch_size // scale), cmap='gray')
    ax.axis('off')
    ax.set_title(' ' * 20 + f'X[{i}], y[{i}]', fontsize=8)
    
    ax = axs[i + 1]
    ax.imshow(trn_labels[i].reshape(patch_size, patch_size), cmap='gray')
    ax.axis('off')
plt.show()

#
#Fit model on train data
# Can limit number of samples for RAM
#
trn_limit = None
rf = RandomForestRegressor(n_estimators=15, random_state=np.random.RandomState(0), n_jobs=-1)
rf.fit(trn_features[:trn_limit], trn_labels[:trn_limit])

#
#Assess on validation set
#
val_predictions = rf.predict(val_features).astype(np.float32)
pred_image = patches_to_image(val_predictions, val_image.shape)

f, axs = plt.subplots(1, 3, figsize=(9, 3))
ax = axs[0]
ax.imshow(val_image, cmap='gray')
ax.set_title('Original')

ax = axs[1]
ax.imshow(rescale(val_image, 1/scale, anti_aliasing=True), cmap='gray')
ax.set_title(f'Rescaled 1/{scale}')

ax = axs[2]
ax.imshow(pred_image, cmap='gray')
ax.set_title('Upsampled')

[ax.axis('off') for ax in axs]
© www.soinside.com 2019 - 2024. All rights reserved.