我正在使用 if-else 条件语句测试 gekko 中的 m.if3 函数,但我得到了两个不同的输出。我从下面的代码中获得的最佳站点是 12。我将其插入到带有 if-else 语句的下一个代码中,以确保成本匹配,但事实并非如此。我是否错误地使用了 if3/if2?前 5 个月的 enroll_rate 为 0.1,其余 45 个月切换为 0.3。
m = GEKKO(remote=False)
# parameters
sitecost = 9 # Cost per site
patcost = 12 # Cost per patient
num_pat = 50 # Required number of patients
recruit_duration = 50 # Duration of recruitment
# Define a piecewise function for the enrollment rate
enroll_rate = m.if3(recruit_duration - 5, 0.3, 0.1)
x = m.Var(integer=True, lb=1) # Number of sites
cost = m.if3(recruit_duration - 5, (0.3 * patcost * recruit_duration + sitecost) * x,
(0.1 * patcost * recruit_duration + sitecost) * x)
pat_count = m.if3(recruit_duration - 5, (0.3 * recruit_duration * x),
(0.1 * recruit_duration * x))
m.Minimize(cost)
m.Equation(pat_count >= num_pat)
m.solve(disp=False)
num_sites = int(x.value[0])
print(f'num_sites = {num_sites}')
print(cost.value[0])
print(pat_count.value[0])
# Parameters
sitecost = 9 # Cost per site
patcost = 12 # Cost per patient
num_pat = 50 # Required number of patients
recruit_duration = 50 # Duration of recruitment
# Define an enrollment rate based on recruit_duration
if recruit_duration > 5:
enroll_rate = 0.1
else:
enroll_rate = 0.3
# GEKKO Variables
x = 10 # Number of sites
# Calculate cost based on enrollment rate and site cost
cost1 = (0.1 * patcost * 45 + sitecost) * 12
cost2 = (0.3 * patcost * 5 + sitecost) * 12
cost = cost1 + cost2
# Calculate total patient count
pat_count1 = 0.1 * 5 * 12
pat_count2 = 0.3 * 45 * 12
pat_count = pat_count1 + pat_count2
print(pat_count)
print(cost)
即使我在两者中做同样的事情,我也会得到不同的输出。
我尝试了从使用 if-else 语句到使用 if2 语句的所有方法。
不需要
if2
或 if3
函数,因为切换参数 recruit_duration-5
是一个常量值,不是 Gekko 变量的函数。就像验证脚本一样,这两个部分可以单独计算并加在一起以获得总成本和患者数量。
from gekko import GEKKO
m = GEKKO(remote=False)
# parameters
sitecost = 9 # Cost per site
patcost = 12 # Cost per patient
num_pat = 50 # Required number of patients
recruit_duration = 50 # Duration of recruitment
x = m.Var(integer=True, lb=1) # Number of sites
enroll_rate1 = 0.3
enroll_rate2 = 0.1
cost1 = m.Intermediate((enroll_rate1 * patcost * 5 + sitecost) * x)
cost2 = m.Intermediate((enroll_rate2 * patcost * (recruit_duration-5) + sitecost) * x)
cost = m.Intermediate(cost1+cost2)
pat_count1 = m.Intermediate(enroll_rate1 * 5 * x)
pat_count2 = m.Intermediate(enroll_rate2 * (recruit_duration-5) * x)
pat_count = m.Intermediate(pat_count1+pat_count2)
m.Minimize(cost)
m.Equation(pat_count >= num_pat)
m.options.SOLVER = 1 # for MINLP solution
m.solve(disp=False)
num_sites = x.value[0]
print(f'num_sites = {num_sites}')
print(f'cost: {cost.value[0]}')
print(f'pat_count: {pat_count.value[0]}')
最优解是:
num_sites = 9.0
cost: 810.0
pat_count: 54.0
解决方案验证与此答案一致:
# Solution validation
# Parameters
sitecost = 9 # Cost per site
patcost = 12 # Cost per patient
num_pat = 50 # Required number of patients
recruit_duration = 50 # Duration of recruitment
# Define an enrollment rate based on recruit_duration
if recruit_duration > 5:
enroll_rate = 0.1
else:
enroll_rate = 0.3
x = 9 # Number of sites
# Calculate cost based on enrollment rate and site cost
cost1 = (0.1 * patcost * 45 + sitecost) * x
cost2 = (0.3 * patcost * 5 + sitecost) * x
cost = cost1 + cost2
# Calculate total patient count
pat_count1 = 0.3 * 5 * x
pat_count2 = 0.1 * 45 * x
pat_count = pat_count1 + pat_count2
print(f'cost (validation): {cost}')
print(f'pat_count (validation): {pat_count}')