第一次访问 StackOverflow。我是一名学生,目前正在尝试使用普林斯顿视觉实验室的筏光流模型开展学校项目,但我遇到了困难。我想使用它并在 google collab 中运行它,但我很难设置 conda 环境来安装并正确加载预先训练的模型(位于他们的 github 存储库中)
这是我第一次做这样的事情,所以我感到非常失落和出乎意料。任何帮助,将不胜感激!我有一些我想要在下面做的事情的代码! `# conda 安装
这是我试图遵循的自述文件和存储库:https://github.com/princeton-vl/RAFT#readme
!wget -c https://repo.continuum.io/archive/Anaconda3-5.1.0-Linux-x86_64.sh !chmod +x Anaconda3-5.1.0-Linux-x86_64.sh !bash ./Anaconda3-5.1.0-Linux-x86_64.sh -b -f -p /usr/local
#conda update to avoid inconsistencies
!conda install anaconda
!conda update --all
!conda update conda -y -q;;
!source /usr/local/etc/profile.d/conda.sh
!conda init bash
!conda install -n root _license -y -q
# raft setup
!conda create --name raft
!conda activate raft
!conda install pytorch=1.6.0 torchvision=0.7.0 cudatoolkit=10.1 matplotlib tensorboard scipy opencv -c pytorch`
谢谢你!
我尝试通过谷歌搜索找到在 Google Collab 上使用 conda 安装的其他模型,并尝试使用他们的代码,但用 RAFT 而不是 FlowNet 或 GMA 替换光流模型,但我遇到了困难。我也尝试正常地安装 cuda 并从那里开始,但是 conda 安装和 inconsistnet 环境出现 hac 问题(这就是为什么我在那里也有更新 conda 行)
以下是我在Google Colab中使用RAFT的方法,必须启用CUDA。您可以查看这个笔记本了解详细信息,但这里是基本步骤。
克隆存储库:
!git clone https://github.com/princeton-vl/RAFT.git
将RAFT添加到核心路径:
sys.path.append('RAFT/core')
下载 RAFT 模型:
%cd RAFT
!./download_models.sh
%cd ..
添加与 RAFT 接口的方法,要在命令行之外使用 RAFT 代码,需要对其进行修改,或者我们可以创建一个特殊的类与其接口,如下所示”
class Args():
def __init__(self, model='', path='', small=False, mixed_precision=True, alternate_corr=False):
self.model = model
self.path = path
self.small = small
self.mixed_precision = mixed_precision
self.alternate_corr = alternate_corr
""" Sketchy hack to pretend to iterate through the class objects """
def __iter__(self):
return self
def __next__(self):
raise StopIteration
def load_model(weights_path, args):
model = RAFT(args)
pretrained_weights = torch.load(weights_path, map_location=torch.device("cpu"))
model = torch.nn.DataParallel(model)
model.load_state_dict(pretrained_weights)
model.to("cuda")
return model
model = load_model("RAFT/models/raft-sintel.pth", args=Args())
现在您可以在一对图像帧上运行推理:
flow_iters = model(frame1, frame2, iters=12, test_mode=False)