有没有一种好方法可以对类进行严格的类型检查并以某种方式批量覆盖它的字段?我需要一个定义良好且类型明确的类来保存一些信息,以及一种指定该信息的覆盖列表的方法。 第一个想法当然只是一个字典列表,但是这不允许轻松进行类型检查,而且设置属性也很麻烦。
import copy
from typing import TypedDict
from dataclasses import dataclass
class SubParameters(TypedDict):
par1: str
par2: list[int]
@dataclass
class Parameters:
subs: SubParameters
others: list[float]
if __name__=="__main__":
base = Parameters(subs=SubParameters(par1="placeholder", par2=[1, 2, 3]), others=[3.14, 10e-6])
overwrites = [{"others": []}, {"subs": SubParameters(par1="placeholder", par2=[4, 5, 6])}]
param_list = []
for change in overwrites:
p = copy.copy(base)
for k, v in change.items():
p.__setattr__(k, v)
param_list.append(p)
print(param_list)
这可行,但没有对覆盖进行类型检查,非常麻烦,并且很难覆盖更复杂的子参数,这些子参数必须完全重新定义。有更好的办法吗?
为了增强类型安全性和批量覆盖类字段的便利性,您可以利用带有类型注释的数据类并实现用于应用类型化覆盖的通用实用程序:
import copy
from dataclasses import dataclass, field
from typing import TypedDict, List, Dict, TypeVar, Generic
# Define your subparameters as before
class SubParameters(TypedDict):
par1: str
par2: List[int]
# Use a dataclass for your main parameters
@dataclass
class Parameters:
subs: SubParameters
others: List[float]
T = TypeVar('T')
# A generic class to handle typed updates
class TypedUpdate(Generic[T]):
def __init__(self, initial_data: T):
self.data = initial_data
def apply_updates(self, updates: Dict[str, T]) -> T:
for key, value in updates.items():
if isinstance(value, dict) and isinstance(getattr(self.data, key), dict):
# For nested dictionaries, we can recursively apply updates
current_value = getattr(self.data, key)
updated_value = TypedUpdate(current_value).apply_updates(value)
setattr(self.data, key, updated_value)
else:
setattr(self.data, key, value)
return self.data
# Utility function to copy and update parameters
def update_parameters(base: Parameters, changes: List[Dict[str, Dict[str, any]]]) -> List[Parameters]:
param_list = []
for change in changes:
p = copy.deepcopy(base) # Use deepcopy to avoid mutating nested structures
typed_update = TypedUpdate(p)
for k, v in change.items():
typed_update.apply_updates({k: v})
param_list.append(p)
return param_list
if __name__ == "__main__":
base = Parameters(subs={"par1": "placeholder", "par2": [1, 2, 3]}, others=[3.14, 10e-6])
overwrites = [{"others": []}, {"subs": {"par1": "new_placeholder", "par2": [4, 5, 6]}}]
param_list = update_parameters(base, overwrites)
print(param_list)