我正在尝试创建一个自定义变换来调整 X 射线的旋转以训练神经网络,如此处所示 - 旋转 X 射线的链接。该函数适用于测试图像,无论是 PNG 还是 JPG,但是,当我将其添加到神经网络处理中时,从 PNG 到 np.array 的转换无法正常工作。它最终生成一个全黑图像,然后(显然)不能用于边缘检测。单步执行代码时,它确实将 PIL.Image 转换为 np.array,但出现了问题。这是代码:
调整图像方法:
class AdjustImage(object):
THRESHOLD = 240
def subimage(self, image, center, theta, width, height):
if 45 < theta <= 90:
theta = theta - 90
width, height = height, width
theta *= math.pi / 180 # convert to rad
v_x = (math.cos(theta), math.sin(theta))
v_y = (-math.sin(theta), math.cos(theta))
s_x = center[0] - v_x[0] * (width / 2) - v_y[0] * (height / 2)
s_y = center[1] - v_x[1] * (width / 2) - v_y[1] * (height / 2)
mapping = np.array([[v_x[0],v_y[0], s_x], [v_x[1],v_y[1], s_y]])
return cv2.warpAffine(image, mapping, (width, height), flags=cv2.WARP_INVERSE_MAP, borderMode=cv2.BORDER_REPLICATE)
def __call__(self, image_source):
# First slightly crop edge - some images had a rogue 2 pixel black edge on one side
#https://github.com/python-pillow/Pillow/issues/6765 - pil issue with pngs
type_var = image_source.format
image_source.show(title="Initial Image")
if isinstance(image_source, np.ndarray):
print("working")
elif image_source.format == 'JPEG':
image_source = np.array(image_source)
image_source = image_source[:,:,::-1].copy()
else:
image_source = image_source.point(lambda x: x / 256)
image_source = image_source.convert('RGB')
image_source.show(title="After Conversion")
image_source = np.array(image_source)
image_source = image_source[:,:,::-1].copy()
#This method was designed and tested on using cv2.imread - these steps change a PIL PNG read into that.
#Otherwise, it will destroy the image
init_crop = 5
h, w = image_source.shape[:2]
image_source = image_source[init_crop:init_crop+(h-init_crop*2), init_crop:init_crop+(w-init_crop*2)]
# Add back a white border
image_source = cv2.copyMakeBorder(image_source, 5,5,5,5, cv2.BORDER_CONSTANT, value=(255,255,255))
image_gray = cv2.cvtColor(image_source, cv2.COLOR_BGR2GRAY)
_, image_thresh = cv2.threshold(image_gray, self.THRESHOLD, 255, cv2.THRESH_TOZERO_INV)
image_thresh2 = image_thresh.copy()
image_thresh2 = cv2.Canny(image_thresh2, 100, 100, apertureSize=3)
points = cv2.findNonZero(image_thresh2)
centre, dimensions, theta = cv2.minAreaRect(points)
rect = cv2.minAreaRect(points)
width = int(dimensions[0])
height = int(dimensions[1])
box = cv2.boxPoints(rect)
box = np.int0(box)
temp = image_source.copy()
cv2.drawContours(temp, [box], 0, (255,0,0), 2)
M = cv2.moments(box)
cx = int(M['m10']/M['m00'])
cy = int(M['m01']/M['m00'])
image_patch = self.subimage(image_source, (cx, cy), (theta+90), height, width)
# add back a small border
image_patch = cv2.copyMakeBorder(image_patch, 1,1,1,1, cv2.BORDER_CONSTANT, value=(255,255,255))
# Convert image to binary, edge is black. Do edge detection and convert edges to a list of points.
# Then calculate a minimum set of points that can enclose the points.
_, image_thresh = cv2.threshold(image_patch, self.THRESHOLD, 255, 1)
image_thresh = cv2.Canny(image_thresh, 100, 100, 3)
points = cv2.findNonZero(image_thresh)
hull = cv2.convexHull(points)
# Find min epsilon resulting in exactly 4 points, typically between 7 and 21
# This is the smallest set of 4 points to enclose the image.
for epsilon in range(3, 50):
hull_simple = cv2.approxPolyDP(hull, epsilon, 1)
if len(hull_simple) == 4:
break
hull = hull_simple
# Find closest fitting image size and warp/crop to fit
# (ie reduce scaling to a minimum)
x,y,w,h = cv2.boundingRect(hull)
target_corners = np.array([[0,0],[w,0],[w,h],[0,h]], np.float32)
# Sort hull into tl,tr,br,bl order.
# n.b. hull is already sorted in clockwise order, we just need to know where top left is.
source_corners = hull.reshape(-1,2).astype('float32')
min_dist = 100000
index = 0
for n in range(len(source_corners)):
x,y = source_corners[n]
dist = math.hypot(x,y)
if dist < min_dist:
index = n
min_dist = dist
# Rotate the array so tl is first
source_corners = np.roll(source_corners , -(2*index))
try:
transform = cv2.getPerspectiveTransform(source_corners, target_corners)
return cv2.warpPerspective(image_patch, transform, (w,h))
except:
print ("Warp failure", image_source)
return image_patch
def __repr__(self):
return self.__class__.__name__+'()'
设置神经网络:
#Constructing the ResNeXt model:
transforms = v2.Compose([AdjustImage(),v2.Resize([256,256]), v2.PILToTensor()])
train_ds = datasets.ImageFolder("C:\\Users\\jenni\\Desktop\\Diss_Work\\X-ray_Images\\train", transform = transforms)
test_ds = datasets.ImageFolder("C:\\Users\\jenni\\Desktop\\Diss_Work\\X-ray_Images\\test", transform = transforms)
val_ds = datasets.ImageFolder("C:\\Users\\jenni\\Desktop\\Diss_Work\\X-ray_Images\\val", transform = transforms)
batch_size = 16
train_dataloader = DataLoader(train_ds, batch_size=batch_size)
test_dataloader = DataLoader(test_ds, batch_size=batch_size)
val_dataloader = DataLoader(val_ds, batch_size = batch_size)
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnext50_32x4d', pretrained=True)
print(model)
model.type(torch.LongTensor)
model.to(device)
loss_fn = nn.CrossEntropyLoss()
optimiser = torch.optim.Adam(model.parameters(), lr = 1e-3)
火车方法:
def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
model.train()
for batch, (X, y) in enumerate(dataloader):
X, y = X.type(torch.FloatTensor), y.type(torch.FloatTensor)
X,y = X.to(device), y.to(device)
# Compute prediction error
pred = model(X)
loss = loss_fn(pred, y.type(torch.LongTensor).to(device))
# Backpropagation
loss.backward()
optimizer.step()
optimizer.zero_grad()
if batch % 100 == 0:
loss, current = loss.item(), (batch + 1) * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
运行神经网络:
epochs = 5
model.to(device)
for t in range(epochs):
print(f"Epoch {t+1}\n-------------------------------")
train(train_dataloader, model, loss_fn, optimiser)
test(test_dataloader, model, loss_fn)
print("Done!")
当它试图寻找空无一物的外壳时,它会出错。我正在使用 pytorch 的 resnext-50,并且没有对其进行任何调整。
我认为下面这行对你有帮助:
image_source = image_source.point(lambda x: x / 256)
我假设
image_source
是一个 PIL Image
,是从 PNG 加载的。如果是这样,那么至少存在一个问题。如果它的像素范围为 0..255,并且除以 256 并将结果存储在 8 位无符号字符中,则每个像素将得到 0 或 1 的结果。我猜你实际上想要一个 0..1 范围内的浮点数,所以你可能想转换为 Numpy 数组,然后浮点数,然后沿着这些线划分:
myFloatArray = np.asarray(image_source).astype(np.float)/256.0
另一个潜在的问题是 PNG 图像可能是 16 位的,这会导致更令人头痛的问题,因为首先范围将是 0..65535,所以你的缩放会是错误的,其次因为 PIL 不会读取 16 位 PNG如果它们是颜色的,请注意这一点。