到目前为止,我在这里提供的代码运行顺利并且不需要很长时间,但只返回新图像的一个补丁,而我的实际目标是获取整个图像,就像我使用插值技术所做的那样,或者,如果这需要太多的精力和时间,至少要提供相关图像的可识别片段。尽管我正在关注一篇使用决策树的科学论文,但在计算和质量方面获得了良好的结果(使用 PSNR 和 SSIM 等指标)。


from sklearn.ensemble import RandomForestRegressor
from matplotlib import pyplot as plt
from skimage.transform import downscale_local_mean, resize
from skimage.metrics import peak_signal_noise_ratio as compare_psnr
from skimage import io, color
import numpy as np
import time

def safely_convert_to_gray(image):
    if len(image.shape) == 3 and image.shape[2] == 3:  # Imagen RGB
        return color.rgb2gray(image)
    elif len(image.shape) == 2:  # Imagen en escala de grises
        return image
        raise ValueError("La imagen no es RGB ni escala de grises en formato reconocido")

def create_patches(lr_images, hr_images, patch_size, scale):
    lr_patches = []
    hr_patches = []
    for lr_img, hr_img in zip(lr_images, hr_images):
        # Asegúrate de que las imágenes estén en escala de grises y reescaladas adecuadamente
        for i in range(0, lr_img.shape[0] - patch_size + 1, patch_size):
            for j in range(0, lr_img.shape[1] - patch_size + 1, patch_size):
                # Extrae parches de la imagen de baja resolución
                lr_patch = lr_img[i:i + patch_size, j:j + patch_size]
                # Asegura que el parche de HR tenga el tamaño correcto, teniendo en cuenta el factor de escala
                hr_patch = hr_img[i*scale:(i+patch_size)*scale, j*scale:(j+patch_size)*scale]
                if lr_patch.shape == (patch_size, patch_size) and hr_patch.shape == (patch_size*scale, patch_size*scale):
    return np.array(lr_patches), np.array(hr_patches)

def load_images_and_features(patch_size, scale):
    lr_image_url = 'https://raw.githubusercontent.com/scikit-image/scikit-image/main/skimage/data/camera.png'
    hr_image_url = 'https://raw.githubusercontent.com/scikit-image/scikit-image/main/skimage/data/astronaut.png'
    lr_image = safely_convert_to_gray(io.imread(lr_image_url))
    hr_image = safely_convert_to_gray(io.imread(hr_image_url))
    lr_images = [lr_image]
    hr_images = [hr_image]
    features, labels = create_patches(lr_images, hr_images, patch_size, scale)
    return features, labels

patch_size = 8  # El tamaño de los parches extraídos de la imagen de baja resolución
scale = 2       # El factor de escala entre las imágenes de baja y alta resolución

features, labels = load_images_and_features(patch_size, scale)

# Entrenamiento del modelo
rf = RandomForestRegressor(n_estimators=10, random_state=42)
rf.fit(features, labels)

