我正在尝试在 R 中编写一个梯度下降函数,它使用回溯线搜索来确定步长。最终,我想找到函数的最小值(假设为 f)。从数学上讲,该过程应从设置 i = 1, xi = xi-1 - backtrack(x, f, alpha, epsilon)(f'(xi-1) 开始。如果绝对值f(xi) - f(xi-1) 小于 epsilon,则函数应返回 xi。否则,应将 i 加 1 并重复上一步。函数回溯然后可以通过设置 t = 1 来计算步长 (t),如果 f(x - tf'(x)) 小于 f(x) - alpha*tf'(x)^2,则将步长设置为beta(t)。否则,函数应该停止并返回 t 的值。
backtrack_desc <- function(fn, deriv, start, alpha, beta, epsilon) {
x = start
while(TRUE) {
step_size = backtrack(fn, deriv, x, alpha, beta)
new_x = fn(x) - step_size * deriv(x)
if(abs(deriv(new_x)) <= epsilon) {
break
}
x = new_x
}
return(x)
}
backtrack <- function(fn, deriv, x, alpha, beta) {
t = 1
while(fn(x - t * deriv(x)) > (fn(x) - alpha * t * deriv(x)^2)) {
t = beta * t
}
}
# This should return something close to zero
backtrack_desc(function(x) x^2, function(x) 2 * x, start = 10,
alpha = .03, beta = .8, epsilon = 1e-10)
backtrack_desc(function(x) x^2, function(x) 2 * x, start = 1,
alpha = .03, beta = .8, epsilon = 1e-10)
该函数应该返回一个接近于零的小数,但是当我运行它时,我收到此错误消息:Error in if (abs(deriv(new_x)) <= epsilon) { : argument is of length zero.
你有两个问题。首先,
backtrack
不会返回任何内容,因此 step_size
始终是 NULL
。其次,我认为new_x
应该是x - step_size * deriv(x)
而不是f(x) - step_size * deriv(x)
。
解决这些问题,我们有:
backtrack_desc <- function(fn, deriv, start, alpha, beta, epsilon) {
x = start
while(TRUE) {
step_size = backtrack(fn, deriv, x, alpha, beta)
new_x = x - step_size * deriv(x)
if(abs(deriv(new_x)) <= epsilon) {
break
}
x = new_x
}
return(x)
}
backtrack <- function(fn, deriv, x, alpha, beta) {
t = 1
while(fn(x - t * deriv(x)) > (fn(x) - alpha * t * deriv(x)^2)) {
t = beta * t
}
return(t)
}
这会导致
# This should return something close to zero
backtrack_desc(function(x) x^2, function(x) 2 * x, start = 10,
alpha = .03, beta = .8, epsilon = 1e-10)
#> [1] 8.082813e-11
# This should return something close to 1
backtrack_desc(function(x) (x - 1)^2, function(x) 2 * x - 2, start = 4,
alpha = .03, beta = .8, epsilon = 1e-10)
#> [1] 1
# This should return something close to pi
backtrack_desc(function(x) cos(x), function(x) -sin(x), start = 1,
alpha = .03, beta = .8, epsilon = 1e-10)
#> [1] 3.141593