我目前正在尝试在 google colab 中运行 ZoeDepth,并且我一直在遵循 https://github.com/isl-org/ZoeDepth 中的说明 以下是我已成功运行的线路:
!pip install torch
!pip install timm
import torch
torch.hub.help("intel-isl/MiDaS", "DPT_BEiT_L_384", force_reload=True) # Triggers fresh download of MiDaS repo
下一步是从 github 存储库获取预训练模型:
import torch
repo = "isl-org/ZoeDepth"
# Zoe_N
model_zoe_n = torch.hub.load(repo, "ZoeD_N", pretrained=True)
但是,我无法成功运行此块,以下是该块下显示的完整日志:
Using cache found in /root/.cache/torch/hub/isl-org_ZoeDepth_main
img_size [384, 512]
Using cache found in /root/.cache/torch/hub/intel-isl_MiDaS_master
/usr/local/lib/python3.10/dist-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3483.)
return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]
Params passed to Resize transform:
width: 512
height: 384
resize_target: True
keep_aspect_ratio: True
ensure_multiple_of: 32
resize_method: minimal
Using pretrained resource url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_N.pt
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-1-9357dea39a86> in <cell line: 5>()
3 repo = "isl-org/ZoeDepth"
4 # Zoe_N
----> 5 model_zoe_n = torch.hub.load(repo, "ZoeD_N", pretrained=True)
6
7 # Zoe_K
/usr/local/lib/python3.10/dist-packages/torch/hub.py in load(repo_or_dir, model, source, trust_repo, force_reload, verbose, skip_validation, *args, **kwargs)
556 verbose=verbose, skip_validation=skip_validation)
557
--> 558 model = _load_local(repo_or_dir, model, *args, **kwargs)
559 return model
560
/usr/local/lib/python3.10/dist-packages/torch/hub.py in _load_local(hubconf_dir, model, *args, **kwargs)
585
586 entry = _load_entry_from_hubconf(hub_module, model)
--> 587 model = entry(*args, **kwargs)
588
589 return model
~/.cache/torch/hub/isl-org_ZoeDepth_main/hubconf.py in ZoeD_N(pretrained, midas_model_type, config_mode, **kwargs)
67
68 config = get_config("zoedepth", config_mode, pretrained_resource=pretrained_resource, **kwargs)
---> 69 model = build_model(config)
70 return model
71
~/.cache/torch/hub/isl-org_ZoeDepth_main/zoedepth/models/builder.py in build_model(config)
49 raise ValueError(
50 f"Model {config.model} has no get_version function.") from e
---> 51 return get_version(config.version_name).build_from_config(config)
~/.cache/torch/hub/isl-org_ZoeDepth_main/zoedepth/models/zoedepth/zoedepth_v1.py in build_from_config(config)
248 @staticmethod
249 def build_from_config(config):
--> 250 return ZoeDepth.build(**config)
~/.cache/torch/hub/isl-org_ZoeDepth_main/zoedepth/models/zoedepth/zoedepth_v1.py in build(midas_model_type, pretrained_resource, use_pretrained_midas, train_midas, freeze_midas_bn, **kwargs)
243 if pretrained_resource:
244 assert isinstance(pretrained_resource, str), "pretrained_resource must be a string"
--> 245 model = load_state_from_resource(model, pretrained_resource)
246 return model
247
~/.cache/torch/hub/isl-org_ZoeDepth_main/zoedepth/models/model_io.py in load_state_from_resource(model, resource)
82 if resource.startswith('url::'):
83 url = resource.split('url::')[1]
---> 84 return load_state_dict_from_url(model, url, progress=True)
85
86 elif resource.startswith('local::'):
~/.cache/torch/hub/isl-org_ZoeDepth_main/zoedepth/models/model_io.py in load_state_dict_from_url(model, url, **kwargs)
59 def load_state_dict_from_url(model, url, **kwargs):
60 state_dict = torch.hub.load_state_dict_from_url(url, map_location='cpu', **kwargs)
---> 61 return load_state_dict(model, state_dict)
62
63
~/.cache/torch/hub/isl-org_ZoeDepth_main/zoedepth/models/model_io.py in load_state_dict(model, state_dict)
47 state[k] = v
48
---> 49 model.load_state_dict(state)
50 print("Loaded successfully")
51 return model
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
2039
2040 if len(error_msgs) > 0:
-> 2041 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
2042 self.__class__.__name__, "\n\t".join(error_msgs)))
2043 return _IncompatibleKeys(missing_keys, unexpected_keys)
RuntimeError: Error(s) in loading state_dict for ZoeDepth:
Unexpected key(s) in state_dict: "core.core.pretrained.model.blocks.0.attn.relative_position_index", "core.core.pretrained.model.blocks.1.attn.relative_position_index", "core.core.pretrained.model.blocks.2.attn.relative_position_index", "core.core.pretrained.model.blocks.3.attn.relative_position_index", "core.core.pretrained.model.blocks.4.attn.relative_position_index", "core.core.pretrained.model.blocks.5.attn.relative_position_index", "core.core.pretrained.model.blocks.6.attn.relative_position_index", "core.core.pretrained.model.blocks.7.attn.relative_position_index", "core.core.pretrained.model.blocks.8.attn.relative_position_index", "core.core.pretrained.model.blocks.9.attn.relative_position_index", "core.core.pretrained.model.blocks.10.attn.relative_position_index", "core.core.pretrained.model.blocks.11.attn.relative_position_index", "core.core.pretrained.model.blocks.12.attn.relative_position_index", "core.core.pretrained.model.blocks.13.attn.relative_position_index", "core.core.pretrained.model.blocks.14.attn.relative_position_index", "core.core.pretrained.model.blocks.15.attn.relative_position_index", "core.core.pretrained.model.blocks.16.attn.relative_position_index", "core.core.pretrained.model.blocks.17.attn.relative_position_index", "core.core.pretrained.model.blocks.18.attn.relative_position_index", "core.core.pretrained.model.blocks.19.attn.relative_position_index", "core.core.pretrained.model.blocks.20.attn.relative_position_index", "core.core.pretrained.mo...
我该如何解决这个问题?
更新:我发现有人在 ZD 的 GitHub 上报告了同样的问题: https://github.com/isl-org/ZoeDepth/issues/70 (看来不是因为我用了Google Colab才出现这个问题)
它通过修复依赖关系对我有用。 答案在 - https://github.com/isl-org/ZoeDepth/issues/70