OCR 结果不一致:训练和测试期间的预测不同

我的自定义 OCR(光学字符识别)模型遇到问题,它在训练和测试阶段会产生不同的预测。


Training started...
Epoch [1/50], Step [100/133], Loss: 1.4840
Epoch [2/50], Step [100/133], Loss: 0.1713
Epoch [3/50], Step [100/133], Loss: 0.1087
Epoch [4/50], Step [100/133], Loss: 0.0793
Epoch [5/50], Step [100/133], Loss: 0.0793
Epoch [6/50], Step [100/133], Loss: 0.0552
Epoch [7/50], Step [100/133], Loss: 0.0501
Epoch [8/50], Step [100/133], Loss: 0.0484
Epoch [9/50], Step [100/133], Loss: 0.0595
Epoch [10/50], Step [100/133], Loss: 0.0437
Epoch [11/50], Step [100/133], Loss: 0.0351
Epoch [12/50], Step [100/133], Loss: 0.0914
Epoch [13/50], Step [100/133], Loss: 0.0304
Epoch [14/50], Step [100/133], Loss: 0.0406
Epoch [15/50], Step [100/133], Loss: 0.0315
Epoch [16/50], Step [100/133], Loss: 0.0331
Epoch [17/50], Step [100/133], Loss: 0.0220
Epoch [18/50], Step [100/133], Loss: 0.0238
Epoch [19/50], Step [100/133], Loss: 0.0272
Epoch [20/50], Step [100/133], Loss: 0.0259
Epoch [21/50], Step [100/133], Loss: 0.0210
Epoch [22/50], Step [100/133], Loss: 0.0826
Epoch [23/50], Step [100/133], Loss: 0.0673
Epoch [24/50], Step [100/133], Loss: 0.0240
Epoch [25/50], Step [100/133], Loss: 0.0198
Epoch [26/50], Step [100/133], Loss: 0.0250
Epoch [27/50], Step [100/133], Loss: 0.0203
Epoch [28/50], Step [100/133], Loss: 0.0170
Epoch [29/50], Step [100/133], Loss: 0.0204
Epoch [30/50], Step [100/133], Loss: 0.0177
Epoch [31/50], Step [100/133], Loss: 0.0208
Epoch [32/50], Step [100/133], Loss: 0.0231
Epoch [33/50], Step [100/133], Loss: 0.0156
Epoch [34/50], Step [100/133], Loss: 0.0117
Epoch [35/50], Step [100/133], Loss: 0.0171
Epoch [36/50], Step [100/133], Loss: 0.0138
Epoch [37/50], Step [100/133], Loss: 0.0196
Epoch [38/50], Step [100/133], Loss: 0.0158
Epoch [39/50], Step [100/133], Loss: 0.0183
Epoch [40/50], Step [100/133], Loss: 0.0163
Epoch [41/50], Step [100/133], Loss: 0.0305
Epoch [42/50], Step [100/133], Loss: 0.0504
Epoch [43/50], Step [100/133], Loss: 0.0404
Epoch [44/50], Step [100/133], Loss: 0.0176
Epoch [45/50], Step [100/133], Loss: 0.0140
Epoch [46/50], Step [100/133], Loss: 0.0099
Epoch [47/50], Step [100/133], Loss: 0.0123
Epoch [48/50], Step [100/133], Loss: 0.0121
Epoch [49/50], Step [100/133], Loss: 0.0118
Epoch [50/50], Step [100/133], Loss: 0.0140
Training finished.
Testing started...
Predicted: A, Actual: A, Confidence: 100.00%
Predicted: B, Actual: B, Confidence: 98.97%
Predicted: C, Actual: C, Confidence: 100.00%
Predicted: D, Actual: D, Confidence: 99.46%
Predicted: E, Actual: E, Confidence: 99.63%
Predicted: F, Actual: F, Confidence: 99.32%
Predicted: G, Actual: G, Confidence: 99.92%
Predicted: H, Actual: H, Confidence: 99.99%
Predicted: I, Actual: I, Confidence: 90.64%
Predicted: J, Actual: J, Confidence: 97.04%
Predicted: K, Actual: K, Confidence: 100.00%
Predicted: L, Actual: L, Confidence: 99.08%
Predicted: M, Actual: M, Confidence: 100.00%
Predicted: O, Actual: A, Confidence: 67.16%
Predicted: P, Actual: P, Confidence: 96.75%
Predicted: Q, Actual: Q, Confidence: 99.71%
Predicted: R, Actual: R, Confidence: 99.98%
Predicted: S, Actual: S, Confidence: 99.25%
Predicted: T, Actual: T, Confidence: 91.48%
Predicted: U, Actual: U, Confidence: 100.00%
Predicted: V, Actual: V, Confidence: 100.00%
Predicted: W, Actual: W, Confidence: 100.00%
Predicted: X, Actual: X, Confidence: 100.00%
Predicted: Y, Actual: Y, Confidence: 100.00%
Predicted: Z, Actual: Z, Confidence: 100.00%
Predicted: 0, Actual: 0, Confidence: 100.00%
Predicted: 1, Actual: 1, Confidence: 100.00%
Predicted: 2, Actual: 2, Confidence: 100.00%
Predicted: 3, Actual: 3, Confidence: 100.00%
Predicted: 4, Actual: 4, Confidence: 100.00%
Predicted: 5, Actual: 5, Confidence: 100.00%
Predicted: 6, Actual: 6, Confidence: 100.00%
Predicted: 7, Actual: 7, Confidence: 100.00%
Predicted: 8, Actual: 8, Confidence: 100.00%
Predicted: 9, Actual: 9, Confidence: 100.00%
Accuracy on test dataset: 99.69%

