为什么if2/if3提供两种不同的输出?

问题描述 投票:0回答:1

我正在使用 if-else 条件语句测试 gekko 中的 m.if3 函数,但我得到了两个不同的输出。我从下面的代码中获得的最佳站点是 12。我将其插入到带有 if-else 语句的下一个代码中,以确保成本匹配,但事实并非如此。我是否错误地使用了 if3/if2?前 5 个月的 enroll_rate 为 0.1,其余 45 个月切换为 0.3。

m = GEKKO(remote=False)

# parameters
sitecost = 9  # Cost per site
patcost = 12  # Cost per patient
num_pat = 50  # Required number of patients
recruit_duration = 50  # Duration of recruitment
# Define a piecewise function for the enrollment rate
enroll_rate = m.if3(recruit_duration - 5, 0.3, 0.1)
x = m.Var(integer=True, lb=1)  # Number of sites

cost = m.if3(recruit_duration - 5, (0.3 * patcost * recruit_duration + sitecost) * x,
             (0.1 * patcost * recruit_duration + sitecost) * x)

pat_count = m.if3(recruit_duration - 5, (0.3 * recruit_duration * x),
                  (0.1 * recruit_duration * x))

m.Minimize(cost)
m.Equation(pat_count >= num_pat)
m.solve(disp=False)
num_sites = int(x.value[0])
print(f'num_sites = {num_sites}')
print(cost.value[0])
print(pat_count.value[0])
# Parameters
sitecost = 9  # Cost per site
patcost = 12  # Cost per patient
num_pat = 50  # Required number of patients
recruit_duration = 50  # Duration of recruitment

# Define an enrollment rate based on recruit_duration
if recruit_duration > 5:
    enroll_rate = 0.1
else:
    enroll_rate = 0.3

# GEKKO Variables
x = 10 # Number of sites

# Calculate cost based on enrollment rate and site cost

cost1 = (0.1 * patcost * 45 + sitecost) * 12

cost2 = (0.3 * patcost * 5 + sitecost) * 12
cost = cost1 + cost2

# Calculate total patient count

pat_count1 = 0.1 * 5 * 12

pat_count2 = 0.3 * 45 * 12

pat_count = pat_count1 + pat_count2

print(pat_count)
print(cost)

即使我在两者中做同样的事情,我也会得到不同的输出。

我尝试了从使用 if-else 语句到使用 if2 语句的所有方法。

python gekko
1个回答
0
投票

不需要

if2
if3
函数,因为切换参数
recruit_duration-5
是一个常量值,不是 Gekko 变量的函数。就像验证脚本一样,这两个部分可以单独计算并加在一起以获得总成本和患者数量。

from gekko import GEKKO
m = GEKKO(remote=False)

# parameters
sitecost = 9  # Cost per site
patcost = 12  # Cost per patient
num_pat = 50  # Required number of patients
recruit_duration = 50  # Duration of recruitment
x = m.Var(integer=True, lb=1)  # Number of sites

enroll_rate1 = 0.3
enroll_rate2 = 0.1

cost1 = m.Intermediate((enroll_rate1 * patcost * 5 + sitecost) * x)
cost2 = m.Intermediate((enroll_rate2 * patcost * (recruit_duration-5) + sitecost) * x)
cost = m.Intermediate(cost1+cost2)

pat_count1 = m.Intermediate(enroll_rate1 * 5 * x)
pat_count2 = m.Intermediate(enroll_rate2 * (recruit_duration-5) * x)
pat_count = m.Intermediate(pat_count1+pat_count2)

m.Minimize(cost)
m.Equation(pat_count >= num_pat)
m.options.SOLVER = 1 # for MINLP solution
m.solve(disp=False)
num_sites = x.value[0]
print(f'num_sites = {num_sites}')
print(f'cost: {cost.value[0]}')
print(f'pat_count: {pat_count.value[0]}')

最优解是:

num_sites = 9.0
cost: 810.0
pat_count: 54.0

解决方案验证与此答案一致:

# Solution validation
# Parameters
sitecost = 9  # Cost per site
patcost = 12  # Cost per patient
num_pat = 50  # Required number of patients
recruit_duration = 50  # Duration of recruitment
# Define an enrollment rate based on recruit_duration
if recruit_duration > 5:
    enroll_rate = 0.1
else:
    enroll_rate = 0.3
x = 9 # Number of sites

# Calculate cost based on enrollment rate and site cost
cost1 = (0.1 * patcost * 45 + sitecost) * x
cost2 = (0.3 * patcost * 5 + sitecost) * x
cost = cost1 + cost2

# Calculate total patient count
pat_count1 = 0.3 * 5 * x
pat_count2 = 0.1 * 45 * x
pat_count = pat_count1 + pat_count2

print(f'cost (validation): {cost}')
print(f'pat_count (validation): {pat_count}')
© www.soinside.com 2019 - 2024. All rights reserved.