我有这个主要课程
def main(args):
if type == train_pipeline_type:
strategy = TrainPipelineStrategy()
else:
strategy = TestPipelineStrategy()
for table in fetch_table_information_by_region(region):
split_required = DataUtils.load_from_dict(table, "split_required")
if split_required:
strategy.split(spark=spark, table_name=table_name,
data_loc=filtered_data_location, partition_column=partition_column,
split_output_dir= split_output_dir)
logger.info("Data Split for table : {} completed".format(table_name))
我的火车管道战略和TestPipeline战略看起来像这样 -
class PipelineTypeStrategy(object):
def partition_data(self, x):
# Something
def prepare_split_data(self, y):
# Something
def write_split_data(self, z):
# Something
def split(self, p):
# Something
class TrainPipelineStrategy(PipelineTypeStrategy):
""""""
class TestPipelineStrategy(PipelineTypeStrategy):
def write_split_data(self, y):
# Something else
我的测试用例 - 我需要通过在main方法中模拟split功能来测试split被调用了多少次。
这是我试过的 -
@patch('module.PipelineTypeStrategy.TrainPipelineStrategy')
def test_split_data_main_split_data_call_count(self, fake_train):
fake_train_functions = mock.Mock()
fake_train_functions.split.return_value = None
fake_train.return_value = fake_train_functions
test_args = ["", "--x=6"]
SplitData.main(args=test_args)
assert fake_train_functions.split.call_count == 10
当我尝试运行我的测试时,它会创建模拟,但最终会调用实际的分割函数。我究竟做错了什么 ?
这段代码的主要问题是你设置patch
的方式是TrainPipelineStrategy
是PipelineTypeStrategy
的嵌套类,但TrainPipelineStrategy
是PipelineTypeStrategy
的子类。
由于TrainPipelineStrategy
继承自PipelineTypeStrategy
,它可以直接访问split
,所以你可以在不引用split
的情况下修补PipelineTypeStrategy
(除非你特别想修补split
中定义的PipelineTypeStrategy
版本)。
但是,如果你只想模拟split
类的PipelineTypeStrategy
方法,你应该使用patch.object
装饰器来模拟split
而不是嘲笑整个类,因为它更干净一点。这是一个例子:
class TestClass(unittest.TestCase):
@patch.object(TrainPipelineStrategy, 'split', return_value=None)
def test_split_data_main_split_data_call_count(self, mock_split):
test_args = ["", "--x=6"]
SplitData.main(args=test_args)
self.assertEqual(mock_split.call_count, 10)