Chapter 8.6 Trajectory Sampling
The general function run(policy, env, stop_condition, hook)
is very flexible and powerful. However, we are not restricted to use it only. In this notebook, we'll see how to use part of the components provided in ReinforcementLearning.jl
to finish some specific experiments.
First, let's define the environment mentioned in Chapter 8.6:
xxxxxxxxxx
6
1
begin
2
using ReinforcementLearning
3
using Flux
4
using Statistics
5
using Plots
6
end
xxxxxxxxxx
57
1
begin
2
mutable struct TestEnv <: AbstractEnv
3
transitions::Array{Int, 3}
4
rewards::Array{Float64, 3}
5
reward_table::Array{Float64, 2}
6
terminate_prob::Float64
7
# cache
8
s_init::Int
9
s::Int
10
reward::Float64
11
is_terminated::Bool
12
end
13
14
function TestEnv(;ns=1000, na=2, b=1, terminate_prob=0.1,init_state=1)
15
transitions = rand(1:ns, b, na, ns)
16
rewards = randn(b, na, ns)
17
reward_table = randn(na, ns)
18
TestEnv(
19
transitions,
20
rewards,
21
reward_table,
22
terminate_prob,
23
init_state,
24
init_state,
25
0.,
26
false
27
)
28
end
29
30
function (env::TestEnv)(a::Int)
31
t = rand() < 0.1
32
bᵢ = rand(axes(env.transitions, 1))
33
34
env.is_terminated = t
35
if t
36
env.reward = env.reward_table[a, env.s]
37
else
38
env.reward = env.rewards[bᵢ, a, env.s]
39
end
40
41
env.s = env.transitions[bᵢ, a, env.s]
42
43
end
44
45
RLBase.state_space(env::TestEnv) = Base.OneTo(1:size(env.rewards, 3))
46
RLBase.action_space(env::TestEnv) = Base.OneTo(1:size(env.rewards, 2))
47
48
function RLBase.reset!(env::TestEnv)
49
env.s = env.s_init
50
env.reward = 0.0
51
env.is_terminated = false
52
end
53
54
RLBase.is_terminated(env::TestEnv) = env.is_terminated
55
RLBase.state(env::TestEnv) = env.s
56
RLBase.reward(env::TestEnv) = env.reward
57
end
Note that this environment is not described very clearly on the book. Part of the information are inferred from the lisp source code.
Info
Actually the lisp code is also not perfect, I spent a whole afternoon to figure out the code logic. So good luck if you also want to understand it.
The definitions above are just like any other environment we've defined before in previous chapters. Now we'll add an extra function to make it work for our planning purpose.
Main.workspace46.successors
xxxxxxxxxx
8
1
"""
2
Return all the possible next states and corresponding reward.
3
"""
4
function successors(env::TestEnv, s, a)
5
S = env.transitions[:, a, s]
6
R = env.rewards[:, a, s]
7
zip(R, S)
8
end
0.9
xxxxxxxxxx
1
1
γ = 0.9
10
xxxxxxxxxx
1
1
n_sweep=10
Main.workspace46.eval_Q
xxxxxxxxxx
19
1
"""
2
Here we are only interested in the performance of Q
3
with env starting at state `1`. Note here we're calculating
4
the discounted reward.
5
"""
6
function eval_Q(Q, env;n_eval=100)
7
R = 0.
8
for _ in 1:n_eval
9
reset!(env)
10
i = 0
11
while !is_terminated(env)
12
a = Q(state(env)) |> argmax # greedy
13
env(a)
14
R += reward(env) * γ^i
15
i += 1
16
end
17
end
18
R/n_eval
19
end
Main.workspace46.gain
xxxxxxxxxx
8
1
"""
2
Calculate the expected gain.
3
"""
4
function gain(Q,env,s,a)
5
p = env.terminate_prob
6
r = env.reward_table[a, s]
7
p * r + (1-p) * mean(r̄ + γ * maximum(Q(s′)) for (r̄, s′) in successors(env, s, a))
8
end
sweep (generic function with 1 method)
xxxxxxxxxx
26
1
function sweep(;b = 1, ns=1000)
2
3
na = 2
4
5
α=1.0
6
p = 0.1
7
8
env= TestEnv(;ns=ns, na=na, b=b, terminate_prob=p)
9
Q = TabularQApproximator(;n_state=ns, n_action=na, opt=Descent(α))
10
11
i = 1
12
vals = [eval_Q(Q, env)]
13
for _ in 1:n_sweep
14
for s in 1:ns
15
for a in 1:na
16
G = gain(Q,env,s,a)
17
update!(Q, (s,a) => Q(s, a) - G)
18
if i % 100 == 0
19
push!(vals, eval_Q(Q, env))
20
end
21
i += 1
22
end
23
end
24
end
25
vals
26
end
on_policy (generic function with 1 method)
xxxxxxxxxx
28
1
function on_policy(;b = 1, ns=1000)
2
3
na = 2
4
5
α=1.0
6
p = 0.1
7
8
env= TestEnv(;ns=ns, na=na, b=b, terminate_prob=p)
9
Q = TabularQApproximator(;n_state=ns, n_action=na, opt=Descent(α))
10
11
i = 1
12
vals = [eval_Q(Q, env)]
13
14
explorer = EpsilonGreedyExplorer(0.1)
15
for i in 1:(n_sweep * ns * na)
16
is_terminated(env) && reset!(env)
17
s = state(env)
18
a = Q(s) |> explorer
19
env(a)
20
G = gain(Q, env, s, a)
21
update!(Q, (s,a) => Q(s,a) - G)
22
if i % 100 == 0
23
push!(vals, eval_Q(Q, env))
24
end
25
end
26
27
vals
28
end
xxxxxxxxxx
8
1
begin
2
fig_8_8 = plot(legend=:bottomright)
3
for b in [1, 3, 10]
4
plot!(fig_8_8, mean(sweep(;b=b) for _ in 1:200), label="uniform b=$b")
5
plot!(fig_8_8, mean(on_policy(;b=b) for _ in 1:200), label="on policy b=$b")
6
end
7
fig_8_8
8
end
xxxxxxxxxx
6
1
begin
2
fig_8_8_2 = plot(legend=:bottomright)
3
plot!(fig_8_8_2, mean(sweep(;ns=10_000) for _ in 1:200), label="uniform")
4
plot!(fig_8_8_2, mean(on_policy(;ns=10_000) for _ in 1:200), label="on_policy")
5
fig_8_8_2
6
end