在 Google Colab Notebook 中安装 ZoeDepth

问题描述 投票:0回答:1

我目前正在尝试在 google colab 中运行 ZoeDepth,并且我一直在遵循 https://github.com/isl-org/ZoeDepth 中的说明 以下是我已成功运行的线路:

!pip install torch
!pip install timm
import torch
torch.hub.help("intel-isl/MiDaS", "DPT_BEiT_L_384", force_reload=True)  # Triggers fresh download of MiDaS repo

下一步是从 github 存储库获取预训练模型:

import torch

repo = "isl-org/ZoeDepth"
# Zoe_N
model_zoe_n = torch.hub.load(repo, "ZoeD_N", pretrained=True)

但是,我无法成功运行此块,以下是该块下显示的完整日志:

Using cache found in /root/.cache/torch/hub/isl-org_ZoeDepth_main

img_size [384, 512]

Using cache found in /root/.cache/torch/hub/intel-isl_MiDaS_master
/usr/local/lib/python3.10/dist-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3483.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]

Params passed to Resize transform:
    width:  512
    height:  384
    resize_target:  True
    keep_aspect_ratio:  True
    ensure_multiple_of:  32
    resize_method:  minimal
Using pretrained resource url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_N.pt

---------------------------------------------------------------------------

RuntimeError                              Traceback (most recent call last)

<ipython-input-1-9357dea39a86> in <cell line: 5>()
      3 repo = "isl-org/ZoeDepth"
      4 # Zoe_N
----> 5 model_zoe_n = torch.hub.load(repo, "ZoeD_N", pretrained=True)
      6 
      7 # Zoe_K

/usr/local/lib/python3.10/dist-packages/torch/hub.py in load(repo_or_dir, model, source, trust_repo, force_reload, verbose, skip_validation, *args, **kwargs)
    556                                            verbose=verbose, skip_validation=skip_validation)
    557 
--> 558     model = _load_local(repo_or_dir, model, *args, **kwargs)
    559     return model
    560 

/usr/local/lib/python3.10/dist-packages/torch/hub.py in _load_local(hubconf_dir, model, *args, **kwargs)
    585 
    586         entry = _load_entry_from_hubconf(hub_module, model)
--> 587         model = entry(*args, **kwargs)
    588 
    589     return model

~/.cache/torch/hub/isl-org_ZoeDepth_main/hubconf.py in ZoeD_N(pretrained, midas_model_type, config_mode, **kwargs)
     67 
     68     config = get_config("zoedepth", config_mode, pretrained_resource=pretrained_resource, **kwargs)
---> 69     model = build_model(config)
     70     return model
     71 

~/.cache/torch/hub/isl-org_ZoeDepth_main/zoedepth/models/builder.py in build_model(config)
     49         raise ValueError(
     50             f"Model {config.model} has no get_version function.") from e
---> 51     return get_version(config.version_name).build_from_config(config)

~/.cache/torch/hub/isl-org_ZoeDepth_main/zoedepth/models/zoedepth/zoedepth_v1.py in build_from_config(config)
    248     @staticmethod
    249     def build_from_config(config):
--> 250         return ZoeDepth.build(**config)

~/.cache/torch/hub/isl-org_ZoeDepth_main/zoedepth/models/zoedepth/zoedepth_v1.py in build(midas_model_type, pretrained_resource, use_pretrained_midas, train_midas, freeze_midas_bn, **kwargs)
    243         if pretrained_resource:
    244             assert isinstance(pretrained_resource, str), "pretrained_resource must be a string"
--> 245             model = load_state_from_resource(model, pretrained_resource)
    246         return model
    247 

~/.cache/torch/hub/isl-org_ZoeDepth_main/zoedepth/models/model_io.py in load_state_from_resource(model, resource)
     82     if resource.startswith('url::'):
     83         url = resource.split('url::')[1]
---> 84         return load_state_dict_from_url(model, url, progress=True)
     85 
     86     elif resource.startswith('local::'):

~/.cache/torch/hub/isl-org_ZoeDepth_main/zoedepth/models/model_io.py in load_state_dict_from_url(model, url, **kwargs)
     59 def load_state_dict_from_url(model, url, **kwargs):
     60     state_dict = torch.hub.load_state_dict_from_url(url, map_location='cpu', **kwargs)
---> 61     return load_state_dict(model, state_dict)
     62 
     63 

~/.cache/torch/hub/isl-org_ZoeDepth_main/zoedepth/models/model_io.py in load_state_dict(model, state_dict)
     47         state[k] = v
     48 
---> 49     model.load_state_dict(state)
     50     print("Loaded successfully")
     51     return model

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
   2039 
   2040         if len(error_msgs) > 0:
-> 2041             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
   2042                                self.__class__.__name__, "\n\t".join(error_msgs)))
   2043         return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for ZoeDepth:
    Unexpected key(s) in state_dict: "core.core.pretrained.model.blocks.0.attn.relative_position_index", "core.core.pretrained.model.blocks.1.attn.relative_position_index", "core.core.pretrained.model.blocks.2.attn.relative_position_index", "core.core.pretrained.model.blocks.3.attn.relative_position_index", "core.core.pretrained.model.blocks.4.attn.relative_position_index", "core.core.pretrained.model.blocks.5.attn.relative_position_index", "core.core.pretrained.model.blocks.6.attn.relative_position_index", "core.core.pretrained.model.blocks.7.attn.relative_position_index", "core.core.pretrained.model.blocks.8.attn.relative_position_index", "core.core.pretrained.model.blocks.9.attn.relative_position_index", "core.core.pretrained.model.blocks.10.attn.relative_position_index", "core.core.pretrained.model.blocks.11.attn.relative_position_index", "core.core.pretrained.model.blocks.12.attn.relative_position_index", "core.core.pretrained.model.blocks.13.attn.relative_position_index", "core.core.pretrained.model.blocks.14.attn.relative_position_index", "core.core.pretrained.model.blocks.15.attn.relative_position_index", "core.core.pretrained.model.blocks.16.attn.relative_position_index", "core.core.pretrained.model.blocks.17.attn.relative_position_index", "core.core.pretrained.model.blocks.18.attn.relative_position_index", "core.core.pretrained.model.blocks.19.attn.relative_position_index", "core.core.pretrained.model.blocks.20.attn.relative_position_index", "core.core.pretrained.mo...

我该如何解决这个问题?

更新:我发现有人在 ZD 的 GitHub 上报告了同样的问题: https://github.com/isl-org/ZoeDepth/issues/70 (看来不是因为我用了Google Colab才出现这个问题)

python machine-learning computer-vision torch
1个回答
0
投票

它通过修复依赖关系对我有用。 答案在 - https://github.com/isl-org/ZoeDepth/issues/70

© www.soinside.com 2019 - 2024. All rights reserved.