我正在为pytorch变压器编写包装器。为了简单起见,我将提供一个最小的示例。类Parent,它将是类My_BERT(Parent)和My_GPT2(Parent)的抽象类。因为pytorch包含用于model_Bert和model_gpt2的LM模型,所以它们具有许多相似的功能,因此我想通过在Partent中编码其他相同的函数来最大程度地减少代码冗余。
My_bert
和My_gpt2在模型初始化方面基本上有所不同,并且一个参数传递给模型,但是99%的函数以相同的方式使用两个模型。问题出在函数“模型”上,它接受不同的参数:
最小代码示例:
class Parent(): """ My own class that is an abstract class for My_bert and My_gpt2 """ def __init__(self): pass def fancy_arithmetic(self, text): print("do_fancy_stuff_that_works_identically_for_both_models(text=text)") def compute_model(self, text): return self.model(input_ids=text, masked_lm_labels=text) #this line works for My_Bert #return self.model(input_ids=text, labels=text) #I'd need this line for My_gpt2 class My_bert(Parent): """ My own My_bert class that is initialized with BERT pytorch model (here model_bert), and uses methods from Parent """ def __init__(self): self.model = model_bert() class My_gpt2(Parent): """ My own My_gpt2 class that is initialized with gpt2 pytorch model (here model_gpt2), and uses methods from Parent """ def __init__(self): self.model = model_gpt2() class model_gpt2: """ This class mocks pytorch transformers gpt2 model, thus I'm writing just bunch of code that allows you run this example""" def __init__(self): pass def __call__(self,*input, **kwargs): return self.model( *input, **kwargs) def model(self, input_ids, labels): print("gpt2") class model_bert: """ This class mocks pytorch transformers bert model""" def __init__(self): pass def __call__(self, *input, **kwargs): self.model(*input, **kwargs) def model(self, input_ids, masked_lm_labels): print("bert") foo = My_bert() foo.compute_model("bar") # this works bar = My_gpt2() #bar.compute_model("rawr") #this does not work.
我知道我可以覆盖
Parent::compute_model
和My_bert
类中的My_gpt2
函数。
但是由于两种“模型”方法是如此相似,我想知道是否有一种说法:“我将向您传递三个参数,您可以使用您所知道的参数”
def compute_model(self, text):
return self.model(input_ids=text, masked_lm_labels=text, labels=text) # ignore the arguments you dont know
我正在为pytorch变压器编写包装器。为了简单起见,我将提供一个最小的示例。父类,它将是My_BERT(Parent)和My_GPT2(Parent)类的抽象类。 ...
*args
和**kwargs
应该解决您遇到的问题。