未能使用“equinox”和“optax”库实现逻辑回归

问题描述 投票:0回答:1

我正在尝试在 JAX 的支持下使用 equinox 和 optax 库实现逻辑回归。在训练模型时,损失并没有随着时间的推移而减少,模型也没有学习。特此附上带有玩具数据集的可重现代码以供参考:

import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jrandom
import equinox as eqx
import optax

data_key,model_key = jax.random.split(jax.random.PRNGKey(0),2)

### Generating toy-data

X_train = jax.random.normal(data_key, (1000,2))
y_train = X_train[:,0]+X_train[:,1]
y_train = jnp.where(y_train>0.5,1,0)

### Using equinox and optax
print("Training using equinox and optax")

epochs = 10000             
learning_rate = 0.1
n_inputs = X_train.shape[1]

class Logistic_Regression(eqx.Module):
    weight: jax.Array
    bias: jax.Array
    def __init__(self, in_size, out_size, key):
        wkey, bkey = jax.random.split(key)
        self.weight = jax.random.normal(wkey, (out_size, in_size))
        self.bias = jax.random.normal(bkey, (out_size,))
        #self.weight = jnp.zeros((out_size, in_size))
        #self.bias = jnp.zeros((out_size,))
    def __call__(self, x):
        return jax.nn.sigmoid(self.weight @ x + self.bias)

@eqx.filter_value_and_grad
def loss_fn(model, x, y):
    pred_y = jax.vmap(model)(x) 
    return -jnp.mean(y * jnp.log(pred_y) + (1 - y) * jnp.log(1 - pred_y))

@eqx.filter_jit
def make_step(model, x, y, opt_state):
    loss, grads = loss_fn(model, x, y)
    updates, opt_state = optim.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return loss, model, opt_state

in_size, out_size = n_inputs, 1
model = Logistic_Regression(in_size, out_size, key=model_key)
optim = optax.sgd(learning_rate)
opt_state = optim.init(model)
for epoch in range(epochs):
    loss, model, opt_state = make_step(model,X_train,y_train, opt_state)
    loss = loss.item()
    if (epoch+1)%1000 ==0:
        print(f"loss at epoch {epoch+1}:{loss}")

# The following code is implementation of Logistic regression using scikit-learn and pytorch, and it is working well. It is added just for reference


### Using scikit-learn
print("Training using scikit-learn")
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
model = LogisticRegression()
model.fit(X_train,y_train)
y_pred = model.predict(X_train)
print("Train accuracy:",accuracy_score(y_train,y_pred))

## Using pytorch
print("Training using pytorch")
import numpy as np
import torch
import torch.nn as nn
from torch.optim import SGD
from torch.nn import Sequential

X_train = np.array(X_train)
y_train = np.array(y_train)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("device:",device)
torch_LR= Sequential(nn.Linear(n_inputs, 1),
                nn.Sigmoid())
torch_LR.to(device)
criterion = nn.BCELoss() # define the optimization
optimizer = SGD(torch_LR.parameters(), lr=learning_rate)

train_loss = []
for epoch in range(epochs):
    inputs, targets = torch.tensor(X_train).to(device), torch.tensor(y_train).to(device) # move the data to GPU if available
    optimizer.zero_grad() # clear the gradients
    yhat = torch_LR(inputs.float()) # compute the model output
    loss = criterion(yhat, targets.unsqueeze(1).float()) # calculate loss
    #train_loss_batch.append(loss.cpu().detach().numpy()) # store the loss
    loss.backward() # update model weights
    optimizer.step()
    if (epoch+1)%1000 ==0:
        print(f"loss at epoch {epoch+1}:{loss.cpu().detach().numpy()}")


我尝试了不同学习率的 SGD 和 adam optmizers,但结果是一样的。另外,我尝试了零权重初始化和随机权重初始化。对于相同的数据,我尝试了 pytorch 和 scikit-learn 库中的 LogisticRegression 模块(我在 sklearn 中了解到没有使用 SGD,只是作为观察性能的参考)。代码块中添加了scikit-learn和pytorch建模,供参考。我已经尝试过使用多个分类数据集,但仍然面临这个问题。

deep-learning logistic-regression equinox jax
1个回答
1
投票

第一次打印损失是在 1000 个纪元之后。如果将其更改为打印前 10 个 epoch 的损失,您会看到优化器正在快速收敛:

    # ...
    if epoch < 10 or (epoch + 1)%1000 ==0:
        print(f"loss at epoch {epoch+1}:{loss}")

结果如下:

Training using equinox and optax
loss at epoch 1:1.237254023551941
loss at epoch 2:1.216030478477478
loss at epoch 3:1.1952687501907349
loss at epoch 4:1.174972414970398
loss at epoch 5:1.1551438570022583
loss at epoch 6:1.1357849836349487
loss at epoch 7:1.1168975830078125
loss at epoch 8:1.098482370376587
loss at epoch 9:1.0805412530899048
loss at epoch 10:1.0630732774734497
loss at epoch 1000:0.6320337057113647
loss at epoch 2000:0.6320337057113647
loss at epoch 3000:0.6320337057113647

到第 1000 个纪元,损失已经收敛到一个最小值,它不会移动。

鉴于此,您的优化器似乎运行正常。


编辑:我做了一些调试,发现

y_pred = jax.vmap(model)(X_train)
返回一个形状为
(1000, 1)
的数组,所以
(y - y_pred)
不是一个长度为 1000 的差异数组,而是一个形状为 (1000, 1000) 的成对差异数组所有输出。这些成对差异的对数损失不是标准逻辑回归模型。

© www.soinside.com 2019 - 2024. All rights reserved.