Python继承和关键字参数

问题描述 投票:1回答:1

我正在为pytorch变压器编写包装器。为了简单起见,我将提供一个最小的示例。类Parent,它将是类My_BERT(Parent)My_GPT2(Parent)的抽象类。因为pytorch包含用于model_Bertmodel_gpt2的LM模型,所以它们具有许多相似的功能,因此我想通过在Partent中编码其他相同的函数来最大程度地减少代码冗余。

My_bert

My_gpt2在模型初始化方面基本上有所不同,并且一个参数传递给模型,但是99%的函数以相同的方式使用两个模型。

问题出在函数“模型”上,它接受不同的参数:

  • 对于model_Bert,它被定义为model(input_ids,masekd_lm_labels)
  • 对于model_gpt2
  • ,它被定义为模型(input_id,标签)

    最小代码示例:

class Parent():
    """ My own class that is an abstract class for My_bert and My_gpt2 """
    def __init__(self):
        pass

    def fancy_arithmetic(self, text):
        print("do_fancy_stuff_that_works_identically_for_both_models(text=text)")

    def compute_model(self, text):
        return self.model(input_ids=text, masked_lm_labels=text) #this line works for My_Bert
        #return self.model(input_ids=text, labels=text) #I'd need this line for My_gpt2

class My_bert(Parent): 
    """ My own My_bert class that is initialized with BERT pytorch 
    model (here model_bert), and uses methods from Parent """
    def __init__(self):
        self.model = model_bert()

class My_gpt2(Parent):
    """ My own My_gpt2 class that is initialized with gpt2 pytorch model (here model_gpt2), and uses methods from Parent """
    def __init__(self):
        self.model = model_gpt2()

class model_gpt2:
    """ This class mocks pytorch transformers gpt2 model, thus I'm writing just bunch of code that allows you run this example"""
    def __init__(self):
        pass

    def __call__(self,*input, **kwargs):
        return self.model( *input, **kwargs)

    def model(self, input_ids, labels):
        print("gpt2")

class model_bert:
    """ This class mocks pytorch transformers bert model"""
    def __init__(self):
        pass

    def __call__(self, *input, **kwargs):
        self.model(*input, **kwargs)

    def model(self, input_ids, masked_lm_labels):
        print("bert")


foo = My_bert()
foo.compute_model("bar")  # this works
bar = My_gpt2()
#bar.compute_model("rawr") #this does not work.

我知道我可以覆盖Parent::compute_modelMy_bert类中的My_gpt2函数。

但是由于两种“模型”方法是如此相似,我想知道是否有一种说法:“我将向您传递三个参数,您可以使用您所知道的参数”

def compute_model(self, text):
    return self.model(input_ids=text, masked_lm_labels=text, labels=text) # ignore the arguments you dont know

我正在为pytorch变压器编写包装器。为了简单起见,我将提供一个最小的示例。父类,它将是My_BERT(Parent)和My_GPT2(Parent)类的抽象类。 ...

python inheritance kwargs
1个回答
1
投票

*args**kwargs应该解决您遇到的问题。

© www.soinside.com 2019 - 2024. All rights reserved.