Skip to content

Commit

Permalink
More next item rules refactoring (#75)
Browse files Browse the repository at this point in the history
* More next item rules refactoring

 * Criteria => MultiCriterion
 * Get rid of most functors and convert into named methods
 * Introduce abstract types for pointwise item criteria

* Add tests for Stateful

* Use best_item in StatefulCatConfig

* Fix up benchmark
  • Loading branch information
frankier authored Nov 28, 2024
1 parent 67a2587 commit 16feb96
Show file tree
Hide file tree
Showing 20 changed files with 271 additions and 130 deletions.
14 changes: 10 additions & 4 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,11 @@ function prepare_4pls(group)
tracked_responses = TrackedResponses(BareResponses(ResponseType(item_bank)),
item_bank,
NullAbilityTracker())
group["$(est_nick)_point_mepv_bare"] = @benchmarkable ($next_item_rule)(
$tracked_responses, $item_bank)
group["$(est_nick)_point_mepv_bare"] = @benchmarkable best_item(
$next_item_rule,
$tracked_responses,
$item_bank
)
bare_responses = BareResponses(
ResponseType(item_bank),
response_idxs,
Expand All @@ -60,8 +63,11 @@ function prepare_4pls(group)
bare_responses,
item_bank,
NullAbilityTracker())
group["$(est_nick)_point_mepv_10"] = @benchmarkable ($next_item_rule)(
$tracked_responses, $item_bank)
group["$(est_nick)_point_mepv_10"] = @benchmarkable best_item(
$next_item_rule,
$tracked_responses,
$item_bank
)
end
return group
end
Expand Down
4 changes: 2 additions & 2 deletions src/Sim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using FittedItemBanks: AbstractItemBank, ResponseType
using ..Responses
using ..CatConfig: CatLoopConfig, CatRules
using ..Aggregators: TrackedResponses, add_response!, Speculator, Aggregators
using ..NextItemRules: compute_criteria
using ..NextItemRules: compute_criteria, best_item

export run_cat, prompt_response, auto_responder

Expand Down Expand Up @@ -56,7 +56,7 @@ function run_cat(cat_config::CatLoopConfig{RulesT},
"Best items"
end criteria
try
next_index = next_item(responses, item_bank)
next_index = best_item(next_item, responses, item_bank)
catch exc
if isa(exc, NextItemError)
@warn "Terminating early due to error getting next item" err=sprint(
Expand Down
4 changes: 2 additions & 2 deletions src/Stateful.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using FittedItemBanks: AbstractItemBank, ResponseType
using ..Aggregators: TrackedResponses, Aggregators
using ..CatConfig: CatLoopConfig, CatRules
using ..Responses: BareResponses, Response
using ..NextItemRules: compute_criteria
using ..NextItemRules: compute_criteria, best_item

## StatefulCat interface
abstract type StatefulCat end
Expand Down Expand Up @@ -73,7 +73,7 @@ function StatefulCatConfig(rules, item_bank)
end

function next_item(config::StatefulCatConfig)
return config.rules.next_item(config.tracked_responses, config.item_bank)
return best_item(config.rules.next_item, config.tracked_responses, config.item_bank)
end

function ranked_items(config::StatefulCatConfig)
Expand Down
2 changes: 1 addition & 1 deletion src/decision_tree/DecisionTree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ function generate_dt_cat(config::DecisionTreeGenerationConfig, item_bank)
while true
track!(responses, config.ability_tracker)
ability = config.ability_estimator(responses)
next_item = config.next_item(responses, item_bank)
next_item = best_item(config.next_item, responses, item_bank)

insert!(decision_tree_result, responses.responses, ability, next_item)
if state_tree.cur_depth == state_tree.max_depth
Expand Down
7 changes: 4 additions & 3 deletions src/next_item_rules/NextItemRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,18 @@ export RandomNextItemRule
export ExhaustiveSearch
export catr_next_item_aliases
export preallocate
export compute_criteria
export compute_criteria, compute_criterion, compute_multi_criterion,
compute_pointwise_criterion
export best_item
export PointResponseExpectation, DistributionResponseExpectation
export MatrixScalarizer, DeterminantScalarizer, TraceScalarizer
export AbilityCovarianceStateCriteria, StateCriteria, ItemCriteria
export AbilityCovarianceStateMultiCriterion, StateMultiCriterion, ItemMultiCriterion
export InformationMatrixCriteria
export ScalarizedStateCriteron, ScalarizedItemCriteron

# Prelude
include("./prelude/abstract.jl")
include("./prelude/next_item_rule.jl")
include("./prelude/strategy.jl")
include("./prelude/criteria.jl")
include("./prelude/preallocate.jl")

Expand Down
8 changes: 5 additions & 3 deletions src/next_item_rules/combinators/expectation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,16 @@ function init_thread(::ExpectationBasedItemCriterion, responses::TrackedResponse
end

function _generic_criterion(criterion::StateCriterion, tracked_responses, item_idx)
criterion(tracked_responses)
compute_criterion(criterion, tracked_responses)
end
# TODO: Support init_thread for wrapped ItemCriterion
function _generic_criterion(criterion::ItemCriterion, tracked_responses, item_idx)
criterion(tracked_responses, item_idx)
compute_criterion(criterion, tracked_responses, item_idx)
end

function (item_criterion::ExpectationBasedItemCriterion)(speculator::Speculator,
function compute_criterion(
item_criterion::ExpectationBasedItemCriterion,
speculator::Speculator,
tracked_responses::TrackedResponses,
item_idx)
exp_resp = Aggregators.response_expectation(item_criterion.response_expectation,
Expand Down
45 changes: 23 additions & 22 deletions src/next_item_rules/combinators/scalarizers.jl
Original file line number Diff line number Diff line change
@@ -1,57 +1,58 @@
struct DeterminantScalarizer <: MatrixScalarizer end
(::DeterminantScalarizer)(mat) = det(mat)
scalarize(::DeterminantScalarizer, mat) = det(mat)

struct TraceScalarizer <: MatrixScalarizer end
(::TraceScalarizer)(mat) = tr(mat)
scalarize(::TraceScalarizer, mat) = tr(mat)

struct ScalarizedItemCriteron{
ItemCriteriaT <: ItemCriteria,
ItemMultiCriterionT <: ItemMultiCriterion,
MatrixScalarizerT <: MatrixScalarizer
} <: ItemCriterion
criteria::ItemCriteriaT
criteria::ItemMultiCriterionT
scalarizer::MatrixScalarizerT
end

function (ssc::ScalarizedItemCriteron)(tracked_responses, item_idx)
res = ssc.criteria(
init_thread(ssc.criteria, tracked_responses), tracked_responses, item_idx) |>
ssc.scalarizer
if !should_minimize(ssc.criteria)
res = -res
end
res
end

struct ScalarizedStateCriteron{
StateCriteriaT <: StateCriteria,
StateMultiCriterionT <: StateMultiCriterion,
MatrixScalarizerT <: MatrixScalarizer
} <: StateCriterion
criteria::StateCriteriaT
criteria::StateMultiCriterionT
scalarizer::MatrixScalarizerT
end

function (ssc::ScalarizedStateCriteron)(tracked_responses)
res = ssc.criteria(tracked_responses) |> ssc.scalarizer
function compute_criterion(ssc::Union{ScalarizedItemCriteron, ScalarizedStateCriteron},
tracked_responses::TrackedResponses, item_idx...)
res = scalarize(
ssc.scalarizer,
compute_multi_criterion(
ssc.criteria,
init_thread(ssc.criteria, tracked_responses),
tracked_responses,
item_idx...
)
)
if !should_minimize(ssc.criteria)
res = -res
end
res
end

struct WeightedStateCriteria{InnerT <: StateCriteria} <: StateCriteria
struct WeightedStateMultiCriterion{InnerT <: StateMultiCriterion} <: StateMultiCriterion
weights::Vector{Float64}
criteria::InnerT
end

function (wsc::WeightedStateCriteria)(tracked_responses, item_idx)
function compute_multi_criterion(
wsc::WeightedStateMultiCriterion, tracked_responses::TrackedResponses, item_idx)
wsc.weights' * wsc.criteria(tracked_responses, item_idx) * wsc.weights
end

struct WeightedItemCriteria{InnerT <: ItemCriteria} <: ItemCriteria
struct WeightedItemMultiCriterion{InnerT <: ItemMultiCriterion} <: ItemMultiCriterion
weights::Vector{Float64}
criteria::InnerT
end

function (wsc::WeightedItemCriteria)(tracked_responses, item_idx)
function compute_multi_criterion(
wsc::WeightedItemMultiCriterion, tracked_responses::TrackedResponses, item_idx)
wsc.weights' * wsc.criteria(tracked_responses, item_idx) * wsc.weights
end
9 changes: 6 additions & 3 deletions src/next_item_rules/criteria/item/information.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,17 @@ function InformationItemCriterion(ability_estimator)
InformationItemCriterion(ability_estimator, expected_item_information)
end

function (item_criterion::InformationItemCriterion)(tracked_responses::TrackedResponses,
function compute_criterion(
item_criterion::InformationItemCriterion, tracked_responses::TrackedResponses,
item_idx)
ability = maybe_tracked_ability_estimate(tracked_responses,
item_criterion.ability_estimator)
ir = ItemResponse(tracked_responses.item_bank, item_idx)
return -item_criterion.expected_item_information(ir, ability)
end

struct InformationMatrixCriteria{AbilityEstimatorT <: AbilityEstimator, F} <: ItemCriteria
struct InformationMatrixCriteria{AbilityEstimatorT <: AbilityEstimator, F} <:
ItemMultiCriterion
ability_estimator::AbilityEstimatorT
expected_item_information::F
end
Expand All @@ -35,7 +37,8 @@ function init_thread(item_criterion::InformationMatrixCriteria,
responses_information(responses.item_bank, responses.responses, ability)
end

function (item_criterion::InformationMatrixCriteria)(acc_info::Matrix{Float64},
function compute_multi_criterion(
item_criterion::InformationMatrixCriteria, acc_info::Matrix{Float64},
tracked_responses::TrackedResponses,
item_idx)
# TODO: Add in information from the prior
Expand Down
3 changes: 2 additions & 1 deletion src/next_item_rules/criteria/item/urry.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ function raw_difficulty(item_bank, item_idx)
item_params(item_bank, item_idx).difficulty
end

function (item_criterion::UrryItemCriterion)(tracked_responses::TrackedResponses, item_idx)
function compute_criterion(
item_criterion::UrryItemCriterion, tracked_responses::TrackedResponses, item_idx)
ability = maybe_tracked_ability_estimate(tracked_responses,
item_criterion.ability_estimator)
diff = raw_difficulty(tracked_responses.item_bank, item_idx)
Expand Down
28 changes: 17 additions & 11 deletions src/next_item_rules/criteria/state/ability_variance.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,20 @@ function AbilityVarianceStateCriterion(bits...)
return AbilityVarianceStateCriterion(dist_est, integrator, skip_zero)
end

function (criterion::AbilityVarianceStateCriterion)(tracked_responses::TrackedResponses)::Float64
function compute_criterion(criterion::AbilityVarianceStateCriterion,
tracked_responses::TrackedResponses)::Float64
# XXX: Not sure if the estimator should come from somewhere else here
denom = normdenom(criterion.integrator,
criterion.dist_est,
tracked_responses)
if denom == 0.0 && criterion.skip_zero
return Inf
end
criterion(DomainType(tracked_responses.item_bank), tracked_responses, denom)
compute_criterion(
criterion, DomainType(tracked_responses.item_bank), tracked_responses, denom)
end

function (criterion::AbilityVarianceStateCriterion)(
function compute_criterion(criterion::AbilityVarianceStateCriterion,
::Union{OneDimContinuousDomain, DiscreteDomain},
tracked_responses::TrackedResponses,
denom)::Float64
Expand All @@ -59,9 +61,12 @@ function (criterion::AbilityVarianceStateCriterion)(
)
end

function (criterion::AbilityVarianceStateCriterion)(::Vector,
function compute_criterion(
criterion::AbilityVarianceStateCriterion,
::Vector,
tracked_responses::TrackedResponses,
denom)::Float64
denom
)::Float64
# XXX: Not quite sure about this --- is it useful, the MIRT rules cover this case
mean = expectation(IntegralCoeffs.id,
ndims(tracked_responses.item_bank),
Expand All @@ -77,25 +82,26 @@ function (criterion::AbilityVarianceStateCriterion)(::Vector,
denom)
end

struct AbilityCovarianceStateCriteria{
struct AbilityCovarianceStateMultiCriterion{
DistEstT <: DistributionAbilityEstimator,
IntegratorT <: AbilityIntegrator
} <: StateCriteria
} <: StateMultiCriterion
dist_est::DistEstT
integrator::IntegratorT
skip_zero::Bool
end

function AbilityCovarianceStateCriteria(bits...)
function AbilityCovarianceStateMultiCriterion(bits...)
skip_zero = false
@requiresome (dist_est, integrator) = _get_dist_est_and_integrator(bits...)
return AbilityCovarianceStateCriteria(dist_est, integrator, skip_zero)
return AbilityCovarianceStateMultiCriterion(dist_est, integrator, skip_zero)
end

# XXX: Should be at type level
should_minimize(::AbilityCovarianceStateCriteria) = true
should_minimize(::AbilityCovarianceStateMultiCriterion) = true

function (criteria::AbilityCovarianceStateCriteria)(
function compute_multi_criterion(
criteria::AbilityCovarianceStateMultiCriterion,
tracked_responses::TrackedResponses,
denom = normdenom(criteria.integrator,
criteria.dist_est,
Expand Down
25 changes: 21 additions & 4 deletions src/next_item_rules/prelude/abstract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,32 @@ $(TYPEDEF)
Abstract base type for all item selection rules. All descendants of this type
are expected to implement the interface
`(rule::NextItemRule)(responses::TrackedResponses, items::AbstractItemBank)::Int`
`(::NextItemRule)(responses::TrackedResponses, items::AbstractItemBank)::Int`.
In practice, all adaptive rules in this package use `ItemStrategyNextItemRule`.
$(FUNCTIONNAME)(bits...; ability_estimator=nothing, parallel=true)
Implicit constructor for $(FUNCTIONNAME). Uses any given `NextItemRule` or
delegates to `ItemStrategyNextItemRule`.
delegates to `ItemStrategyNextItemRule` the default instance.
"""
abstract type NextItemRule <: CatConfigBase end

"""
$(TYPEDEF)
Abstract type for next item strategies, tightly coupled with `ItemStrategyNextItemRule`.
All descendants of this type are expected to implement the interface
`(rule::ItemStrategyNextItemRule{::NextItemStrategy, ::ItemCriterion})(responses::TrackedResponses,
items) where {ItemCriterionT <: }
`(strategy::NextItemStrategy)(; parallel=true)::NextItemStrategy`
"""
abstract type NextItemStrategy <: CatConfigBase end

"""
$(TYPEDEF)
Abstract type for next item criteria
"""
abstract type ItemCriterion <: CatConfigBase end

Expand All @@ -28,6 +38,13 @@ $(TYPEDEF)
"""
abstract type StateCriterion <: CatConfigBase end

"""
$(TYPEDEF)
"""
abstract type PointwiseItemCriterion <: CatConfigBase end

abstract type PurePointwiseItemCriterion <: PointwiseItemCriterion end

abstract type MatrixScalarizer end
abstract type StateCriteria end
abstract type ItemCriteria end
abstract type StateMultiCriterion end
abstract type ItemMultiCriterion end
Loading

0 comments on commit 16feb96

Please sign in to comment.