从多种格式的RL脚本中AST提取参数

问题描述 投票:0回答:1

我有多种 RL 算法的实现,我试图从中提取参数、它们的数据类型和值。

但是,这些实现因脚本而异,有时参数定义如下:

parser.add_argument("--env-id", type=str, default="pong_v3",
        help="the id of the environment") \\

    parser.add_argument("--total-timesteps", type=int, default=20000000,
        help="total timesteps of the experiments") 

    parser.add_argument("--learning-rate", type=float, default=2.5e-4,
        help="the learning rate of the optimizer")

还有这样的:

class MAA2C(Agent):
def __init__(self, env, n_agents, state_dim, action_dim,
                 memory_capacity=10000, max_steps=None,
                 roll_out_n_steps=10,
                 reward_gamma=0.99, reward_scale=1., done_penalty=None,
                 actor_hidden_size=32, critic_hidden_size=32,
                 actor_output_act=nn.functional.log_softmax,
                 use_cuda=True, training_strategy="cocurrent",
                 actor_parameter_sharing=False, critic_parameter_sharing=False):

我对如何集成所有此类格式的提取感到有点迷失,现在我在文本文件中定义参数,并且提取基于当前格式:

env_id: str = "CartPole-v1"
    total_timesteps: int = 500000
    learning_rate: float = 2.5e-4
    num_envs: int = 4
    num_steps: int = 128
    anneal_lr: bool = True

这是我目前的代码:

import ast
import sys
from tabulate import tabulate

def extract_parameters_with_values_from_file(file_path: str) -> dict:
    with open(file_path, 'r') as file:
        source_code = file.read()

    parameter_values = {}

    tree = ast.parse(source_code)
    for node in ast.walk(tree):
        if isinstance(node, ast.AnnAssign):
            parameter_name = node.target.id
            if isinstance(node.annotation, ast.Name):
                data_type = node.annotation.id
            else:
                data_type = str(ast.dump(node.annotation))
            if isinstance(node.value, ast.Call) and isinstance(node.value.func, ast.Name) and node.value.func.id == 'int':
                parameter_value = node.value.args[0].n
            elif isinstance(node.value, ast.Constant):
                parameter_value = node.value.value
            elif isinstance(node.value, ast.NameConstant):
                parameter_value = node.value.value
            else:
                try:
                    parameter_value = ast.literal_eval(node.value)
                except (ValueError, SyntaxError):
                    parameter_value = ast.dump(node.value)
            parameter_values[parameter_name] = (data_type, parameter_value)

    return parameter_values

def read_parameters_from_txt(file_path: str) -> list:
    with open(file_path, 'r') as file:
        parameters = file.read().splitlines()
    return parameters

def extract_parameters_with_values(parameter_names: list, python_file_path: str) -> list:
    parameter_values = extract_parameters_with_values_from_file(python_file_path)
    extracted_parameters = []

    for parameter_name in parameter_names:
        if parameter_name in parameter_values:
            data_type, value = parameter_values[parameter_name]
            extracted_parameters.append([parameter_name, data_type, value])

    return extracted_parameters

if __name__ == "__main__":
    if len(sys.argv) != 3:
        sys.exit(1)

    parameter_txt_path = sys.argv[1]
    python_file_path = sys.argv[2]

    parameter_names = read_parameters_from_txt(parameter_txt_path)
    extracted_parameters = extract_parameters_with_values(parameter_names, python_file_path)

    if not extracted_parameters:
        print("No parameters found in the source code")
    else:
        print(tabulate(extracted_parameters, headers=['Parameter', 'Data Type', 'Value'], tablefmt="github"))

这些是我需要包含其格式的文件示例:Code1Code2

我的

params.txt
文件如下所示:

device
policy_noise
ent_coef
vf_coef
clip_coef
gamma
batch_size
stack_size
frame_size
max_cycles
total_episodes
env_id
learning_rate
total_timesteps
buffer_size
nums_envs

我对此有点困惑,任何建议或想法将不胜感激。

python extract abstract-syntax-tree reinforcement-learning
1个回答
0
投票

请看一下

Adapter
设计模式。基本上,您有多个检索数据的入口点。每个入口点都有不同形式的数据。所以每个数据点的数据首先应该转换为一致的结构(预处理)。从那里您可以将其转换为所需的输出。

重构Guru适配器设计模式

为来自终端的用户输入创建一个适配器,并为类参数创建一个适配器。您可以从类构造函数派生签名。请检查下面的链接以获取更多信息。

Stackoverflow 参考

我建议不要使用文本文件来读取参数。您可能会使其工作,但如果在更新库时在代码中调整参数,您的代码将会中断,并且您必须每次都手动更新文本文件。这是不可持续的。尝试动态(自动)获取参数。

祝你好运!

© www.soinside.com 2019 - 2024. All rights reserved.