我正在尝试解决这个leetcode问题:https://leetcode.com/problems/minimum-falling-path-sum/description
给定一个 n x n 整数矩阵数组,返回通过矩阵的任何下降路径的最小和。
下降路径从第一行中的任何元素开始,并选择 下一行中直接下方或对角线的元素 左右。具体来说,位置 (row, col) 的下一个元素 将是 (row + 1, col - 1)、(row + 1, col) 或 (row + 1, col + 1)。
这是一个动态编程问题,我想使用递归和记忆来解决。编辑部分提供了一个使用
row
和 col
进行记忆的 java 解决方案,如下所示:
class Solution {
public int minFallingPathSum(int[][] matrix) {
int minFallingSum = Integer.MAX_VALUE;
Integer memo[][] = new Integer[matrix.length][matrix[0].length];
// start a DFS (with memoization) from each cell in the top row
for (int startCol = 0; startCol < matrix.length; startCol++) {
minFallingSum = Math.min(minFallingSum,
findMinFallingPathSum(matrix, 0, startCol, memo));
}
return minFallingSum;
}
public int findMinFallingPathSum(int[][] matrix, int row, int col, Integer[][] memo) {
//base cases
if (col < 0 || col == matrix.length) {
return Integer.MAX_VALUE;
}
//check if we have reached the last row
if (row == matrix.length - 1) {
return matrix[row][col];
}
//check if the results are calculated before
if (memo[row][col] != null) {
return memo[row][col];
}
// calculate the minimum falling path sum starting from each possible next step
int left = findMinFallingPathSum(matrix, row + 1, col, memo);
int middle = findMinFallingPathSum(matrix, row + 1, col + 1, memo);
int right = findMinFallingPathSum(matrix, row + 1, col - 1, memo);
memo[row][col] = Math.min(left, Math.min(middle, right)) + matrix[row][col];
return memo[row][col];
}
}
我最初使用 python 的方法如下:
class Solution:
def minFallingPathSum(self, matrix: List[List[int]]) -> int:
d = {}
min_sum = sys.maxsize
for i in range(len(matrix)):
min_sum = min(min_sum, self.recur(matrix, 1, i, matrix[0][i], d))
return min_sum
def recur(self, matrix: [], row: int, col: int, sum: int, d: {}):
if row >= len(matrix):
return sum
if (row, col) not in d:
l = []
l.append(self.recur(matrix, row + 1, col, sum + matrix[row][col], d))
if col - 1 >= 0:
l.append(self.recur(matrix, row + 1, col - 1, sum + matrix[row][col-1], d))
if col + 1 < len(matrix):
l.append(self.recur(matrix, row + 1, col + 1, sum + matrix[row][col+1], d))
d[row,col] = min(l)
return d[row,col]
但是在 18/50 测试用例之后它因错误答案而失败。我通过使用
sum
以及 row
和 col
来将其更改为下面的内容,如下所示:
class Solution:
def minFallingPathSum(self, matrix: List[List[int]]) -> int:
d = {}
min_sum = sys.maxsize
for i in range(len(matrix)):
min_sum = min(min_sum, self.recur(matrix, 1, i, matrix[0][i], d))
return min_sum
def recur(self, matrix: [], row: int, col: int, sum: int, d: {}):
if row >= len(matrix):
return sum
if (row, col, sum) not in d:
l = []
l.append(self.recur(matrix, row + 1, col, sum + matrix[row][col], d))
if col - 1 >= 0:
l.append(self.recur(matrix, row + 1, col - 1, sum + matrix[row][col-1], d))
if col + 1 < len(matrix):
l.append(self.recur(matrix, row + 1, col + 1, sum + matrix[row][col+1], d))
d[row,col,sum] = min(l)
return d[row,col,sum]
这是可行的,但在 43/50 个测试用例之后超出了时间限制。
我想知道为什么我的使用
(row, col)
进行记忆的 Python 代码在编辑中的 Java 代码中不起作用。
如有任何帮助,我们将不胜感激。
...因为它适用于社论中的 Java 代码
但是你并没有真正将 Java 算法复制到你的 Python 版本中:
Java 版本自下而上工作,返回总和从给定坐标向下到矩阵的底部,而您的算法尝试自上而下工作,累积从顶部到当前单元格的路径的总和。
Java 版本使用单元格坐标作为记忆键,而 Python 版本使用单元格坐标与总和作为键(这会破坏您可能从记忆化中获得的好处)。
在
l
中,您还可以收集包含 matrix[row][col+1]
或 matrix[row][col-1]
的总和,然后将其中最好的存储在 dp[row,col,sum]
中。但这些金额将归入 row
和 col
的 兄弟姐妹,因此这是不正确的。
这里的版本采用与 Java 版本相同的方法——自下而上,因此您不需要传递部分总和作为参数。我选择将
recur
函数放在主函数中,因此不需要传递 matrix
或 d
作为参数:
class Solution:
def minFallingPathSum(self, matrix: List[List[int]]) -> int:
d = {}
def recur(row: int, col: int):
if col < 0 or col >= len(matrix[0]):
return 10000000 # larger than any value
if row == len(matrix) - 1:
return matrix[row][col]
if (row, col) not in d:
d[row,col] = min(recur(row + 1, col),
recur(row + 1, col - 1),
recur(row + 1, col + 1)) + matrix[row][col]
return d[row,col]
return min(recur(0, i) for i in range(len(matrix[0])))
您可以通过实现“迭代”算法(逐行)来改进内存使用情况,并且记忆仅存储先前访问的行的结果。另外,您可以使用列表而不是字典:
class Solution:
def minFallingPathSum(self, matrix: List[List[int]]) -> int:
n = len(matrix[0])
dp = [0] * n
for row in matrix:
dp = [
min(dp[max(i-1, 0): min(n, i+2)]) + val
for i, val in enumerate(row)
]
return min(dp)