具有Keras模型的PyQt5 GUI在单独的线程中运行,再次按下“运行”按钮时卡住

问题描述 投票:0回答:1

我有一个应用程序接收.pickle文件与标记的火车数据,并应建立一个神经网络(与Keras)。它应该对数据进行训练,并使用matplotlib在画布上显示训练/验证错误,并显示QprogressBar的进度。

我有一个自定义回调函数,它将pyqtSignal发送到每个纪元结束的主GUI,发送当前纪元以及累积列车和验证错误。然后在主程序中有一个函数接收信号并触发更新方法。

一切正常,直到我按下GUI窗口 - 然后应用程序卡住(但网络仍在shell中运行)。我猜点击中断触发一些循环,使整个程序冻结,但我可以“弄清楚哪个。

我已经搜索了有关使用线程时PyQt5 GUI被卡住的其他问题,但没有找到答案 - here

我已经尝试使用Qthread.start()而不是Qthread.run()用于Qthread - 但在这种情况下,情节根本不会更新。

我已经编写了一个完整的示例来演示问题(数据文件应该是.pickle格式并包含X的列表[X,y] - 样本作为numpy ndarray,y - 相应的标签为numpy ndarray和can在https://www.kaggle.com/luciferadmin/heart-disease-uci-in-pickle-format找到):

import sys
import os
import pickle as pkl
from keras.layers import Input, Dense
from keras.callbacks import Callback
from keras.models import Model
from PyQt5.QtWidgets import QApplication, QProgressBar, QWidget, QVBoxLayout, QPushButton, QLineEdit, QFileDialog
from PyQt5.QtCore import QThread, pyqtSignal
import matplotlib
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure
from PyQt5.QtWidgets import (QSizePolicy)
matplotlib.use('Qt5Agg')


class Plot(FigureCanvas):
    def __init__(self, x_label, y_label, parent=None, width=5, height=4, dpi=100):
        fig = Figure(figsize=(width, height), dpi=dpi)

        self.axes = fig.add_subplot(111)

        self.compute_initial_figure()
        self.axes.set_xlabel(x_label)
        self.axes.set_ylabel(y_label)

        FigureCanvas.__init__(self, fig)
        self.setParent(parent)

        FigureCanvas.setSizePolicy(self, QSizePolicy.Expanding, QSizePolicy.Expanding)

        FigureCanvas.updateGeometry(self)

    def compute_initial_figure(self):
        self.axes.set_xticks(range(1, 100, 10))


class MultiPlot(Plot):
    def __init__(self, parent=None, x_axis_name='X', y_axis_name='Y', width=5, height=4, dpi=100):
        super().__init__(x_axis_name, y_axis_name, parent, width, height, dpi)
        self.compute_initial_figure()

    def compute_initial_figure(self):
        self.axes.set_xticks(range(0, 100, 10))

    def plot_multi_data(self, x_axis_name='X', y_axis_name='Y', plot_labels=None, y_list=None):
        if y_list is not None:
            self.axes.clear()

            graph_handles = []

            markers = ['b:', 'r']
            y_index = 0

            for y in y_list:

                x = range(1, len(y) + 1)
                label = plot_labels[y_index]

                new_plot, = self.axes.plot(x, y, markers[y_index], markersize=2, label=label)
                graph_handles.append(new_plot)
                y_index += 1

                self.axes.set_xticks(x, int(len(list(x))/10))
            self.axes.legend(handles=graph_handles, loc=0, fontsize=8, shadow=True)

        self.axes.set_xlabel(x_axis_name)
        self.axes.set_ylabel(y_axis_name)

        self.draw()


class TrainPlotCallback(Callback):
    def __init__(self, signal):
        Callback.__init__(self)
        self.train_err = []
        self.val_err = []
        self.signal = signal

    def on_epoch_end(self, epoch, logs={}):
        self.train_err.append(1 - logs.get('acc'))
        self.val_err.append(1 - logs.get('val_acc'))

        self.signal.emit(epoch, [self.train_err, self.val_err])


def classification_model(data_input_path, on_epoch_end_signal):

    # ///////////////////// TEST /////////////////////
    if os.path.exists(data_input_path):
        plot_losses = TrainPlotCallback(on_epoch_end_signal)
        with open(data_input_path, 'rb') as pickle_in:
            data = pkl.load(pickle_in)
            X = data[0]
            y = data[1]

        input_size = X.shape[1]

        # MODEL CREATION
        # ///////////////////// INPUT LAYER /////////////////////
        inputs = Input(shape=(input_size,))
        # ///////////////////// INPUT LAYER /////////////////////
        # ///////////////////// HIDDEN LAYER /////////////////////
        x = Dense(10, activation='relu', kernel_initializer='normal')(inputs)   # THE FIRST LAYER
        # ///////////////////// HIDDEN LAYER /////////////////////
        # ///////////////////// OUTPUT LAYERS /////////////////////
        predictions = Dense(len(y[0]), activation='softmax')(x)  # the length of the output layer is as the length of the classes being predicted.
        # ///////////////////// OUTPUT LAYERS /////////////////////
        # MODEL CREATION

        # ///////////////////// MODEL DEFINITION /////////////////////
        model = Model(inputs=inputs, outputs=predictions)
        model.compile(optimizer='Adam',
                      loss='categorical_crossentropy',
                      metrics=['acc'])
        # ///////////////////// MODEL DEFINITION /////////////////////

        # ///////////////////// MODEL TRAINING /////////////////////
        model.fit(X, y, validation_split=0.2, batch_size=100, epochs=100, callbacks=[plot_losses])
        # ///////////////////// MODEL TRAINING /////////////////////


class ModelThread(QThread):
    epoch_end_signal = pyqtSignal(int, list)  # signal that has epoch # as the first parameter, and a list that contains the error values for the train and validation.

    def __init__(self, data_input_path):
        QThread.__init__(self)
        self.data_input_path = data_input_path

    def __del__(self):
        self.wait()

    def run(self):
        classification_model(data_input_path=self.data_input_path,
                             on_epoch_end_signal=self.epoch_end_signal
                             )


class DashBoard(QWidget):
    def __init__(self):
        super().__init__()
        self.main_v_box = QVBoxLayout(self)

        #  <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< STRINGS >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
        self.input_data_path_str = ''
        #  <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< STRINGS >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
        #  <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< PROGRESS BAR >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
        self.progress_bar = QProgressBar()
        #  <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< PROGRESS BAR >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
        #  <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< BUTTONS >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
        self.run_model_btn = QPushButton('Run')
        self.browse_train_data_file_path_btn = QPushButton('Browse')
        #  <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< BUTTONS >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
        #  <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< MULTI PLOTS >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
        self.in_training_plot = MultiPlot(x_axis_name='Epoch Number', y_axis_name='Error')
        #  <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< MULTI PLOTS >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
        #  <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< LINE EDITS >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
        self.train_data_file_path_le = QLineEdit()
        #  <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< LINE EDITS >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
        self.init()
        self.pack()
        self.showMaximized()

    def init(self):
        self.progress_bar.hide()
        self.browse_train_data_file_path_btn.clicked.connect(self.on_btn_click)
        self.run_model_btn.clicked.connect(self.on_btn_click)

    def pack(self):
        self.main_v_box.addWidget(self.train_data_file_path_le)
        self.main_v_box.addWidget(self.browse_train_data_file_path_btn)
        self.main_v_box.addWidget(self.in_training_plot)
        self.main_v_box.addWidget(self.run_model_btn)
        self.main_v_box.addWidget(self.progress_bar)

    def on_btn_click(self):
        btn_index = self.sender()

        if btn_index == self.browse_train_data_file_path_btn:
            self.input_data_path_str = QFileDialog.getOpenFileName(self, '.pickle files', os.getenv('HOME'), '*.pickle')[0]
            self.train_data_file_path_le.setText(self.input_data_path_str)
        elif btn_index == self.run_model_btn:
            model_thread = ModelThread(data_input_path=self.input_data_path_str)
            model_thread.epoch_end_signal.connect(self.update_ui_on_epoch_end)
            self.progress_bar.show()
            model_thread.run()
            self.progress_bar.hide()

    def update_ui_on_epoch_end(self, current_epoch_num, error_lists):
        if current_epoch_num < 100:
            self.progress_bar.setValue(current_epoch_num)
        else:
            self.progress_bar.setValue(100)
        self.in_training_plot.plot_multi_data(x_axis_name='Epoch', y_axis_name='Error', plot_labels=['Train Accuracy', 'Validation Accuracy'], y_list=[error_lists[0], error_lists[1]])

    def run_model(self):
        if os.path.exists(self.train_data_file_path_str) and os.path.exists(self.output_data_path_str):
            train_thread = ModelThread(data_input_path='')
            train_thread.epoch_end_signal.connect(self.update_ui_on_epoch_end)
            # train_thread.start()
            self.progress_bar.show()
            train_thread.run()
            self.progress_bar.hide()


if __name__ == '__main__':
    app = QApplication(sys.argv)
    main_menu = DashBoard()
    sys.exit(app.exec_())
matplotlib keras pyqt5 qthread
1个回答
1
投票

您的初始错误是您不应该直接调用run而是启动,但您的线程是一个局部变量,稍后将被删除。

更好的解决方案是创建一个存在于另一个线程中的QObject,并使用QTimer.singleShot调用该函数,而不是创建自定义QThread。

import os
import sys
from functools import partial
import pickle as pkl
from keras.layers import Input, Dense
from keras.callbacks import Callback
from keras.models import Model

from PyQt5 import QtCore, QtWidgets

import matplotlib
matplotlib.use('Qt5Agg')

from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure

class Plot(FigureCanvas):
    def __init__(self, x_label, y_label, parent=None, width=5, height=4, dpi=100):
        fig = Figure(figsize=(width, height), dpi=dpi)

        self.axes = fig.add_subplot(111)

        self.compute_initial_figure()
        self.axes.set_xlabel(x_label)
        self.axes.set_ylabel(y_label)

        FigureCanvas.__init__(self, fig)
        self.setParent(parent)

        FigureCanvas.setSizePolicy(self, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding)

        FigureCanvas.updateGeometry(self)

    def compute_initial_figure(self):
        self.axes.set_xticks(range(1, 100, 10))


class MultiPlot(Plot):
    def __init__(self, parent=None, x_axis_name='X', y_axis_name='Y', width=5, height=4, dpi=100):
        super().__init__(x_axis_name, y_axis_name, parent, width, height, dpi)
        self.compute_initial_figure()

    def compute_initial_figure(self):
        self.axes.set_xticks(range(0, 100, 10))

    def plot_multi_data(self, x_axis_name='X', y_axis_name='Y', plot_labels=None, y_list=None):
        if y_list is not None:
            self.axes.clear()

            graph_handles = []

            markers = ['b:', 'r']
            y_index = 0

            for y in y_list:

                x = range(1, len(y) + 1)
                label = plot_labels[y_index]

                new_plot, = self.axes.plot(x, y, markers[y_index], markersize=2, label=label)
                graph_handles.append(new_plot)
                y_index += 1

                self.axes.set_xticks(x, int(len(list(x))/10))
            self.axes.legend(handles=graph_handles, loc=0, fontsize=8, shadow=True)

        self.axes.set_xlabel(x_axis_name)
        self.axes.set_ylabel(y_axis_name)

        self.draw()


class TrainPlotCallback(Callback):
    def __init__(self, signal):
        Callback.__init__(self)
        self.train_err = []
        self.val_err = []
        self.signal = signal

    def on_epoch_end(self, epoch, logs={}):
        self.train_err.append(1 - logs.get('acc'))
        self.val_err.append(1 - logs.get('val_acc'))
        self.signal.emit(epoch, [self.train_err, self.val_err])

def classification_model(data_input_path, on_epoch_end_signal):

    # ///////////////////// TEST /////////////////////
    if os.path.exists(data_input_path):
        plot_losses = TrainPlotCallback(on_epoch_end_signal)
        with open(data_input_path, 'rb') as pickle_in:
            data = pkl.load(pickle_in)
            X = data[0]
            y = data[1]

        input_size = X.shape[1]

        # MODEL CREATION
        # ///////////////////// INPUT LAYER /////////////////////
        inputs = Input(shape=(input_size,))
        # ///////////////////// INPUT LAYER /////////////////////
        # ///////////////////// HIDDEN LAYER /////////////////////
        x = Dense(10, activation='relu', kernel_initializer='normal')(inputs)   # THE FIRST LAYER
        # ///////////////////// HIDDEN LAYER /////////////////////
        # ///////////////////// OUTPUT LAYERS /////////////////////
        predictions = Dense(len(y[0]), activation='softmax')(x)  # the length of the output layer is as the length of the classes being predicted.
        # ///////////////////// OUTPUT LAYERS /////////////////////
        # MODEL CREATION

        # ///////////////////// MODEL DEFINITION /////////////////////
        model = Model(inputs=inputs, outputs=predictions)
        model.compile(optimizer='Adam',
                      loss='categorical_crossentropy',
                      metrics=['acc'])
        # ///////////////////// MODEL DEFINITION /////////////////////

        # ///////////////////// MODEL TRAINING /////////////////////
        model.fit(X, y, validation_split=0.2, batch_size=100, epochs=100, callbacks=[plot_losses])
        # ///////////////////// MODEL TRAINING /////////////////////


class Worker(QtCore.QObject):
    started = QtCore.pyqtSignal()
    finished = QtCore.pyqtSignal()
    epoch_end_signal = QtCore.pyqtSignal(int, list)  # signal that has epoch # as the first parameter, and a list that contains the error values for the train and validation.

    @QtCore.pyqtSlot(str)
    def start_task(self, input_path):
        self.started.emit()
        classification_model(data_input_path=input_path,
                             on_epoch_end_signal=self.epoch_end_signal)
        self.finished.emit()


class DashBoard(QtWidgets.QWidget):
    def __init__(self):
        super().__init__()
        self.main_v_box = QtWidgets.QVBoxLayout(self)
        self.input_data_path_str = ''
        self.progress_bar = QtWidgets.QProgressBar()
        self.run_model_btn = QtWidgets.QPushButton('Run')
        self.browse_train_data_file_path_btn = QtWidgets.QPushButton('Browse')
        self.in_training_plot = MultiPlot(x_axis_name='Epoch Number', y_axis_name='Error')
        self.train_data_file_path_le = QtWidgets.QLineEdit()
        self.init()
        self.pack()
        self.showMaximized()

    def init(self):
        self.worker = Worker()
        thread = QtCore.QThread(self)
        thread.start()
        self.worker.moveToThread(thread)
        self.progress_bar.hide()
        self.browse_train_data_file_path_btn.clicked.connect(self.on_btn_click)
        self.run_model_btn.clicked.connect(self.on_btn_click)
        self.worker.epoch_end_signal.connect(self.update_ui_on_epoch_end)
        self.worker.started.connect(self.progress_bar.show)
        self.worker.finished.connect(self.progress_bar.hide)
        self.worker.started.connect(partial(self.run_model_btn.setEnabled, False))
        self.worker.finished.connect(partial(self.run_model_btn.setEnabled, True))

    def pack(self):
        self.main_v_box.addWidget(self.train_data_file_path_le)
        self.main_v_box.addWidget(self.browse_train_data_file_path_btn)
        self.main_v_box.addWidget(self.in_training_plot)
        self.main_v_box.addWidget(self.run_model_btn)
        self.main_v_box.addWidget(self.progress_bar)

    def on_btn_click(self):
        btn_index = self.sender()

        if btn_index == self.browse_train_data_file_path_btn:
            self.input_data_path_str, _ = QtWidgets.QFileDialog.getOpenFileName(self, '.pickle files', os.getenv('HOME'), '*.pickle')
            self.train_data_file_path_le.setText(self.input_data_path_str)
        elif btn_index == self.run_model_btn:
            QtCore.QTimer.singleShot(0, partial(self.worker.start_task, self.input_data_path_str))

    def update_ui_on_epoch_end(self, current_epoch_num, error_lists):
        if current_epoch_num < 100:
            self.progress_bar.setValue(current_epoch_num)
        else:
            self.progress_bar.setValue(100)
        self.in_training_plot.plot_multi_data(x_axis_name='Epoch', y_axis_name='Error', plot_labels=['Train Accuracy', 'Validation Accuracy'], y_list=[error_lists[0], error_lists[1]])


if __name__ == '__main__':
    app = QtWidgets.QApplication(sys.argv)
    main_menu = DashBoard()
    sys.exit(app.exec_())
© www.soinside.com 2019 - 2024. All rights reserved.