如何使用 jax/flax 恢复 orbax 检查点?

问题描述 投票:0回答:1

我用下面的代码保存了一个 orbax 检查点:

check_options = ocp.CheckpointManagerOptions(max_to_keep=5, create=True)
check_path = Path(os.getcwd(), out_dir, 'checkpoint')
checkpoint_manager = ocp.CheckpointManager(check_path, options=check_options, item_names=('state', 'metadata'))
checkpoint_manager.save(
                    step=iter_num,
                    args=ocp.args.Composite(
                        state=ocp.args.StandardSave(state),
                        metadata=ocp.args.JsonSave((model_args, iter_num, best_val_loss, losses['val'].item(), config))))

当我尝试从保存的检查点恢复时,我使用下面的代码来恢复

state
变量:

state, lr_schedule = init_train_state(model, params['params'], learning_rate, weight_decay, beta1, beta2, decay_lr, warmup_iters, 
                     lr_decay_iters, min_lr)  # Here state is the initialied state variable with type Train_state.
state = checkpoint_manager.restore(checkpoint_manager.latest_step(), items={'state': state})

但是当我尝试在训练循环中使用恢复的状态时,我得到了这个错误:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
File /opt/conda/envs/py_3.10/lib/python3.10/site-packages/jax/_src/api_util.py:584, in shaped_abstractify(x)
    583 try:
--> 584   return _shaped_abstractify_handlers[type(x)](x)
    585 except KeyError:

KeyError: <class 'orbax.checkpoint.composite_checkpoint_handler.CompositeArgs'>

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
Cell In[40], line 37
     34 if iter_num == 0 and eval_only:
     35     break
---> 37 state, loss = train_step(state, get_batch('train'))
     39 # timing and logging
     40 t1 = time.time()

    [... skipping hidden 6 frame]

File /opt/conda/envs/py_3.10/lib/python3.10/site-packages/jax/_src/api_util.py:575, in _shaped_abstractify_slow(x)
    573   dtype = dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True)
    574 else:
--> 575   raise TypeError(
    576       f"Cannot interpret value of type {type(x)} as an abstract array; it "
    577       "does not have a dtype attribute")
    578 return core.ShapedArray(np.shape(x), dtype, weak_type=weak_type,
    579                         named_shape=named_shape)

TypeError: Cannot interpret value of type <class 'orbax.checkpoint.composite_checkpoint_handler.CompositeArgs'> as an abstract array; it does not have a dtype attribute

那么,我应该如何正确恢复

state
检查点并在训练循环中使用它?

谢谢!

python deep-learning jax flax
1个回答
0
投票

您正在以不允许的方式混合新旧 API。抱歉,没有提出与此相关的错误,我可以对此进行调查。

您的保存是正确的,但我建议它看起来更像以下内容:

with ocp.CheckpointManager(path, options=options, item_names=('state', 'metadata')) as mngr:
  mngr.save(
      step, 
      args=ocp.args.Composite(
          state=ocp.args.StandardSave(state),
          metadata=ocp.args.JsonSave(...),
      )
  )

恢复时,您当前使用的是旧API中的

items
,其用法与
CheckpointManager
的定义不一致,这是基于新API完成的。

item_names
args
是新 API 的标志。

你应该这样做:

with ocp.CheckpointManager(...) as mngr:
  mngr.restore(
      mngr.latest_step(), 
      args=ocp.args.Composite(
          state=ocp.args.StandardSave(abstract_state),
      )
  )

如果有任何意外问题,请告诉我。

© www.soinside.com 2019 - 2024. All rights reserved.