xxxxxxxxxx
6
1
begin
2
using ReinforcementLearning
3
using Flux
4
using Statistics
5
using Plots
6
end
The Baird Count Environment
xxxxxxxxxx
3
1
md"""
2
## The Baird Count Environment
3
"""
xxxxxxxxxx
24
1
begin
2
const DASH_SOLID = (:dashed, :solid)
3
4
Base. mutable struct BairdCounterEnv <: AbstractEnv
5
current::Int = rand(1:7)
6
end
7
8
RLBase.state_space(env::BairdCounterEnv) = Base.OneTo(7)
9
RLBase.action_space(env::BairdCounterEnv) = Base.OneTo(length(DASH_SOLID))
10
11
function (env::BairdCounterEnv)(a)
12
if DASH_SOLID[a] == :dashed
13
env.current = rand(1:6)
14
else
15
env.current = 7
16
end
17
nothing
18
end
19
20
RLBase.reward(env::BairdCounterEnv) = 0.
21
RLBase.is_terminated(env::BairdCounterEnv) = false
22
RLBase.state(env::BairdCounterEnv) = env.current
23
RLBase.reset!(env::BairdCounterEnv) = env.current = rand(1:6)
24
end
Off Policy
xxxxxxxxxx
4
1
# Base.@kwdef struct OffPolicy{P,B} <: AbstractPolicy
2
# π_target::P
3
# π_behavior::B
4
# end
xxxxxxxxxx
1
1
# (π::OffPolicy)(env) = π.π_behavior(env)
xxxxxxxxxx
53
1
begin
2
3
# const VectorWSARTTrajectory = Trajectory{<:NamedTuple{(:weight, SART...)}}
4
5
# function VectorWSARTTrajectory(;weight=Float64, state=Int, action=Int, reward=Float32, terminal=Bool)
6
# VectorTrajectory(;weight=Float64, state=state, action=action, reward=reward, terminal=terminal)
7
# end
8
9
# function RLBase.update!(
10
# p::OffPolicy,
11
# t::VectorWSARTTrajectory,
12
# e::AbstractEnv,
13
# s::AbstractStage
14
# )
15
# update!(p.π_target, t, e, s)
16
# end
17
18
# function RLBase.update!(
19
# t::VectorWSARTTrajectory,
20
# p::OffPolicy,
21
# env::AbstractEnv,
22
# s::PreActStage,
23
# a
24
# )
25
# push!(t[:state], state(env))
26
# push!(t[:action], a)
27
28
# w = prob(p.π_target, s, a) / prob(p.π_behavior, s, a)
29
# push!(t[:weight], w)
30
# end
31
32
# function RLBase.update!(
33
# t::VectorWSARTTrajectory,
34
# p::OffPolicy{<:QBasedPolicy{<:TDLearner}},
35
# env::AbstractEnv,
36
# s::PreEpisodeStage,
37
# )
38
# empty!(t)
39
# end
40
41
# function RLBase.update!(
42
# t::VectorWSARTTrajectory,
43
# p::OffPolicy{<:QBasedPolicy{<:TDLearner}},
44
# env::AbstractEnv,
45
# s::PostEpisodeStage,
46
# )
47
# action = rand(action_space(env))
48
49
# push!(trajectory[:state], state(env))
50
# push!(trajectory[:action], action)
51
# push!(t[:weight], 1.0)
52
# end
53
end
Figure 11.2
# BairdCounterEnv
## Traits
| Trait Type | Value |
|:----------------- | ------------------------------------------------:|
| NumAgentStyle | ReinforcementLearningBase.SingleAgent() |
| DynamicStyle | ReinforcementLearningBase.Sequential() |
| InformationStyle | ReinforcementLearningBase.ImperfectInformation() |
| ChanceStyle | ReinforcementLearningBase.Stochastic() |
| RewardStyle | ReinforcementLearningBase.StepReward() |
| UtilityStyle | ReinforcementLearningBase.GeneralSum() |
| ActionStyle | ReinforcementLearningBase.MinimalActionSet() |
| StateStyle | ReinforcementLearningBase.Observation{Any}() |
| DefaultStateStyle | ReinforcementLearningBase.Observation{Any}() |
## Is Environment Terminated?
No
## State Space
`Base.OneTo(7)`
## Action Space
`Base.OneTo(2)`
## Current State
```
1
```
xxxxxxxxxx
1
1
world = BairdCounterEnv()
xxxxxxxxxx
10
1
begin
2
Base. struct RecordWeights <: AbstractHook
3
weights::Vector{Vector{Float64}}=[]
4
end
5
6
(h::RecordWeights)(::PostActStage, agent, env) = push!(
7
h.weights,
8
agent.policy.π_target.learner.approximator.weights |> deepcopy
9
)
10
end
8
xxxxxxxxxx
1
1
NW = 8
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
xxxxxxxxxx
1
1
INIT_WEIGHT = ones(8)
10
xxxxxxxxxx
1
1
INIT_WEIGHT[7] = 10
8×7 Array{Float64,2}:
0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0
xxxxxxxxxx
1
1
STATE_MAPPING = zeros(NW, length(state_space(world)))
2
xxxxxxxxxx
8
1
begin
2
for i in 1:6
3
STATE_MAPPING[i, i] = 2
4
STATE_MAPPING[8, i] = 1
5
end
6
STATE_MAPPING[7, 7] = 1
7
STATE_MAPPING[8, 7] = 2
8
end
8×7 Array{Float64,2}:
2.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 2.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 2.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 2.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 2.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 2.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 1.0
1.0 1.0 1.0 1.0 1.0 1.0 2.0
xxxxxxxxxx
1
1
STATE_MAPPING
#1 (generic function with 1 method)
xxxxxxxxxx
1
1
π_b = x -> rand() < 6/7 ? 1 : 2
VBasedPolicy
├─ learner => TDLearner
│ ├─ approximator => LinearApproximator
│ │ ├─ weights => 8-element Array{Float64,1}
│ │ └─ optimizer => Descent
│ │ └─ eta => 0.01
│ ├─ γ => 0.99
│ ├─ method => SRS
│ └─ n => 0
└─ mapping => Main.var"#3#4"
xxxxxxxxxx
9
1
π_t = VBasedPolicy(
2
learner=TDLearner(
3
approximator=RLZoo.LinearApproximator(INIT_WEIGHT, Descent(0.01)),
4
γ=0.99,
5
n=0,
6
method=:SRS
7
),
8
mapping = (env, V) -> 2
9
)
0.857143
0.142857
xxxxxxxxxx
1
1
prob_b = [6/7, 1/7]
0.0
1.0
xxxxxxxxxx
1
1
prob_t = [0., 1.]
Well, I must admit it is a little tricky here.
xxxxxxxxxx
1
1
RLBase.prob(::typeof(π_b), s, a::Integer) = prob_b[a]
xxxxxxxxxx
1
1
RLBase.prob(::typeof(π_t), s, a::Integer) = prob_t[a]
Agent
├─ policy => OffPolicy
│ ├─ π_target => VBasedPolicy
│ │ ├─ learner => TDLearner
│ │ │ ├─ approximator => LinearApproximator
│ │ │ │ ├─ weights => 8-element Array{Float64,1}
│ │ │ │ └─ optimizer => Descent
│ │ │ │ └─ eta => 0.01
│ │ │ ├─ γ => 0.99
│ │ │ ├─ method => SRS
│ │ │ └─ n => 0
│ │ └─ mapping => Main.var"#3#4"
│ └─ π_behavior => Main.var"#1#2"
└─ trajectory => Trajectory
└─ traces => NamedTuple
├─ weight => 0-element Array{Float64,1}
├─ state => 0-element Array{Any,1}
├─ action => 0-element Array{Int64,1}
├─ reward => 0-element Array{Float32,1}
└─ terminal => 0-element Array{Bool,1}
xxxxxxxxxx
7
1
agent = Agent(
2
policy=OffPolicy(
3
π_target=π_t,
4
π_behavior=π_b
5
),
6
trajectory=VectorWSARTTrajectory(state=Any)
7
)
# BairdCounterEnv |> StateOverriddenEnv
## Traits
| Trait Type | Value |
|:----------------- | ------------------------------------------------:|
| NumAgentStyle | ReinforcementLearningBase.SingleAgent() |
| DynamicStyle | ReinforcementLearningBase.Sequential() |
| InformationStyle | ReinforcementLearningBase.ImperfectInformation() |
| ChanceStyle | ReinforcementLearningBase.Stochastic() |
| RewardStyle | ReinforcementLearningBase.StepReward() |
| UtilityStyle | ReinforcementLearningBase.GeneralSum() |
| ActionStyle | ReinforcementLearningBase.MinimalActionSet() |
| StateStyle | ReinforcementLearningBase.Observation{Any}() |
| DefaultStateStyle | ReinforcementLearningBase.Observation{Any}() |
## Is Environment Terminated?
No
## State Space
`Base.OneTo(7)`
## Action Space
`Base.OneTo(2)`
## Current State
```
[0.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]
```
xxxxxxxxxx
4
1
new_env = StateOverriddenEnv(
2
BairdCounterEnv(),
3
s -> STATE_MAPPING[:, s]
4
)
xxxxxxxxxx
1
1
hook = RecordWeights()
1.0
1.0
1.0
1.0
1.0
1.0
10.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
10.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
10.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
10.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
10.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
10.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
10.0
1.0
1.0
1.0
1.0
1.0
2.2432
1.0
10.0
1.6216
1.0
1.0
1.0
1.0
2.2432
1.0
10.0
1.6216
1.0
1.0
1.0
1.0
2.2432
1.0
10.0
1.6216
1.0
1.0
1.0
1.0
2.2432
1.0
10.0
1.6216
1.0
1.0
1.0
1.0
2.2432
1.0
10.0
1.6216
1.0
1.0
1.0
1.0
2.2432
1.0
10.0
1.6216
1.0
1.0
1.0
1.0
2.2432
1.0
10.0
1.6216
1.0
1.0
1.0
1.0
2.2432
1.0
10.0
1.6216
1.0
1.0
1.0
1.0
2.2432
1.0
10.0
1.6216
1.0
1.0
1.0
1.0
2.2432
1.0
10.0
1.6216
1.0
1.0
1.0
1.0
2.2432
1.0
10.0
1.6216
1.0
1.0
1.0
1.0
2.2432
1.0
10.0
1.6216
1.0
1.0
1.0
1.0
2.2432
1.0
10.0
1.6216
105.718
94.5561
107.221
79.355
101.402
103.259
6.76732
287.29
105.718
94.5561
107.221
79.355
101.402
103.259
6.76732
287.29
105.718
94.5561
107.221
79.355
101.402
103.259
6.76732
287.29
105.718
94.5561
107.221
79.355
101.402
103.259
6.76732
287.29
105.718
94.5561
107.221
79.355
101.402
103.259
6.76732
287.29
105.718
94.5561
107.221
79.355
101.402
103.259
6.76732
287.29
105.718
94.5561
107.221
79.355
101.402
103.259
6.76732
287.29
105.718
94.5561
107.221
79.355
101.402
103.259
6.76732
287.29
105.718
94.5561
107.221
79.355
101.402
103.259
6.76732
287.29
116.471
94.5561
107.221
79.355
101.402
103.259
6.76732
292.666
xxxxxxxxxx
1
1
run(agent, new_env, StopAfterStep(1000),hook)
xxxxxxxxxx
7
1
begin
2
p = plot(legend=:topleft)
3
for i in 1:length(INIT_WEIGHT)
4
plot!(p, [w[i] for w in hook.weights])
5
end
6
p
7
end