Skip to content

Commit

Permalink
Add 4-partite exact oracle
Browse files Browse the repository at this point in the history
  • Loading branch information
sebastiendesignolle committed Dec 7, 2024
1 parent c6f5e2f commit 4438a63
Showing 1 changed file with 65 additions and 1 deletion.
66 changes: 65 additions & 1 deletion src/fw_methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ function FrankWolfe.compute_extreme_point(
empty!(setm)
end
if verbose && sc scm
println(rpad(string([λa2]), 2 + ndigits(lmo.o[1]^lmo.m[1])), " ", string(-scm))
println(rpad(string([λa1]), 2 + ndigits(lmo.o[1]^lmo.m[1])), " ", string(-scm))
end
if count && sc scm
push!(setm, collect(BellProbabilitiesDS(ax, lmo)))
Expand Down Expand Up @@ -511,6 +511,70 @@ function FrankWolfe.compute_extreme_point(
return dsm
end

function FrankWolfe.compute_extreme_point(
lmo::BellProbabilitiesLMO{T, 8, 1},
A::Array{T, 8};
verbose = false,
count = false,
sym = false,
kwargs...,
) where {T <: Number}
ax = [ones(Int, lmo.m[n]) for n in 1:4]
sc = zero(T)
axm = [zeros(Int, lmo.m[n]) for n in 1:4]
scm = typemax(T)
# set containing all optimal strategies when count=true
setm = Set{Array{T, 4}}()
for λa4 in 0:(lmo.o[4]^lmo.m[4] - 1)
digits!(ax[4], λa4; base = lmo.o[4])
ax[4] .+= 1
for λa3 in (sym ? λa4 : 0):(lmo.o[3]^lmo.m[3] - 1)
digits!(ax[3], λa3; base = lmo.o[3])
ax[3] .+= 1
for λa2 in (sym ? λa3 : 0):(lmo.o[2]^lmo.m[2] - 1)
digits!(ax[2], λa2; base = lmo.o[2])
ax[2] .+= 1
for x1 in 1:length(ax[1])
for a1 in 1:lmo.o[1]
s = zero(T)
for x2 in 1:length(ax[2]), x3 in 1:length(ax[3]), x4 in 1:length(ax[4])
s += A[a1, ax[2][x2], ax[3][x3], ax[4][x4], x1, x2, x3, x4]
end
lmo.tmp[1][x1, a1] = s
end
end
for x1 in 1:length(ax[1])
ax[1][x1] = argmin(lmo.tmp[1][x1, :])[1]
end
sc = zero(T)
for x1 in 1:length(ax[1])
sc += lmo.tmp[1][x1, ax[1][x1]]
end
if sc < scm
scm = sc
for n in 1:4
axm[n] .= ax[n]
end
empty!(setm)
end
if verbose && sc scm
println(rpad(string([λa4, λa3, λa2]), 6 + ndigits(lmo.o[4]^lmo.m[4]) + ndigits(lmo.o[3]^lmo.m[3]) + ndigits(lmo.o[2]^lmo.m[2])), " ", string(-scm))
end
if count && sc scm
push!(setm, collect(BellProbabilitiesDS(ax, lmo)))
end
end
end
end
if count
println(length(setm))
end
dsm = BellProbabilitiesDS(axm, lmo)
lmo.cnt += 1
return dsm
end


##############
# ACTIVE SET #
##############
Expand Down

0 comments on commit 4438a63

Please sign in to comment.