如何从预训练的 PyTorch 模型(例如 ResNet 或 VGG)中提取特定层的特征,而无需再次进行前向传递?
编辑: torchvision v0.11.0 中有一个新功能,允许提取特征。
例如,如果你想从图层
layer4.2.relu_2
中提取特征,你可以这样做:
import torch
from torchvision.models import resnet50
from torchvision.models.feature_extraction import create_feature_extractor
x = torch.rand(1, 3, 224, 224)
model = resnet50()
return_nodes = {
"layer4.2.relu_2": "layer4"
}
model2 = create_feature_extractor(model, return_nodes=return_nodes)
intermediate_outputs = model2(x)
您可以在您想要的特定层上注册forward hook。比如:
def some_specific_layer_hook(module, input_, output):
pass # the value is in 'output'
model.some_specific_layer.register_forward_hook(some_specific_layer_hook)
model(some_input)
例如,要获取 ResNet 中的
res5c
输出,您可能需要使用 nonlocal
变量(或 Python 2 中的 global
):
res5c_output = None
def res5c_hook(module, input_, output):
nonlocal res5c_output
res5c_output = output
resnet.layer4.register_forward_hook(res5c_hook)
resnet(some_input)
# Then, use `res5c_output`.
接受的答案非常有帮助!我在这里发布了一个完整的示例(使用 @bryant1410 所描述的注册钩子),供那些寻找可行解决方案的懒人使用:
import torch
import torchvision.models as models
from torchvision import transforms
from PIL import Image
def get_feat_vector(path_img, model):
'''
Input:
path_img: string, /path/to/image
model: a pretrained torch model
Output:
my_output: torch.tensor, output of avgpool layer
'''
input_image = Image.open(path_img)
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0)
with torch.no_grad():
my_output = None
def my_hook(module_, input_, output_):
nonlocal my_output
my_output = output_
a_hook = model.avgpool.register_forward_hook(my_hook)
model(input_batch)
a_hook.remove()
return my_output
您就有了特征提取函数,只需使用下面的代码片段调用它即可从
resnet18.avgpool
层获取特征
model = models.resnet18(pretrained=True)
model.eval()
path_ = '/path/to/image'
my_feature = get_feat_vector(path_, model)
使用
register_forward_hook
的替代方案,但使用类而不是全局变量。
简单的例子:
class FeatureExtractor:
def __init__(self):
self.extracted_features = None
def __call__(self, module, input_, output):
self.extracted_features = output
extractor = FeatureExtractor()
model.some_specific_layer.register_forward_hook(extractor)
model(some_input)
extractor.extracted_features
从多层中提取(存储在字典中):
class FeatureExtractor:
def __init__(self):
self.extracted_features = dict()
def extract_features(self, module, input_, output, name):
self.extracted_features[name] = output
def get_forward_hook(self, name):
return functools.partial(self.extract_features, name=name)
model.some_specific_layer.register_forward_hook(extractor.get_forward_hook(layer_name))
model(some_input)
extractor.extracted_features[layer_name]
functools.partial
允许我们创建一个 callable ,它映射到 FeatureExtractor.extract_features
,并且特定参数已传递给 name 参数。