Haskell GriWorld 无限循环

问题描述 投票:0回答:1

我正在尝试通过强化学习在 Haskell 中编写 GridWorld 模拟代码。我被困住了,因为我一直陷入第 109 行的无限循环。我已经盯着这个问题一周了,我从头开始重写了代码多次,这样我就可以从别人的角度获得一些帮助。

这是我的程序的输出:

Initial Grid:
0 | 1 | 2 | 3 | 4 | 
1 | 2 | 3 | 4 | 5 | 
2 | 3 | 4 | 5 | 6 | 
3 | 4 | 5 | 6 | 7 | 
4 | 5 | 6 | 7 | 8 | 
Training Q-learning agent...
Q-learning training finished.
Final Grid:

代码执行到第 109 行(不包括)。

visualizeGrid (\(x, y) -> maximum [finalQTable ((x, y), a) | a <- [minBound .. maxBound]])

代码编译良好。在错误窗口中(我正在使用在线环境)我收到消息Main:<>

以下所有代码:

import System.Random
import Data.List (maximumBy)
import Data.Ord (comparing)
import Control.Monad (foldM)
import Debug.Trace (trace)

type Position = (Int, Int)
type Reward = Float
type GridWorld = Position -> Reward
data Action = Up | Down | MoveLeft | MoveRight deriving (Eq, Enum, Bounded)

instance Show Action where
    show Up = "↑"
    show Down = "↓"
    show MoveLeft = "←"
    show MoveRight = "→"

step :: GridWorld -> Position -> Action -> (Position, Reward)
step world (x, y) action = case action of
    Up       -> ((x, max 0 (y-1)), world (x, max 0 (y-1)))
    Down     -> ((x, min 4 (y+1)), world (x, min 4 (y+1)))
    MoveLeft -> ((max 0 (x-1), y), world (max 0 (x-1), y))
    MoveRight-> ((min 4 (x+1), y), world (min 4 (x+1), y))

type QTable = ((Position, Action) -> Reward)

trainQ :: GridWorld -> QTable -> Float -> Float -> Int -> IO QTable
trainQ world qtable alpha gamma episodes = do
    gen <- newStdGen
    let actions = [minBound .. maxBound] :: [Action]
        positions = [(x, y) | x <- [0..4], y <- [0..4]]
    finalQTable <- snd <$> foldM (\(g, q) _ -> do
                                    let (q', g') = trainEpisode world q alpha gamma actions positions g
                                    return (g', q')) (gen, qtable) [1..episodes]
    return finalQTable

trainEpisode :: GridWorld -> QTable -> Float -> Float -> [Action] -> [Position] -> StdGen -> (QTable, StdGen)
trainEpisode world qtable alpha gamma actions positions gen = 
    let (startPos, newGen) = randomR (0, length positions - 1) gen
        startState = positions !! startPos
        (_, finalQTable, _) = foldl (\(prevPos, q, g) _ -> 
                                        let (newPos, _) = step world prevPos (chooseAction q newPos actions)
                                            (updatedQTable, _) = trainStep world q alpha gamma actions newPos
                                        in (newPos, updatedQTable, g)) 
                                    (startState, qtable, newGen) [1..10]
    in (finalQTable, newGen)

trainStep :: GridWorld -> QTable -> Float -> Float -> [Action] -> Position -> (QTable, Position)
trainStep world qtable alpha gamma actions pos =
    let action = chooseAction qtable pos actions
        (newPos, reward) = step world pos action
        oldValue = qtable (pos, action)
        futureValue = maximum [qtable (newPos, a) | a <- actions]
        newValue = oldValue + alpha * (reward + gamma * futureValue - oldValue)
        qtable' = \s -> if s == (pos, action) then newValue else qtable s
    in (qtable', newPos)

chooseAction :: QTable -> Position -> [Action] -> Action
chooseAction qtable (x, y) actions =
    let validActions = filter (\action -> isValidAction action (x, y)) actions
        bestAction = maximumBy (comparing (\a -> qtable ((x, y), a))) validActions
    in bestAction

isValidAction :: Action -> Position -> Bool
isValidAction Up    (_, y) = y > 0
isValidAction Down  (_, y) = y < 4
isValidAction MoveLeft  (x, _) = x > 0
isValidAction MoveRight (x, _) = x < 4

runAgent :: GridWorld -> QTable -> Position -> [Action] -> IO ()
runAgent world qtable pos actions = do
    putStrLn $ "Current Position: " ++ show pos
    let action = chooseAction qtable pos actions
    putStrLn $ "Chosen Action: " ++ show action
    let (newPos, reward) = step world pos action
    putStrLn $ "Action: " ++ show action ++ ", Reward: " ++ show reward
    if newPos == pos
        then do
            putStrLn "Agent is stuck! Generating random action."
            gen <- newStdGen
            let (randomActionIndex, newGen) = randomR (0, length actions - 1) gen
                randomAction = actions !! randomActionIndex
            putStrLn $ "Random Action: " ++ show randomAction
            let (newPos', reward') = step world pos randomAction
            putStrLn $ "Random Action Result: " ++ show randomAction ++ ", New Position: " ++ show newPos' ++ ", Reward: " ++ show reward'
            runAgent world qtable newPos' actions
        else do
            putStrLn $ "New Position: " ++ show newPos
            runAgent world qtable newPos actions

visualizeGrid :: GridWorld -> IO ()
visualizeGrid world = mapM_ putStrLn [concat [show (round (world (x, y))) ++ " | " | x <- [0..4]] | y <- [0..4]]

gridWorld :: GridWorld
gridWorld (x, y) = fromIntegral (x + y)

main :: IO ()
main = do
    putStrLn "Initial Grid:"
    visualizeGrid gridWorld
    putStrLn "Training Q-learning agent..."
    let initialQTable = \_ -> 0.0
        alpha = 0.1
        gamma = 0.9
        episodes = 4
    finalQTable <- trainQ gridWorld initialQTable alpha gamma episodes
    putStrLn "Q-learning training finished."
    putStrLn "Final Grid:"
    visualizeGrid (\(x, y) -> maximum [finalQTable ((x, y), a) | a <- [minBound .. maxBound]])
    putStrLn "Running agent:"
    runAgent gridWorld finalQTable (0, 0) [minBound .. maxBound]

haskell infinite-loop reinforcement-learning q-learning
1个回答
0
投票

<>可能是由于这条线造成的

let (newPos, _) = step world prevPos (chooseAction q newPos actions)

since

newPos
是根据自身递归定义的。使用另一个名称,例如
newPos2
作为新变量。

© www.soinside.com 2019 - 2024. All rights reserved.