Python - 无法模拟对Inherited类的调用

问题描述 投票:-1回答:1

我有这个主要课程

def main(args):
    if type == train_pipeline_type:
        strategy = TrainPipelineStrategy()
    else:
        strategy = TestPipelineStrategy()
    for table in fetch_table_information_by_region(region):
        split_required = DataUtils.load_from_dict(table, "split_required")
        if split_required:
            strategy.split(spark=spark, table_name=table_name,
                           data_loc=filtered_data_location, partition_column=partition_column,
                           split_output_dir= split_output_dir)
            logger.info("Data Split for table : {} completed".format(table_name))

我的火车管道战略和TestPipeline战略看起来像这样 -

class PipelineTypeStrategy(object):

    def partition_data(self, x):
        # Something

    def prepare_split_data(self, y):
        # Something

    def write_split_data(self, z):
        # Something

    def split(self, p):
        # Something


class TrainPipelineStrategy(PipelineTypeStrategy):
    """"""


class TestPipelineStrategy(PipelineTypeStrategy):

    def write_split_data(self, y):
        # Something else

我的测试用例 - 我需要通过在main方法中模拟split功能来测试split被调用了多少次。

这是我试过的 -

@patch('module.PipelineTypeStrategy.TrainPipelineStrategy')
    def test_split_data_main_split_data_call_count(self, fake_train):
        fake_train_functions = mock.Mock()
        fake_train_functions.split.return_value = None
        fake_train.return_value = fake_train_functions
        test_args = ["", "--x=6"]
        SplitData.main(args=test_args)
        assert fake_train_functions.split.call_count == 10

当我尝试运行我的测试时,它会创建模拟,但最终会调用实际的分割函数。我究竟做错了什么 ?

python python-3.x python-unittest
1个回答
0
投票

这段代码的主要问题是你设置patch的方式是TrainPipelineStrategyPipelineTypeStrategy的嵌套类,但TrainPipelineStrategyPipelineTypeStrategy的子类。

由于TrainPipelineStrategy继承自PipelineTypeStrategy,它可以直接访问split,所以你可以在不引用split的情况下修补PipelineTypeStrategy(除非你特别想修补split中定义的PipelineTypeStrategy版本)。

但是,如果你只想模拟split类的PipelineTypeStrategy方法,你应该使用patch.object装饰器来模拟split而不是嘲笑整个类,因为它更干净一点。这是一个例子:

class TestClass(unittest.TestCase):
    @patch.object(TrainPipelineStrategy, 'split', return_value=None)
    def test_split_data_main_split_data_call_count(self, mock_split):
        test_args = ["", "--x=6"]
        SplitData.main(args=test_args)
        self.assertEqual(mock_split.call_count, 10)
© www.soinside.com 2019 - 2024. All rights reserved.