我有多种 RL 算法的实现,我试图从中提取参数、它们的数据类型和值。
但是,这些实现因脚本而异,有时参数定义如下:
parser.add_argument("--env-id", type=str, default="pong_v3",
help="the id of the environment") \\
parser.add_argument("--total-timesteps", type=int, default=20000000,
help="total timesteps of the experiments")
parser.add_argument("--learning-rate", type=float, default=2.5e-4,
help="the learning rate of the optimizer")
还有这样的:
class MAA2C(Agent):
def __init__(self, env, n_agents, state_dim, action_dim,
memory_capacity=10000, max_steps=None,
roll_out_n_steps=10,
reward_gamma=0.99, reward_scale=1., done_penalty=None,
actor_hidden_size=32, critic_hidden_size=32,
actor_output_act=nn.functional.log_softmax,
use_cuda=True, training_strategy="cocurrent",
actor_parameter_sharing=False, critic_parameter_sharing=False):
我对如何集成所有此类格式的提取感到有点迷失,现在我在文本文件中定义参数,并且提取基于当前格式:
env_id: str = "CartPole-v1"
total_timesteps: int = 500000
learning_rate: float = 2.5e-4
num_envs: int = 4
num_steps: int = 128
anneal_lr: bool = True
这是我目前的代码:
import ast
import sys
from tabulate import tabulate
def extract_parameters_with_values_from_file(file_path: str) -> dict:
with open(file_path, 'r') as file:
source_code = file.read()
parameter_values = {}
tree = ast.parse(source_code)
for node in ast.walk(tree):
if isinstance(node, ast.AnnAssign):
parameter_name = node.target.id
if isinstance(node.annotation, ast.Name):
data_type = node.annotation.id
else:
data_type = str(ast.dump(node.annotation))
if isinstance(node.value, ast.Call) and isinstance(node.value.func, ast.Name) and node.value.func.id == 'int':
parameter_value = node.value.args[0].n
elif isinstance(node.value, ast.Constant):
parameter_value = node.value.value
elif isinstance(node.value, ast.NameConstant):
parameter_value = node.value.value
else:
try:
parameter_value = ast.literal_eval(node.value)
except (ValueError, SyntaxError):
parameter_value = ast.dump(node.value)
parameter_values[parameter_name] = (data_type, parameter_value)
return parameter_values
def read_parameters_from_txt(file_path: str) -> list:
with open(file_path, 'r') as file:
parameters = file.read().splitlines()
return parameters
def extract_parameters_with_values(parameter_names: list, python_file_path: str) -> list:
parameter_values = extract_parameters_with_values_from_file(python_file_path)
extracted_parameters = []
for parameter_name in parameter_names:
if parameter_name in parameter_values:
data_type, value = parameter_values[parameter_name]
extracted_parameters.append([parameter_name, data_type, value])
return extracted_parameters
if __name__ == "__main__":
if len(sys.argv) != 3:
sys.exit(1)
parameter_txt_path = sys.argv[1]
python_file_path = sys.argv[2]
parameter_names = read_parameters_from_txt(parameter_txt_path)
extracted_parameters = extract_parameters_with_values(parameter_names, python_file_path)
if not extracted_parameters:
print("No parameters found in the source code")
else:
print(tabulate(extracted_parameters, headers=['Parameter', 'Data Type', 'Value'], tablefmt="github"))
我的
params.txt
文件如下所示:
device
policy_noise
ent_coef
vf_coef
clip_coef
gamma
batch_size
stack_size
frame_size
max_cycles
total_episodes
env_id
learning_rate
total_timesteps
buffer_size
nums_envs
我对此有点困惑,任何建议或想法将不胜感激。
请看一下
Adapter
设计模式。基本上,您有多个检索数据的入口点。每个入口点都有不同形式的数据。所以每个数据点的数据首先应该转换为一致的结构(预处理)。从那里您可以将其转换为所需的输出。
为来自终端的用户输入创建一个适配器,并为类参数创建一个适配器。您可以从类构造函数派生签名。请检查下面的链接以获取更多信息。
我建议不要使用文本文件来读取参数。您可能会使其工作,但如果在更新库时在代码中调整参数,您的代码将会中断,并且您必须每次都手动更新文本文件。这是不可持续的。尝试动态(自动)获取参数。
祝你好运!