diff --git a/subprojects/generator/src/main/java/tools/refinery/generator/ModelGeneratorFactory.java b/subprojects/generator/src/main/java/tools/refinery/generator/ModelGeneratorFactory.java index ec273cf4d..877db0bbe 100644 --- a/subprojects/generator/src/main/java/tools/refinery/generator/ModelGeneratorFactory.java +++ b/subprojects/generator/src/main/java/tools/refinery/generator/ModelGeneratorFactory.java @@ -25,6 +25,8 @@ import java.util.Collection; import java.util.Set; +// This class is used as a fluent builder, so it's not necessary to use the return value of all of its methods. +@SuppressWarnings("UnusedReturnValue") public final class ModelGeneratorFactory { @Inject private Provider initializerProvider; @@ -38,6 +40,8 @@ public final class ModelGeneratorFactory { private boolean partialInterpretationBasedNeighbourhoods; + private int stateCoderDepth = NeighbourhoodCalculator.DEFAULT_DEPTH; + public ModelGeneratorFactory cancellationToken(CancellationToken cancellationToken) { this.cancellationToken = cancellationToken; return this; @@ -48,8 +52,15 @@ public ModelGeneratorFactory debugPartialInterpretations(boolean debugPartialInt return this; } - public void partialInterpretationBasedNeighbourhoods(boolean partialInterpretationBasedNeighbourhoods) { + public ModelGeneratorFactory partialInterpretationBasedNeighbourhoods( + boolean partialInterpretationBasedNeighbourhoods) { this.partialInterpretationBasedNeighbourhoods = partialInterpretationBasedNeighbourhoods; + return this; + } + + public ModelGeneratorFactory stateCoderDepth(int stateCoderDepth) { + this.stateCoderDepth = stateCoderDepth; + return this; } public ModelGenerator createGenerator(Problem problem) { @@ -68,7 +79,7 @@ public ModelGenerator createGenerator(Problem problem) { initializer.configureStoreBuilder(storeBuilder); var store = storeBuilder.build(); return new ModelGenerator(initializer.getProblemTrace(), store, initializer.getModelSeed(), - solutionSerializerProvider); + solutionSerializerProvider); } private Collection getRequiredInterpretations() { @@ -78,7 +89,8 @@ private Collection getRequiredInterpretations() { } private StateCodeCalculatorFactory getStateCoderCalculatorFactory() { - return partialInterpretationBasedNeighbourhoods ? PartialNeighbourhoodCalculator.FACTORY : - NeighbourhoodCalculator::new; + return partialInterpretationBasedNeighbourhoods ? + PartialNeighbourhoodCalculator.factory(Concreteness.PARTIAL, stateCoderDepth) : + NeighbourhoodCalculator.factory(stateCoderDepth); } } diff --git a/subprojects/store-reasoning/src/main/java/tools/refinery/store/reasoning/interpretation/PartialNeighbourhoodCalculator.java b/subprojects/store-reasoning/src/main/java/tools/refinery/store/reasoning/interpretation/PartialNeighbourhoodCalculator.java index 859cf7c12..67e20dbc9 100644 --- a/subprojects/store-reasoning/src/main/java/tools/refinery/store/reasoning/interpretation/PartialNeighbourhoodCalculator.java +++ b/subprojects/store-reasoning/src/main/java/tools/refinery/store/reasoning/interpretation/PartialNeighbourhoodCalculator.java @@ -21,13 +21,12 @@ public class PartialNeighbourhoodCalculator extends AbstractNeighbourhoodCalculator> { private final ModelQueryAdapter queryAdapter; + private final Concreteness concreteness; - public static final StateCodeCalculatorFactory FACTORY = - (model, ignoredInterpretations, individuals) -> new PartialNeighbourhoodCalculator(model, individuals); - - protected PartialNeighbourhoodCalculator(Model model, IntSet individuals) { - super(model, individuals); + protected PartialNeighbourhoodCalculator(Model model, IntSet individuals, Concreteness concreteness, int depth) { + super(model, individuals, depth); queryAdapter = model.getAdapter(ModelQueryAdapter.class); + this.concreteness = concreteness; } @Override @@ -42,7 +41,7 @@ public StateCoderResult calculateCodes() { var partialSymbols = adapter.getStoreAdapter().getPartialSymbols(); return partialSymbols.stream() .>map(partialSymbol -> - adapter.getPartialInterpretation(Concreteness.PARTIAL, (PartialSymbol) partialSymbol)) + adapter.getPartialInterpretation(concreteness, (PartialSymbol) partialSymbol)) .toList(); } @@ -60,4 +59,9 @@ protected Object getNullValue(PartialInterpretation interpretation) { protected Cursor getCursor(PartialInterpretation interpretation) { return interpretation.getAll(); } + + public static StateCodeCalculatorFactory factory(Concreteness concreteness, int depth) { + return (model, interpretations, individuals) -> new PartialNeighbourhoodCalculator(model, individuals, + concreteness, depth); + } } diff --git a/subprojects/store/src/main/java/tools/refinery/store/statecoding/internal/StateCoderBuilderImpl.java b/subprojects/store/src/main/java/tools/refinery/store/statecoding/internal/StateCoderBuilderImpl.java index eed591e71..36d7d4c71 100644 --- a/subprojects/store/src/main/java/tools/refinery/store/statecoding/internal/StateCoderBuilderImpl.java +++ b/subprojects/store/src/main/java/tools/refinery/store/statecoding/internal/StateCoderBuilderImpl.java @@ -27,7 +27,7 @@ public class StateCoderBuilderImpl extends AbstractModelAdapterBuilder excluded = new HashSet<>(); private final MutableIntSet individuals = IntSets.mutable.empty(); - private StateCodeCalculatorFactory calculator = NeighbourhoodCalculator::new; + private StateCodeCalculatorFactory calculator = NeighbourhoodCalculator.factory(); private StateEquivalenceChecker checker = new StateEquivalenceCheckerImpl(); @Override diff --git a/subprojects/store/src/main/java/tools/refinery/store/statecoding/neighbourhood/AbstractNeighbourhoodCalculator.java b/subprojects/store/src/main/java/tools/refinery/store/statecoding/neighbourhood/AbstractNeighbourhoodCalculator.java index 5bfc47253..a4060e74c 100644 --- a/subprojects/store/src/main/java/tools/refinery/store/statecoding/neighbourhood/AbstractNeighbourhoodCalculator.java +++ b/subprojects/store/src/main/java/tools/refinery/store/statecoding/neighbourhood/AbstractNeighbourhoodCalculator.java @@ -20,6 +20,7 @@ public abstract class AbstractNeighbourhoodCalculator implements StateCodeCalculator { private final Model model; private final IntSet individuals; + private final int depth; private List nullImpactValues; private LinkedHashMap impactValues; private MutableIntLongMap individualHashValues; @@ -28,9 +29,10 @@ public abstract class AbstractNeighbourhoodCalculator implements StateCodeCal protected static final long PRIME = 31; - protected AbstractNeighbourhoodCalculator(Model model, IntSet individuals) { + protected AbstractNeighbourhoodCalculator(Model model, IntSet individuals, int depth) { this.model = model; this.individuals = individuals; + this.depth = depth; } protected Model getModel() { @@ -64,7 +66,7 @@ public StateCoderResult calculateCodes() { nextObjectCode = tempObjectCode; nextObjectCode.clear(); rounds++; - } while (rounds <= 7 && rounds <= previousObjectCode.getEffectiveSize()); + } while (rounds <= depth && rounds <= previousObjectCode.getEffectiveSize()); long result = calculateLastSum(previousObjectCode); return new StateCoderResult((int) result, previousObjectCode); diff --git a/subprojects/store/src/main/java/tools/refinery/store/statecoding/neighbourhood/NeighbourhoodCalculator.java b/subprojects/store/src/main/java/tools/refinery/store/statecoding/neighbourhood/NeighbourhoodCalculator.java index f6071c5ba..5859ccc26 100644 --- a/subprojects/store/src/main/java/tools/refinery/store/statecoding/neighbourhood/NeighbourhoodCalculator.java +++ b/subprojects/store/src/main/java/tools/refinery/store/statecoding/neighbourhood/NeighbourhoodCalculator.java @@ -9,16 +9,19 @@ import tools.refinery.store.map.Cursor; import tools.refinery.store.model.Interpretation; import tools.refinery.store.model.Model; +import tools.refinery.store.statecoding.StateCodeCalculatorFactory; import tools.refinery.store.tuple.Tuple; import java.util.List; public class NeighbourhoodCalculator extends AbstractNeighbourhoodCalculator> { + public static final int DEFAULT_DEPTH = 7; + private final List> interpretations; - public NeighbourhoodCalculator(Model model, List> interpretations, - IntSet individuals) { - super(model, individuals); + protected NeighbourhoodCalculator(Model model, List> interpretations, + IntSet individuals, int depth) { + super(model, individuals, depth); this.interpretations = List.copyOf(interpretations); } @@ -41,4 +44,13 @@ protected Object getNullValue(Interpretation interpretation) { protected Cursor getCursor(Interpretation interpretation) { return interpretation.getAll(); } + + public static StateCodeCalculatorFactory factory(int depth) { + return (model, interpretations, individuals) -> new NeighbourhoodCalculator(model, interpretations, + individuals, depth); + } + + public static StateCodeCalculatorFactory factory() { + return factory(DEFAULT_DEPTH); + } }