Chapter 7.2 n-step Sarsa
xxxxxxxxxx
6
1
begin
2
using ReinforcementLearning
3
using Statistics
4
using Flux
5
using Plots
6
end
# RandomWalk1D
## Traits
| Trait Type | Value |
|:----------------- | ----------------------------------------------:|
| NumAgentStyle | ReinforcementLearningBase.SingleAgent() |
| DynamicStyle | ReinforcementLearningBase.Sequential() |
| InformationStyle | ReinforcementLearningBase.PerfectInformation() |
| ChanceStyle | ReinforcementLearningBase.Deterministic() |
| RewardStyle | ReinforcementLearningBase.TerminalReward() |
| UtilityStyle | ReinforcementLearningBase.GeneralSum() |
| ActionStyle | ReinforcementLearningBase.MinimalActionSet() |
| StateStyle | ReinforcementLearningBase.Observation{Int64}() |
| DefaultStateStyle | ReinforcementLearningBase.Observation{Int64}() |
## Is Environment Terminated?
No
## State Space
`Base.OneTo(21)`
## Action Space
`Base.OneTo(2)`
## Current State
```
11
```
xxxxxxxxxx
1
1
env = RandomWalk1D(N=21)
21
2
xxxxxxxxxx
1
1
ns, na = length(state_space(env)), length(action_space(env))
-1.0:0.1:1.0
xxxxxxxxxx
1
1
true_values = -1:0.1:1
Again, we first define a hook to calculate RMS
xxxxxxxxxx
4
1
struct RecordRMS <: AbstractHook
2
rms::Vector{Float64}
3
RecordRMS() = new([])
4
end
xxxxxxxxxx
1
1
(f::RecordRMS)(::PostEpisodeStage, agent, env) = push!(f.rms, sqrt(mean((agent.policy.learner.approximator.table[2:end-1] - true_values[2:end-1]).^2)))
run_once (generic function with 1 method)
xxxxxxxxxx
17
1
function run_once(α, n)
2
env = RandomWalk1D(N=21)
3
agent = Agent(
4
policy=VBasedPolicy(
5
learner=TDLearner(
6
approximator=TabularVApproximator(;n_state=ns, opt=Descent(α)),
7
method=:SRS,
8
n=n
9
),
10
mapping= (env, V) -> rand(1:na)
11
),
12
trajectory=VectorSARTTrajectory()
13
)
14
hook = RecordRMS()
15
run(agent, env, StopAfterEpisode(10; is_show_progress=false), hook)
16
mean(hook.rms)
17
end
xxxxxxxxxx
18
1
begin
2
A = 0.:0.05:1.0
3
p = plot()
4
for n in [2^i for i in 0:9]
5
avg_rms = Float64[]
6
for α in A
7
rms = []
8
for _ in 1:100
9
push!(rms, run_once(α, n))
10
end
11
push!(avg_rms, mean(rms))
12
end
13
plot!(p, A, avg_rms, label="n = $n")
14
end
15
16
ylims!(p, 0.25, 0.55)
17
p
18
end