Process finished with exit code 0

我使用 PyTorch 开发了一个自定义 OCR 模型。以下是相关组件的简要概述:

  • OCR模型:定义OCR神经网络的架构。
  • OCRDataset:自定义数据集类,用于处理 OCR 数据加载和预处理。
  • OCRHandler:处理 OCR 模型的训练和测试。
  • OCR:继承自OCRHandler,提供图像到文本转换的附加功能。


import string
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
from pathlib import Path
from PIL import UnidentifiedImageError, Image
from typing import Optional, Union

class OCRModel(nn.Module):
    def __init__(self, num_classes: int):
        Initialize the OCRModel.

            num_classes (int): Number of classes for classification.
        super(OCRModel, self).__init__()
        # Define convolutional layers
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        # Define pooling layer
        self.pool = nn.MaxPool2d(2, 2)
        # Define fully connected layers
        self.fc1 = nn.Linear(64 * 16 * 16, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        Forward pass through the network.

            x (torch.Tensor): Input tensor.

            torch.Tensor: Output tensor.
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class OCRDataset(Dataset):
    def __init__(self, data_path: Union[str, Path], transform: Optional[transforms.Compose] = None):
        Initialize the OCRDataset.

            data_path (Union[str, Path]): Path to the dataset.
            transform (Optional[transforms.Compose]): Transformations to apply to the data.
        if not isinstance(data_path, (str, Path)):
            raise TypeError("data_path must be a string or a Path object.")
        self.data_path = Path(data_path)
        self.transform = transform
        self.dataset = ImageFolder(root=str(self.data_path), transform=transform)

    def __len__(self) -> int:
        return len(self.dataset)

    def __getitem__(self, idx: int) -> tuple[None, None]:
        while True:
                image, label = self.dataset[idx]
                # Convert image to PIL Image to handle errors
                image = transforms.ToPILImage()(image)
                if self.transform:
                    image = self.transform(image)
            except (UnidentifiedImageError, OSError) as error:
                print(f"Error opening image at index {idx}: {error}")
                idx += 1
                if idx >= len(self):
                    print("Reached end of dataset.")
                    return None, None
        return image, label

class OCRHandler:
    def __init__(self, model: OCRModel):
        Initialize the OCRHandler.

            model (OCRModel): OCR model instance.
        if not isinstance(model, OCRModel):
            raise TypeError("model must be an instance of OCRModel.")
        self.model = model

    def train(self, train_data_path: Union[str, Path], num_epochs: int = 10, batch_size: int = 32,
              learning_rate: float = 0.001) -> None:
        Train the OCR model.

            train_data_path (Union[str, Path]): Path to the training dataset.
            num_epochs (int): Number of epochs for training.
            batch_size (int): Batch size for training.
            learning_rate (float): Learning rate for optimization.
        # Print a message to indicate the start of training
        print("Training started...")

        # Prepare the training dataset with transformations
        train_dataset = OCRDataset(train_data_path, transform=transforms.Compose([
            transforms.Resize((64, 64)),

        # Create a data loader for the training dataset
        train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

        # Define loss criterion and optimizer
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate)

        # Loop through epochs
        for epoch in range(num_epochs):
            # Initialize running loss for each epoch
            running_loss = 0.0
            # Loop through batches in the training dataloader
            for i, (images, labels) in enumerate(train_dataloader):
                # Clear previous gradients
                # Forward pass
                outputs = self.model(images)
                # Calculate loss
                loss = criterion(outputs, labels)
                # Backpropagation
                # Update weights
                # Accumulate loss
                running_loss += loss.item()
                # Print loss statistics every 100 steps
                if (i + 1) % 100 == 0:
                        f"Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_dataloader)}], Loss: {running_loss / 100:.4f}")
                    running_loss = 0.0
        print("Training finished.")

    def test(self, testing_path: Union[str, Path]) -> None:
        Test the OCR model on the test dataset.

            testing_path (str): Path to the test dataset.
        # Print a message to indicate testing has started
        print("Testing started...")

        # Prepare the test dataset
        test_dataset = ImageFolder(root=str(testing_path), transform=transforms.Compose([
            transforms.Resize((64, 64)),

        # Create a data loader for the test dataset
        test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)

        # Initialize counters for correct predictions and total samples
        correct = 0
        total = 0

        # Disable gradient calculation for inference
        with torch.no_grad():
            # Iterate over the test dataset
            for images, labels in test_dataloader:
                # Forward pass through the model
                outputs = self.model(images)
                # Get the predicted classes
                _, predicted = torch.max(outputs, 1)
                # Update correct predictions count
                correct += (predicted == labels).sum().item()
                # Update total count
                total += labels.size(0)
                # Calculate confidence of the prediction
                confidence = torch.softmax(outputs, 1)[0][predicted.item()].item() * 100
                # Convert predicted and actual labels to uppercase letters or digits
                predicted_label = string.ascii_uppercase[predicted.item()] if predicted.item() < 26 else str(
                    predicted.item() - 26)
                actual_label = string.ascii_uppercase[labels.item()] if labels.item() < 26 else str(labels.item() - 26)
                # Print prediction details
                print(f"Predicted: {predicted_label}, Actual: {actual_label}, Confidence: {confidence:.2f}%")

        # Calculate and print accuracy
        print(f"Accuracy on test dataset: {(correct / total) * 100:.2f}%")

    def save(self) -> None:
        Save the trained model.
        save_path = Path(__file__).resolve().parent
        save_path.mkdir(parents=True, exist_ok=True)
        model_path = save_path / "ocr_model.pt"
        torch.save(self.model.state_dict(), model_path)
        print(f"Model saved successfully at: {model_path}")

class OCR(OCRHandler):
    def __init__(self, debug=False):
        Initialize the OCR class.

            debug (bool, optional): Whether to enable debug mode. Defaults to False.
        # Load the OCR model
        ocr_model = OCRModel(num_classes=36)
        self.debug = debug

    def image_to_text(self, image_path: Union[str, Path]) -> Union[str, tuple[str, dict | None]]:
        Convert an image to text using the OCR model.

            image_path (Union[str, Path]): Path to the input image.

            Union[str, tuple[str, dict]]: The predicted text. If debug mode is enabled,
                returns a tuple containing the predicted text and a dictionary with debug information.
        # Prepare the input image
        image = Image.open(image_path).convert("RGB")
        transform = transforms.Compose([
            transforms.Resize((64, 64)),
        # Add batch dimension
        image = transform(image).unsqueeze(0)

        # Disable gradient calculation for inference
        with torch.no_grad():
            # Forward pass through the model
            outputs = self.model(image)
            # Get the predicted class
            _, predicted = torch.max(outputs, 1)
            # Convert predicted label to uppercase letter or digit
            predicted_text = string.ascii_uppercase[predicted.item()] if predicted.item() < 26 else str(
                predicted.item() - 26)

            if self.debug:
                # Calculate confidence scores
                confidence_scores = torch.softmax(outputs, 1)[0].tolist()
                # Convert confidence scores to percentages
                confidence_percentages = [score * 100 for score in confidence_scores]
                # Create a dictionary with debug information
                debug_info = {
                    "predicted_text": predicted_text,
                    "confidence_scores": {string.ascii_uppercase[i] if i < 26 else str(i - 26): percentage
                                          for i, percentage in enumerate(confidence_percentages)},
                    "top_predictions": [
                            "class": string.ascii_uppercase[i] if i < 26 else str(i - 26),
                            "confidence": percentage
                        for i, percentage in enumerate(confidence_percentages)
                return predicted_text, debug_info
                return predicted_text, None

if __name__ == "__main__":
    ocr_model = OCRModel(num_classes=36)
    trainer = OCRHandler(ocr_model)
    train_data_path = Path("dataset/text_identification/train")
    trainer.train(train_data_path, num_epochs=50, batch_size=32, learning_rate=0.001)
    test_data_path = Path("dataset/text_identification/test")

当我尝试测试 OCR 时,在识别字符“0”时,它出现了 G?对我来说,经过数百张图像训练的东西如何得出 G 是毫无意义的。

{'predicted_text': 'G', 'confidence_scores': {'A': 2.5511952117085457, 'B': 2.9084114357829094, 'C': 2.9245806857943535, 'D': 3.019385412335396, 'E': 2.867276221513748, 'F': 2.7050500735640526, 'G': 3.0951984226703644, 'H': 3.0260657891631126, 'I': 2.691103331744671, 'J': 2.631979249417782, 'K': 2.5525128468871117, 'L': 2.7190934866666794, 'M': 2.6693686842918396, 'N': 2.9030684381723404, 'O': 2.5663597509264946, 'P': 2.941553108394146, 'Q': 2.8950219973921776, 'R': 2.789396792650223, 'S': 2.6342585682868958, 'T': 2.5583021342754364, 'U': 2.799813263118267, 'V': 2.574686147272587, 'W': 2.713942527770996, 'X': 2.824728935956955, 'Y': 2.85495538264513, 'Z': 2.7191564440727234, '0': 2.6622315868735313, '1': 2.6157714426517487, '2': 2.8603684157133102, '3': 2.5942767038941383, '4': 2.733685076236725, '5': 2.891615778207779, '6': 3.0571456998586655, '7': 2.551230974495411, '8': 2.9105449095368385, '9': 2.986663021147251}, 'top_predictions': [{'class': 'A', 'confidence': 2.5511952117085457}, {'class': 'B', 'confidence': 2.9084114357829094}, {'class': 'C', 'confidence': 2.9245806857943535}, {'class': 'D', 'confidence': 3.019385412335396}, {'class': 'E', 'confidence': 2.867276221513748}, {'class': 'F', 'confidence': 2.7050500735640526}, {'class': 'G', 'confidence': 3.0951984226703644}, {'class': 'H', 'confidence': 3.0260657891631126}, {'class': 'I', 'confidence': 2.691103331744671}, {'class': 'J', 'confidence': 2.631979249417782}, {'class': 'K', 'confidence': 2.5525128468871117}, {'class': 'L', 'confidence': 2.7190934866666794}, {'class': 'M', 'confidence': 2.6693686842918396}, {'class': 'N', 'confidence': 2.9030684381723404}, {'class': 'O', 'confidence': 2.5663597509264946}, {'class': 'P', 'confidence': 2.941553108394146}, {'class': 'Q', 'confidence': 2.8950219973921776}, {'class': 'R', 'confidence': 2.789396792650223}, {'class': 'S', 'confidence': 2.6342585682868958}, {'class': 'T', 'confidence': 2.5583021342754364}, {'class': 'U', 'confidence': 2.799813263118267}, {'class': 'V', 'confidence': 2.574686147272587}, {'class': 'W', 'confidence': 2.713942527770996}, {'class': 'X', 'confidence': 2.824728935956955}, {'class': 'Y', 'confidence': 2.85495538264513}, {'class': 'Z', 'confidence': 2.7191564440727234}, {'class': '0', 'confidence': 2.6622315868735313}, {'class': '1', 'confidence': 2.6157714426517487}, {'class': '2', 'confidence': 2.8603684157133102}, {'class': '3', 'confidence': 2.5942767038941383}, {'class': '4', 'confidence': 2.733685076236725}, {'class': '5', 'confidence': 2.891615778207779}, {'class': '6', 'confidence': 3.0571456998586655}, {'class': '7', 'confidence': 2.551230974495411}, {'class': '8', 'confidence': 2.9105449095368385}, {'class': '9', 'confidence': 2.986663021147251}]}
[模型]在训练期间产生不同的预测 测试阶段。

从您分享的内容来看,这是您不可能做出的声明。从你的评估中我可以看到,你的损失是 Loss: 0.0140 (如果这是准确度指标,那么你的准确度为 98.6%,在测试过程中,你的准确度甚至会上升到 99.69%。所以你的模型在测试中的表现比在训练中更好我们在这里看到的数据。

在训练过程中,模型似乎表现得非常好,令人难以想象 好吧,在训练数据集上实现了高精度。然而,当我 在测试过程中使用模型从图像中预测文本,结果 不一致。


a) 定义一个验证集,您在每个时期之后运行该验证集,以真正看到不一致的情况 b) 增加测试集的大小。您的测试集越小,结果准确性(或您使用的任何指标)的方差就越高。从上面来看,您的测试集太小,可能会误导您。 c) 当您这样做时,为每个角色创建验证集,这样您就可以了解类别损失并真正了解模型的弱点在哪里。


每个班级的置信度分数似乎也有波动 显着,导致意想不到的预测。

信心是另一个误导性术语。虽然这可能反映了预测类别的概率,但当您的类别集中有两个相似的类别时,它们总是会相互影响,因此具有共享且较低的概率(例如 G 和 0 或 O 和 0)。 要明白我的意思,您可以


当我尝试测试我的 OCR 时,在识别字符“0”时,它 出现G?

是的,如上所述 - 相似的类使预测任务变得更加困难。有时确实很难区分 0 和 G - 取决于书写方式。如果您想更好地了解当前的问题,

绘制一些样本总是一个好主意。特别是如果是手写字符,则存在语义差距,如果没有上下文,您将无法达到 100% 的预测准确性。

我不知道现在该怎么办,我对神经网络的了解很少 网络。



