提高实时数据采集下代码的性能并理解奇怪的bug

问题描述 投票:0回答:1

我正在使用一个物理实验系统,使用相机实时跟踪其演变。具体来说,我触发相机以 2.5Hz 的速率检索图像,因此系统的效率对于跟上实验的进度至关重要。

我当前的程序会在专用文件夹

path
中获取第一张图像,并要求用户选择一个感兴趣的区域来执行所需的操作。接下来,它计算所获取图像以及一批 n 个图像的每个像素的平均亮度,
img_round

每个图像批次的平均亮度相对于

img_round
的迭代次数实时绘制。

目前,当我在静态数据上运行该程序时,效果很好。然而,当我尝试在预期的处理实验设置中运行它时,其中图像被主动添加到文件夹中,我得到了绘图中亮度的错误值。

总的来说,我担心我的代码效率不高,我想尽可能优化它。

请找到以下代码:

import os
import cv2
import numpy as np
import pyqtgraph as pg
from scipy.optimize import minimize_scalar
from pyqtgraph.Qt import QtCore, QtGui, QtWidgets
import time 
 
center = (0, 0)
radius = (0)
is_dragging_center = False
is_dragging_radius = False
global avg_brightness_per_img_round
avg_brightness_per_img_round = 0
 
img_round = 5
run_count = 0
 
brightness_history = []
std_history = []
func_history = []
global scatter_item
scatter_item = None
 
def update_display_image():
    global resized_image
    if resized_image is not None:
        display_image = resized_image.copy()
        cv2.circle(display_image, center, radius, (0, 255, 0), 2)
        cv2.circle(display_image, center, 5, (0, 0, 255), thickness=cv2.FILLED)
 
class UpdateDisplaySignal(QtCore.QObject):
    update_display_signal = QtCore.pyqtSignal()
 
update_display_signal_obj = UpdateDisplaySignal()
update_display_signal_obj.update_display_signal.connect(update_display_image)
 
 
def on_mouse(event, x, y, flags, param):
    global center, radius, is_dragging_center, is_dragging_radius
 
    if event == cv2.EVENT_LBUTTONDOWN:
        if np.sqrt((x - center[0]) ** 2 + (y - center[1]) ** 2) < 20:
            is_dragging_center = True
        else:
            is_dragging_radius = True
 
    elif event == cv2.EVENT_LBUTTONUP:
        is_dragging_center = False
        is_dragging_radius = False
 
    elif event == cv2.EVENT_MOUSEMOVE:
        if is_dragging_center:
            center = (x, y)
        elif is_dragging_radius:
            radius = int(np.sqrt((x - center[0]) ** 2 + (y - center[1]) ** 2))
 
app = QtWidgets.QApplication([])
pw = pg.PlotWidget(title='Mean Brightness vs image round')
pw.setLabel('left', 'Mean Brightness')
pw.setLabel('bottom', 'Image round')
scatter = pg.ScatterPlotItem(size=10, pen=pg.mkPen(None), brush=pg.mkBrush(255, 0, 0, 120))
line = pg.PlotDataItem(pen=pg.mkPen(color=(0,0,255), width=2))
pw.addItem(line)
pw.addItem(scatter)
 
def update_scatter():
    global scatter_item
    indices, values = zip(*enumerate(brightness_history, start=1))
    x = list(indices)
    y = list(values)
 
    if scatter_item is None:
        scatter_item = pg.ScatterPlotItem(size=10, pen=pg.mkPen(None), brush=pg.mkBrush(255, 0, 0, 120))
        pw.addItem(scatter_item)
 
    if isinstance(x, int):  
        x = [x]  
 
    if len(x) > 1:
        line.setData(x=x, y=y)
        scatter.setData(x=x, y=y, symbol='o', size=10, pen=pg.mkPen(None), brush=pg.mkBrush(255, 0, 0, 120))
 
        for i, (xi, yi) in enumerate(zip(x, y)):
            label = pg.TextItem(text=f'{yi:.2f}', anchor=(0, 0))
            label.setPos(xi, yi)
            pw.addItem(label)
            
win = QtWidgets.QMainWindow()
win.setCentralWidget(pw)
win.show()
 
path = r'C:\Users\blehe\Desktop\Betatron\images'
 
def calc_xray_count(image_path, center, radius):
    original_image = cv2.imread(image_path, cv2.IMREAD_ANYDEPTH)
 
    median_filtered_image = cv2.medianBlur(original_image, 5)
 
    mask = np.zeros(original_image.shape, dtype=np.uint8)
    cv2.circle(mask, center, radius, 255, thickness=cv2.FILLED)
 
    median_filtered_image += 1  # Avoid not counting black pixels in image
    result = cv2.bitwise_and(median_filtered_image, median_filtered_image, mask=mask)
 
    pixel_count = np.count_nonzero(result)
 
    img_brightness_sum = np.sum(result)
    img_var = np.var(result)
 
    if (pixel_count > 0):
        img_avg_brightness = (img_brightness_sum/pixel_count) -1 # Subtract back to real data
    else:
        img_avg_brightness = 0
 
    return img_avg_brightness, img_var
 
