Skip to content

Commit

Permalink
refactor(dse): expose state coder parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
kris7t committed Jun 20, 2024
1 parent e0beca7 commit 1140488
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import java.util.Collection;
import java.util.Set;

// This class is used as a fluent builder, so it's not necessary to use the return value of all of its methods.
@SuppressWarnings("UnusedReturnValue")
public final class ModelGeneratorFactory {
@Inject
private Provider<ModelInitializer> initializerProvider;
Expand All @@ -38,6 +40,8 @@ public final class ModelGeneratorFactory {

private boolean partialInterpretationBasedNeighbourhoods;

private int stateCoderDepth = NeighbourhoodCalculator.DEFAULT_DEPTH;

public ModelGeneratorFactory cancellationToken(CancellationToken cancellationToken) {
this.cancellationToken = cancellationToken;
return this;
Expand All @@ -48,8 +52,15 @@ public ModelGeneratorFactory debugPartialInterpretations(boolean debugPartialInt
return this;
}

public void partialInterpretationBasedNeighbourhoods(boolean partialInterpretationBasedNeighbourhoods) {
public ModelGeneratorFactory partialInterpretationBasedNeighbourhoods(
boolean partialInterpretationBasedNeighbourhoods) {
this.partialInterpretationBasedNeighbourhoods = partialInterpretationBasedNeighbourhoods;
return this;
}

public ModelGeneratorFactory stateCoderDepth(int stateCoderDepth) {
this.stateCoderDepth = stateCoderDepth;
return this;
}

public ModelGenerator createGenerator(Problem problem) {
Expand All @@ -68,7 +79,7 @@ public ModelGenerator createGenerator(Problem problem) {
initializer.configureStoreBuilder(storeBuilder);
var store = storeBuilder.build();
return new ModelGenerator(initializer.getProblemTrace(), store, initializer.getModelSeed(),
solutionSerializerProvider);
solutionSerializerProvider);
}

private Collection<Concreteness> getRequiredInterpretations() {
Expand All @@ -78,7 +89,8 @@ private Collection<Concreteness> getRequiredInterpretations() {
}

private StateCodeCalculatorFactory getStateCoderCalculatorFactory() {
return partialInterpretationBasedNeighbourhoods ? PartialNeighbourhoodCalculator.FACTORY :
NeighbourhoodCalculator::new;
return partialInterpretationBasedNeighbourhoods ?
PartialNeighbourhoodCalculator.factory(Concreteness.PARTIAL, stateCoderDepth) :
NeighbourhoodCalculator.factory(stateCoderDepth);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,12 @@

public class PartialNeighbourhoodCalculator extends AbstractNeighbourhoodCalculator<PartialInterpretation<?, ?>> {
private final ModelQueryAdapter queryAdapter;
private final Concreteness concreteness;

public static final StateCodeCalculatorFactory FACTORY =
(model, ignoredInterpretations, individuals) -> new PartialNeighbourhoodCalculator(model, individuals);

protected PartialNeighbourhoodCalculator(Model model, IntSet individuals) {
super(model, individuals);
protected PartialNeighbourhoodCalculator(Model model, IntSet individuals, Concreteness concreteness, int depth) {
super(model, individuals, depth);
queryAdapter = model.getAdapter(ModelQueryAdapter.class);
this.concreteness = concreteness;
}

@Override
Expand All @@ -42,7 +41,7 @@ public StateCoderResult calculateCodes() {
var partialSymbols = adapter.getStoreAdapter().getPartialSymbols();
return partialSymbols.stream()
.<PartialInterpretation<?, ?>>map(partialSymbol ->
adapter.getPartialInterpretation(Concreteness.PARTIAL, (PartialSymbol<?, ?>) partialSymbol))
adapter.getPartialInterpretation(concreteness, (PartialSymbol<?, ?>) partialSymbol))
.toList();
}

Expand All @@ -60,4 +59,9 @@ protected Object getNullValue(PartialInterpretation<?, ?> interpretation) {
protected Cursor<Tuple, ?> getCursor(PartialInterpretation<?, ?> interpretation) {
return interpretation.getAll();
}

public static StateCodeCalculatorFactory factory(Concreteness concreteness, int depth) {
return (model, interpretations, individuals) -> new PartialNeighbourhoodCalculator(model, individuals,
concreteness, depth);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public class StateCoderBuilderImpl extends AbstractModelAdapterBuilder<StateCode
implements StateCoderBuilder {
private final Set<AnySymbol> excluded = new HashSet<>();
private final MutableIntSet individuals = IntSets.mutable.empty();
private StateCodeCalculatorFactory calculator = NeighbourhoodCalculator::new;
private StateCodeCalculatorFactory calculator = NeighbourhoodCalculator.factory();
private StateEquivalenceChecker checker = new StateEquivalenceCheckerImpl();

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
public abstract class AbstractNeighbourhoodCalculator<T> implements StateCodeCalculator {
private final Model model;
private final IntSet individuals;
private final int depth;
private List<T> nullImpactValues;
private LinkedHashMap<T, long[]> impactValues;
private MutableIntLongMap individualHashValues;
Expand All @@ -28,9 +29,10 @@ public abstract class AbstractNeighbourhoodCalculator<T> implements StateCodeCal

protected static final long PRIME = 31;

protected AbstractNeighbourhoodCalculator(Model model, IntSet individuals) {
protected AbstractNeighbourhoodCalculator(Model model, IntSet individuals, int depth) {
this.model = model;
this.individuals = individuals;
this.depth = depth;
}

protected Model getModel() {
Expand Down Expand Up @@ -64,7 +66,7 @@ public StateCoderResult calculateCodes() {
nextObjectCode = tempObjectCode;
nextObjectCode.clear();
rounds++;
} while (rounds <= 7 && rounds <= previousObjectCode.getEffectiveSize());
} while (rounds <= depth && rounds <= previousObjectCode.getEffectiveSize());

long result = calculateLastSum(previousObjectCode);
return new StateCoderResult((int) result, previousObjectCode);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,19 @@
import tools.refinery.store.map.Cursor;
import tools.refinery.store.model.Interpretation;
import tools.refinery.store.model.Model;
import tools.refinery.store.statecoding.StateCodeCalculatorFactory;
import tools.refinery.store.tuple.Tuple;

import java.util.List;

public class NeighbourhoodCalculator extends AbstractNeighbourhoodCalculator<Interpretation<?>> {
public static final int DEFAULT_DEPTH = 7;

private final List<Interpretation<?>> interpretations;

public NeighbourhoodCalculator(Model model, List<? extends Interpretation<?>> interpretations,
IntSet individuals) {
super(model, individuals);
protected NeighbourhoodCalculator(Model model, List<? extends Interpretation<?>> interpretations,
IntSet individuals, int depth) {
super(model, individuals, depth);
this.interpretations = List.copyOf(interpretations);
}

Expand All @@ -41,4 +44,13 @@ protected Object getNullValue(Interpretation<?> interpretation) {
protected Cursor<Tuple, ?> getCursor(Interpretation<?> interpretation) {
return interpretation.getAll();
}

public static StateCodeCalculatorFactory factory(int depth) {
return (model, interpretations, individuals) -> new NeighbourhoodCalculator(model, interpretations,
individuals, depth);
}

public static StateCodeCalculatorFactory factory() {
return factory(DEFAULT_DEPTH);
}
}

0 comments on commit 1140488

Please sign in to comment.