def upscale_image(lr_img, model, patch_size, scale):
    upscaled_img = resize(lr_img, (lr_img.shape[0] * scale, lr_img.shape[1] * scale), anti_aliasing=True)
    for i in range(0, upscaled_img.shape[0] - patch_size * scale + 1, scale):
        for j in range(0, upscaled_img.shape[1] - patch_size * scale + 1, scale):
            lr_patch = upscaled_img[i//scale:(i//scale)+patch_size, j//scale:(j//scale)+patch_size].flatten()
            hr_patch_predicted = model.predict([lr_patch])
            upscaled_img[i:i + patch_size * scale, j:j + patch_size * scale] = hr_patch_predicted.reshape(patch_size * scale, patch_size * scale)
    return upscaled_img

# Prueba de escalamiento de imagen
lr_test_image = features[0].reshape(patch_size, patch_size)
upscaled_image = upscale_image(lr_test_image, rf, patch_size, scale)

# tiempo
print("tiempo: ", round(tiempo,2))

# Visualización de la imagen original y la escalada
plt.title("Upscaled Image")

如前所述,当前代码确实可以运行,但不能满足我的需求,这将是新调整大小的图像的完整图像,或者至少是图像的可识别片段,而不仅仅是补丁。所以我可以稍后将其与原始 HR 版本进行比较。


from sklearn.ensemble import RandomForestRegressor
from sklearn.feature_extraction.image import extract_patches_2d, reconstruct_from_patches_2d
from matplotlib import pyplot as plt
from skimage.transform import downscale_local_mean, resize, rescale
from skimage.metrics import peak_signal_noise_ratio as compare_psnr
from skimage import io, color
import numpy as np
import time

def safely_convert_to_gray(image):
    if len(image.shape) == 3 and image.shape[2] == 3:  # Imagen RGB
        return color.rgb2gray(image)
    elif len(image.shape) == 2:  # Imagen en escala de grises
        return image
        raise ValueError("La imagen no es RGB ni escala de grises en formato reconocido")

def image_to_Xy(image, patch_size, scale):
    Returns: (X, y) tuple
    where X: (n_patches, patch_size**2 / scale**2) array of low-res patches
          y: (n_patches, patch_size**2) array of hi-res patches
    hires_patches = extract_patches_2d(image, [patch_size] * 2)    
    lowres_patches = np.array(
        [rescale(patch, 1 / scale, anti_aliasing=True) for patch in hires_patches]
    hires_patches = hires_patches.reshape(-1, patch_size ** 2)
    lowres_patches = lowres_patches.reshape(-1, (patch_size // scale)**2)
    return lowres_patches.astype(np.float32), hires_patches.astype(np.float32)

def images_to_Xy(images, patch_size, scale, shuffle=True, random_state=None):
    Returns (X, y)
    where X: (patches of all images, patch_size**2 / scale**2) array of low-res patches
    where y: (patches of all images, patch_size**2) array of hi-res patches
    Xy_perimage = [image_to_Xy(image, patch_size, scale) for image in images]
    X_arr = np.concatenate([X for X, y in Xy_perimage], axis=0)
    y_arr = np.concatenate([y for X, y in Xy_perimage], axis=0)
    if shuffle:
        ixs = np.random.default_rng(random_state).permutation(range(len(X_arr)))
        X_arr, y_arr = [arr[ixs] for arr in (X_arr, y_arr)]
    return X_arr, y_arr

def patches_to_image(patches, image_size):
    Reconstructs image from patches.
    Only for the hi-res original & hi-res prediction.
    Does not work for the downsampled images.
    psize = int(patches.shape[-1]**0.5)
    patches_unflat = patches.reshape(-1, psize, psize)
    return reconstruct_from_patches_2d(patches_unflat, image_size)

# Load train and validation images
image_urls = [
images = [safely_convert_to_gray(io.imread(url)) for url in image_urls]
images = [rescale(image, 0.25) for image in images] #let's work with smaller images due to RAM

train_images = images[0:1]
val_image = images[1]

patch_size = 16
scale = 4

trn_features, trn_labels = images_to_Xy(train_images, patch_size, scale, shuffle=True, random_state=0)
val_features, val_labels = image_to_Xy(val_image, patch_size, scale)

#View some samples
f, axs = plt.subplots(2, 10, figsize=(9, 2.3))
axs = axs.flatten()
for i in range(0, 20, 2):
    ax = axs[i]
    ax.imshow(trn_features[i].reshape(patch_size // scale, patch_size // scale), cmap='gray')
    ax.set_title(' ' * 20 + f'X[{i}], y[{i}]', fontsize=8)
    ax = axs[i + 1]
    ax.imshow(trn_labels[i].reshape(patch_size, patch_size), cmap='gray')

#Fit model on train data
# Can limit number of samples for RAM
trn_limit = None
rf = RandomForestRegressor(n_estimators=15, random_state=np.random.RandomState(0), n_jobs=-1)
rf.fit(trn_features[:trn_limit], trn_labels[:trn_limit])

#Assess on validation set
val_predictions = rf.predict(val_features).astype(np.float32)
pred_image = patches_to_image(val_predictions, val_image.shape)

f, axs = plt.subplots(1, 3, figsize=(9, 3))
ax = axs[0]
ax.imshow(val_image, cmap='gray')

ax = axs[1]
ax.imshow(rescale(val_image, 1/scale, anti_aliasing=True), cmap='gray')
ax.set_title(f'Rescaled 1/{scale}')

ax = axs[2]
ax.imshow(pred_image, cmap='gray')

[ax.axis('off') for ax in axs]