#-----------------------------------------------------------------------

image_files = []
for filename in os.listdir(path):
    if filename.endswith('.TIF'):
        image_files.append(os.path.join(path, filename))
 
first_image_path = image_files[0]
image = cv2.imread(first_image_path)
 
scale_percent = 60 
width = int(image.shape[1] * scale_percent / 100)
height = int(image.shape[0] * scale_percent / 100)
dim = (width, height)
gray_img = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
colored_image = cv2.applyColorMap(gray_img, cv2.COLORMAP_PINK)
resized_image = cv2.resize(colored_image, dim, interpolation=cv2.INTER_AREA)
 
center = (resized_image.shape[1] // 2, resized_image.shape[0] // 2)
radius = min(resized_image.shape[1] // 3, resized_image.shape[0] // 3)
 
cv2.namedWindow("Adjust the circle (press 'Enter' to proceed)")
cv2.setMouseCallback("Adjust the circle (press 'Enter' to proceed)", on_mouse)
 
while True:
    display_image = resized_image.copy()
 
    cv2.circle(display_image, center, radius, (0, 255, 0), 2)
    cv2.circle(display_image, center, 5, (0, 0, 255), thickness=cv2.FILLED)
    cv2.imshow("Adjust the circle (press 'Enter' to proceed)", display_image)
 
    key = cv2.waitKey(1) & 0xFF
    if key == 13: 
        break
 
cv2.destroyAllWindows()
            
center = (int(center[0] / scale_percent * 100), int(center[1] / scale_percent * 100))
radius = int(radius / scale_percent * 100)
 
img_round_brightness_sum = 0
img_round_var_sum = 0
 
def process_images():
    global run_count, img_round_brightness_sum, img_round_var_sum
 
    while run_count < len(os.listdir(path)):   
        for i, image_path in enumerate(image_files, start=1):
            img_avg_brightness, img_var = calc_xray_count(image_path, center, radius)
            img_round_brightness_sum += img_avg_brightness
            img_round_var_sum += img_var
 
            run_count += 1
 
            if run_count % img_round == 0:
                avg_brightness_per_img_round = (img_round_brightness_sum/img_round)
                deviation_per_img_round = np.sqrt(img_round_var_sum/img_round)
 
                brightness_history.append(avg_brightness_per_img_round)
                std_history.append(deviation_per_img_round)
 
                update_scatter()

                img_round_brightness_sum = 0
                img_round_var_sum = 0
 
                img_avg_brightness = 0
                img_var = 0
 
                QtCore.QCoreApplication.processEvents()
                QtCore.QThread.msleep(100)
 
if __name__ == "__main__":
    timer = QtCore.QTimer() 
    timer.timeout.connect(process_images)
    timer.start(100)  
    app.exec_()

我真的很感激任何帮助。谢谢您。

python performance image-processing optimization pyqtgraph
1个回答
0
投票

很难理解您编写代码的方式想要实现的目标,所以请清理它。这是一个示例和一些改进建议,但我很快就放弃了重构整个事情。您可以对此代码进行大量清理和重构,例如将函数分解为更小的单元,然后链接这些函数。使用类来保存处理器的状态,而不是到处依赖全局变量等。

import numpy as np
import matplotlib.pyplot as plt
import cv2

class Processor:
    def __init__(self,img):
        self.img = img # Self.img can act as your image buffer
        self.center = (75,75) # I hard coded values but you can calculate and initialize center here
        self.radius = 20
        self.is_dragging_center = False # Initialize all of your global variables here
        self.is_dragging = False
        self.dragging_center = False
        
    def draw_circles(self, img): #Either explicitly pass the image or use self.img
        cv2.circle(img, self.center, self.radius, (0, 255, 0), 2)
        cv2.circle(img, self.center, 5, (0, 0, 255), thickness=cv2.FILLED)
    
    def resize_img(self,img, scale_percent=60):
        width = img.shape[1] * scale_percent // 100
        height = img.shape[0] * scale_percent // 100
        return cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
    
    def transform_color(self, img, cma=cv2.COLORMAP_PINK):
        return cv2.applyColorMap(cv2.cvtColor(img, cv2.COLOR_BGR2GRAY), cmap)
    
    def on_mouse(self, event, x, y, flags, param):
        cx,cy = self.center
        if event == cv2.EVENT_LBUTTONDOWN:
            # <20 should probably not be hard coded?
            # if you want performance do you actually care about dragging in a perfectly circular radius?
            # For example using "if abs(x-cx) + abs(y-cy) < 20" creates a diamond pattern but it's 3 to 4x faster
            # you can also simplify this to ((x - cx) ** 2 + (y - cy) ** 2) < 400
            # "is_dragging_radius" looks to be : "not is_dragging_center", do you really need two variables here?          
#             self.is_dragging = True
#             self.dragging_center = ((x - cx) ** 2 + (y - cy) ** 2) < 400
            if np.sqrt((x - cx) ** 2 + (y - cy) ** 2) < 20:
                self.is_dragging_center = True
            else:
                self.is_dragging_radius = True
        # All of the "elif"s can be simplified to just "if"s
        elif event == cv2.EVENT_LBUTTONUP:
#             self.is_dragging = False
            self.is_dragging_center = False
            self.is_dragging_radius = False

#         if event == cv2.EVENT_MOUSEMOVE and self.is_dragging:
#             if dragging_center:
#                 center = (x, y)
#             else:
#                 radius = int(np.sqrt((x - center[0]) ** 2 + (y - center[1]) ** 2))
        elif event == cv2.EVENT_MOUSEMOVE:
            if is_dragging_center:
                center = (x, y)
            elif is_dragging_radius:
                radius = int(np.sqrt((x - center[0]) ** 2 + (y - center[1]) ** 2))
                
p = Processor(np.zeros((150,150,3), dtype=np.uint8))
p.draw_circles(p.img)
plt.figure()
plt.imshow(p.img)
plt.show()

关于性能问题以及如何在将图像保存到磁盘时处理图像,您可以执行以下操作:

from concurrent.futures import ThreadPoolExecutor as TPE, wait
# below imports are only used in the second example
from collections import deque
from threading import Thread, Event
# below imports are just used for this example, you can ignore them
import time
from random import random, choice
from string import ascii_letters, digits
from threading import Lock

def generator(max_delay): # Acts as the data generator, i.e: a camera in your case
    alphabet = list(ascii_letters)+list(digits)
    while True:
        yield ''.join(choice(alphabet) for i in range(20))
        time.sleep(random()*max_delay)
        
        
def save(x, lock):
    time.sleep(random()*0.5)
    with lock:
        print(f'saved {x}')

def process(x, lock):
    time.sleep(random()*0.8)
    with lock:
        print(f'Processed {x} to {x[:10].swapcase()}')
        
def producer(generator, queue, event, lock):
    for x in generator:
        if event.is_set():
            break
        queue.append(x)
    with lock:
        print(f'{"*"*20} Done prodcuing. Shutting down. {"*"*20}')


lock = Lock() # Lock is used to make sure multiple threads don't print at the same time
############################## 
##############################
###     FIRST EXAMPLE      ###
##############################
##############################
# In this example we don't mind if we miss captured frames since we don't have a large enough memory to store
# all of the incoming data and our processing might be too slow to process in real time
g = generator(0.3)
with TPE() as executor:
    start = time.time()
    while (time.time()-start)<5: # Run experiment for 5 seconds
        x = next(g) # Get the next input form "camera"
        # Save the image and process it in parallel (you may want to pass a copy to be processed)
        futures = [executor.submit(save,x,lock), executor.submit(process,x,lock)]
        # Wait for saving and processing to finish before processing the next frame
        wait(futures)
        print('#'*50)
    print('Done processing and saving. Shutting down.')


print('_'*100)
print('_'*100)
print('_'*100)
print('_'*100)
############################## 
##############################
###     SECOND EXAMPLE     ###
##############################
##############################
# This second example deals with a case that your input rate is higher than your processing throughput
# If you can process everything in real time OR if you don't mind missing frames, then you can ignore this part.
# This is only advised if you have enough ram available to hold all the captured data


q = deque()
shutdown_event = Event()
t = Thread(target=producer, args=(generator(0.1), q, shutdown_event, lock))
t.start()

with TPE() as executor:
    start = time.time()
    while (time.time()-start)<5: # Run experiment for 5 seconds
        # If queue is empty, wait for it to be populated (in case that input rate is lower than processing throughput)
        while len(q)==0:
            pass
        x = q.popleft() # Get the next "frame" from the queue
        # Save the image and process it in parallel (you may want to pass a copy to be processed)
        futures = [executor.submit(save,x,lock), executor.submit(process,x,lock)]
        # Wait for saving and processing to finish before processing the next frame
        wait(futures)
        print('#'*50)
    # now that your experiment is done, shutdown the producer
    shutdown_event.set()
    # Keep processing until the queue is empty
    while len(q)!=0:
        x = q.popleft() # Get the next input form "camera"
        # Save the image and process it in parallel (you may want to pass a copy to be processed)
        futures = [executor.submit(save,x,lock), executor.submit(process,x,lock)]
        # Wait for saving and processing to finish before processing the next frame
        wait(futures)
        print('#'*50)
    print('Done processing and saving. Shutting down.')
t.join()

请注意,即使实验结束并且您已停止捕获帧,第二个示例也会继续处理。它还输出更多的数据,因为没有丢失任何输入(并且

max_delay
是 0.1 而不是 0.3 秒)。

© www.soinside.com 2019 - 2024. All rights reserved.