Activating project at `~/code/CounterfactualExplanations.jl/docs`
Data Preprocessing
To illustrate how data is preprocessed, we consider a simple toy dataset with three categorical features (name
, grade
and sex
) and one continuous feature (age
):
= (
X =categorical(["Danesh", "Lee", "Mary", "John"]),
name=categorical(["A", "B", "A", "C"], ordered=true),
grade=categorical(["male","female","male","male"]),
sex=[1.85, 1.67, 1.5, 1.67],
height
)schema(X)
┌────────┬──────────────────┬──────────────────────────────────┐ │ names │ scitypes │ types │ ├────────┼──────────────────┼──────────────────────────────────┤ │ name │ Multiclass{4} │ CategoricalValue{String, UInt32} │ │ grade │ OrderedFactor{3} │ CategoricalValue{String, UInt32} │ │ sex │ Multiclass{2} │ CategoricalValue{String, UInt32} │ │ height │ Continuous │ Float64 │ └────────┴──────────────────┴──────────────────────────────────┘
Categorical features are expected to be one-hot or dummy encoded. To this end, we could use MLJ
, for example:
= OneHotEncoder(drop_last=true)
hot = fit!(machine(hot, X))
mach = transform(mach, X)
W schema(W)
┌ Info: Training machine(OneHotEncoder(features = Symbol[], …), …).
└ @ MLJBase /Users/patrickaltmeyer/.julia/packages/MLJBase/9Nkjh/src/machines.jl:492
┌ Info: Spawning 3 sub-features to one-hot encode feature :name.
└ @ MLJModels /Users/patrickaltmeyer/.julia/packages/MLJModels/8Nrhi/src/builtins/Transformers.jl:878
┌ Info: Spawning 2 sub-features to one-hot encode feature :grade.
└ @ MLJModels /Users/patrickaltmeyer/.julia/packages/MLJModels/8Nrhi/src/builtins/Transformers.jl:878
┌ Info: Spawning 1 sub-features to one-hot encode feature :sex.
└ @ MLJModels /Users/patrickaltmeyer/.julia/packages/MLJModels/8Nrhi/src/builtins/Transformers.jl:878
┌──────────────┬────────────┬─────────┐ │ names │ scitypes │ types │ ├──────────────┼────────────┼─────────┤ │ name__Danesh │ Continuous │ Float64 │ │ name__John │ Continuous │ Float64 │ │ name__Lee │ Continuous │ Float64 │ │ grade__A │ Continuous │ Float64 │ │ grade__B │ Continuous │ Float64 │ │ sex__female │ Continuous │ Float64 │ │ height │ Continuous │ Float64 │ └──────────────┴────────────┴─────────┘
The matrix that will actually be perturbed during the counterfactual search looks as follows:
= permutedims(MLJBase.matrix(W)) X
7×4 Matrix{Float64}:
1.0 0.0 0.0 0.0
0.0 0.0 0.0 1.0
0.0 1.0 0.0 0.0
1.0 0.0 1.0 0.0
0.0 1.0 0.0 0.0
0.0 1.0 0.0 0.0
1.85 1.67 1.5 1.67
The CounterfactualData
constructor takes two optional arguments that can be used to specify the indices of categorical and continues features. If nothing is supplied, all features are assumed to be continuous. For categorical features, the constructor expects and array of arrays of integers (Vector{Vector{Int}}
) where each subarray includes the indices of a all one-hot encoded rows related to a single categorical feature. In the example above, the name
feature is one-hot encoded across rows 1, 2 and 3 of X
.
= [
features_categorical 1,2,3], # name
[4,5], # grade
[6] # sex
[
]= [7] features_continuous
1-element Vector{Int64}:
7
We propose the following simple logic for reconstructing categorical encodings after perturbations:
- For one-hot encoded features with multiple classes, choose the maximum.
- For binary features, clip the perturbed value to fall into \([0,1]\) and round to the nearest of the two integers.
function reconstruct_cat_encoding(x)
map(features_categorical) do cat_group_index
if length(cat_group_index) > 1
= Int.(x[cat_group_index] .== maximum(x[cat_group_index]))
x[cat_group_index] if sum(x[cat_group_index]) > 1
= findall(x[cat_group_index] .== 1)
ties = zeros(length(x[cat_group_index]))
_x = rand(ties,1)[1]
winner = 1
_x[winner] = _x
x[cat_group_index] end
else
= [round(clamp(x[cat_group_index][1],0,1))]
x[cat_group_index] end
end
return x
end
reconstruct_cat_encoding (generic function with 1 method)
Perturbing Single Element
= X[:,1]
x 1] = 1.1
x[ x
7-element Vector{Float64}:
1.1
0.0
0.0
1.0
0.0
0.0
1.85
reconstruct_cat_encoding(x)
7-element Vector{Float64}:
1.0
0.0
0.0
1.0
0.0
0.0
1.85
Perturbing Multiple Elements
2] = 1.1
x[3] = -1.2
x[ x
7-element Vector{Float64}:
1.0
1.1
-1.2
1.0
0.0
0.0
1.85
reconstruct_cat_encoding(x)
7-element Vector{Float64}:
0.0
1.0
0.0
1.0
0.0
0.0
1.85
Breaking ties
1] = 1.0
x[ x
7-element Vector{Float64}:
1.0
1.0
0.0
1.0
0.0
0.0
1.85
reconstruct_cat_encoding(x)
7-element Vector{Float64}:
1.0
0.0
0.0
1.0
0.0
0.0
1.85
Binary
3]] = [0.75]
x[features_categorical[ x
7-element Vector{Float64}:
1.0
0.0
0.0
1.0
0.0
0.75
1.85
reconstruct_cat_encoding(x)
7-element Vector{Float64}:
1.0
0.0
0.0
1.0
0.0
1.0
1.85