无法在 Google Collab 中使用 RAFT 光流

问题描述 投票:0回答:1

第一次访问 StackOverflow。我是一名学生,目前正在尝试使用普林斯顿视觉实验室的筏光流模型开展学校项目,但我遇到了困难。我想使用它并在 google collab 中运行它,但我很难设置 conda 环境来安装并正确加载预先训练的模型(位于他们的 github 存储库中)

这是我第一次做这样的事情,所以我感到非常失落和出乎意料。任何帮助,将不胜感激!我有一些我想要在下面做的事情的代码! `# conda 安装

这是我试图遵循的自述文件和存储库:https://github.com/princeton-vl/RAFT#readme

!wget -c https://repo.continuum.io/archive/Anaconda3-5.1.0-Linux-x86_64.sh !chmod +x Anaconda3-5.1.0-Linux-x86_64.sh !bash ./Anaconda3-5.1.0-Linux-x86_64.sh -b -f -p /usr/local

#conda update to avoid inconsistencies

!conda install anaconda

!conda update --all

!conda update conda -y -q;;

!source /usr/local/etc/profile.d/conda.sh

!conda init bash

!conda install -n root _license -y -q

# raft setup

!conda create --name raft

!conda activate raft

!conda install pytorch=1.6.0 torchvision=0.7.0 cudatoolkit=10.1 matplotlib tensorboard scipy opencv -c pytorch`

谢谢你!

我尝试通过谷歌搜索找到在 Google Collab 上使用 conda 安装的其他模型,并尝试使用他们的代码,但用 RAFT 而不是 FlowNet 或 GMA 替换光流模型,但我遇到了困难。我也尝试正常地安装 cuda 并从那里开始,但是 conda 安装和 inconsistnet 环境出现 hac 问题(这就是为什么我在那里也有更新 conda 行)

github conda google-colaboratory opticalflow
1个回答
0
投票

以下是我在Google Colab中使用RAFT的方法,必须启用CUDA。您可以查看这个笔记本了解详细信息,但这里是基本步骤。

  1. 克隆存储库:

    !git clone https://github.com/princeton-vl/RAFT.git

  2. 将RAFT添加到核心路径:

    sys.path.append('RAFT/core')

  3. 下载 RAFT 模型:

    %cd RAFT
    !./download_models.sh
    %cd ..

  4. 添加与 RAFT 接口的方法,要在命令行之外使用 RAFT 代码,需要对其进行修改,或者我们可以创建一个特殊的类与其接口,如下所示”

class Args():
  def __init__(self, model='', path='', small=False, mixed_precision=True, alternate_corr=False):
    self.model = model
    self.path = path
    self.small = small
    self.mixed_precision = mixed_precision
    self.alternate_corr = alternate_corr

  """ Sketchy hack to pretend to iterate through the class objects """
  def __iter__(self):
    return self

  def __next__(self):
    raise StopIteration
  1. 加载模型:
def load_model(weights_path, args):
    model = RAFT(args)
    pretrained_weights = torch.load(weights_path, map_location=torch.device("cpu"))
    model = torch.nn.DataParallel(model)
    model.load_state_dict(pretrained_weights)
    model.to("cuda")
    return model

model = load_model("RAFT/models/raft-sintel.pth", args=Args())

现在您可以在一对图像帧上运行推理:

flow_iters = model(frame1, frame2, iters=12, test_mode=False)
© www.soinside.com 2019 - 2024. All rights reserved.