如何从 PyTorch 模型中获取特定层的输出?

问题描述 投票:0回答:3

如何从预训练的 PyTorch 模型(例如 ResNet 或 VGG)中提取特定层的特征,而无需再次进行前向传递?

python pytorch
3个回答
14
投票

新答案

编辑: torchvision v0.11.0 中有一个新功能,允许提取特征

例如,如果你想从图层

layer4.2.relu_2
中提取特征,你可以这样做:

import torch
from torchvision.models import resnet50
from torchvision.models.feature_extraction import create_feature_extractor

x = torch.rand(1, 3, 224, 224)

model = resnet50()

return_nodes = {
    "layer4.2.relu_2": "layer4"
}
model2 = create_feature_extractor(model, return_nodes=return_nodes)
intermediate_outputs = model2(x)

旧答案

您可以在您想要的特定层上注册forward hook。比如:

def some_specific_layer_hook(module, input_, output):
    pass  # the value is in 'output'

model.some_specific_layer.register_forward_hook(some_specific_layer_hook)
    
model(some_input)

例如,要获取 ResNet 中的

res5c
输出,您可能需要使用
nonlocal
变量(或 Python 2 中的
global
):

res5c_output = None

def res5c_hook(module, input_, output):
    nonlocal res5c_output
    res5c_output = output

resnet.layer4.register_forward_hook(res5c_hook)

resnet(some_input)
    
# Then, use `res5c_output`.

3
投票

接受的答案非常有帮助!我在这里发布了一个完整的示例(使用 @bryant1410 所描述的注册钩子),供那些寻找可行解决方案的懒人使用:

import torch 
import torchvision.models as models
from torchvision import transforms
from PIL import Image

def get_feat_vector(path_img, model):
    '''
    Input: 
        path_img: string, /path/to/image
        model: a pretrained torch model
    Output:
        my_output: torch.tensor, output of avgpool layer
    '''
    input_image = Image.open(path_img)
    
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
    input_tensor = preprocess(input_image)
    input_batch = input_tensor.unsqueeze(0)

    with torch.no_grad():
        my_output = None
        
        def my_hook(module_, input_, output_):
            nonlocal my_output
            my_output = output_

        a_hook = model.avgpool.register_forward_hook(my_hook)        
        model(input_batch)
        a_hook.remove()
        return my_output

您就有了特征提取函数,只需使用下面的代码片段调用它即可从

resnet18.avgpool

获取特征
model = models.resnet18(pretrained=True)
model.eval()
path_ = '/path/to/image'
my_feature = get_feat_vector(path_, model)

0
投票

使用

register_forward_hook
的替代方案,但使用类而不是全局变量。

简单的例子:

class FeatureExtractor:
    def __init__(self):
        self.extracted_features = None

    def __call__(self, module, input_, output):
        self.extracted_features = output

extractor = FeatureExtractor()
model.some_specific_layer.register_forward_hook(extractor)
model(some_input)
extractor.extracted_features

从多层中提取(存储在字典中):

class FeatureExtractor:
    def __init__(self):
        self.extracted_features = dict()

    def extract_features(self, module, input_, output, name):
        self.extracted_features[name] = output

    def get_forward_hook(self, name):
        return functools.partial(self.extract_features, name=name)

model.some_specific_layer.register_forward_hook(extractor.get_forward_hook(layer_name))
model(some_input)
extractor.extracted_features[layer_name]

functools.partial
允许我们创建一个 callable ,它映射到
FeatureExtractor.extract_features
,并且特定参数已传递给 name 参数。

© www.soinside.com 2019 - 2024. All rights reserved.