我有一个函数 Black_Cox()
其中调用其他函数,如下图所示。
import numpy as np
from scipy import stats
# Parameters
D = 100
r = 0.05
γ = 0.1
# Normal CDF
N = lambda x: stats.norm.cdf(x)
H = lambda V, T, L, σ: np.exp(-r*T) * N( (np.log(V/L) + (r-0.5*σ**2)*T) / (σ*np.sqrt(T)) )
# Black-Scholes
def C_BS(V, K, T, σ):
d1 = (np.log(V/K) + (r + 0.5*σ**2)*T ) / ( σ*np.sqrt(T) )
d2 = d1 - σ*np.sqrt(T)
return V*N(d1) - np.exp(-r*T)*K*N(d2)
def BL(V, T, D, L, σ):
return L * H(V, T, L, σ) - L * (L/V)**(2*r/σ**2-1) * H(L**2/V, T, L, σ) \
+ C_BS(V, L, T, σ) - (L/V)**(2*r/σ**2-1) * C_BS(L**2/V, L, T, σ) \
- C_BS(V, D, T, σ) + (L/V)**(2*r/σ**2-1) * C_BS(L**2/V, D, T, σ)
def Bb(V, T, C, γ, σ, a):
b = (np.log(C/V) - γ*T) / σ
μ = (r - a - 0.5*σ**2 - γ) / σ
m = np.sqrt(μ**2 + 2*r)
return C*np.exp(b*(μ-m)) * ( N((b-m*T)/np.sqrt(T)) + np.exp(2*m*b)*N((b+m*T)/np.sqrt(T)) )
def Black_Cox(V, T, C=160, σ=0.1, a=0):
return np.exp(γ*T)*BL(V*np.exp(-γ*T), T, D*np.exp(-γ*T), C*np.exp(-γ*T), σ) + Bb(V, T, C, γ, σ, a)
我需要处理这个函数的导数 Black_Cox
功能 w.r.t. V
. 更准确地说,我需要在数千条路径上评估这个导数,在这些路径上,我改变其他参数,找到导数,并在某些地方评估。V
.
什么是最好的方法?
我应该使用 sympy
找出这个导数,然后在我的 V
的选择,就像我在Mathematica中做的那样。D[BlackCox[V, 10, 100, 160], V] /. V -> 180
,或者
我应该只用 jax
?
如果 sympy
,你会建议我怎么做?
用 jax
我明白,我需要做以下导入。
import jax.numpy as np
from jax.scipy import stats
from jax import grad
并在得到梯度之前重新评估我的函数。
func = lambda x: Black_Cox(x,10,160,0.1)
grad(func)(180.0)
如果我还需要使用 numpy
函数的版本,我是否必须为每个函数创建两个实例,或者是否有一种优雅的方法来复制一个函数的 jax
的目的?
Jax并没有提供任何内置的方法来重新编译一个使用jax版本的numpy和scipy函数。但是你可以使用一个像下面这样的代码段来自动完成它。
import inspect
from functools import wraps
import numpy as np
import jax.numpy
def replace_globals(func, globals_):
"""Recompile a function with replaced global values."""
namespace = func.__globals__.copy()
namespace.update(globals_)
source = inspect.getsource(func)
exec(source, namespace)
return wraps(func)(namespace[func.__name__])
它的工作原理是这样的:
def numpy_func(N):
return np.arange(N) ** 2
jax_func = replace_globals(numpy_func, {"np": jax.numpy})
现在你可以评估numpy版本:
numpy_func(10)
# array([ 0, 1, 4, 9, 16, 25, 36, 49, 64, 81])
和jax版本
jax_func(10)
# DeviceArray([ 0, 1, 4, 9, 16, 25, 36, 49, 64, 81], dtype=int32)
只要确保你在封装更复杂的函数时 替换掉所有相关的全局变量就可以了。