diff --git a/src/core/function/typed.js b/src/core/function/typed.js index 8de1b761dd..09c1cc4bde 100644 --- a/src/core/function/typed.js +++ b/src/core/function/typed.js @@ -152,15 +152,16 @@ export const createTyped = /* #__PURE__ */ factory('typed', dependencies, functi { name: 'FunctionNode', test: isFunctionNode }, { name: 'FunctionAssignmentNode', test: isFunctionAssignmentNode }, { name: 'IndexNode', test: isIndexNode }, - { name: 'Node', test: isNode }, { name: 'ObjectNode', test: isObjectNode }, { name: 'OperatorNode', test: isOperatorNode }, { name: 'ParenthesisNode', test: isParenthesisNode }, { name: 'RangeNode', test: isRangeNode }, { name: 'RelationalNode', test: isRelationalNode }, { name: 'SymbolNode', test: isSymbolNode }, + { name: 'Node', test: isNode }, { name: 'Map', test: isMap }, + { name: 'Set', test: entity => entity instanceof Set }, { name: 'Object', test: isObject } // order 'Object' last, it matches on other classes too ]) diff --git a/src/function/algebra/resolve.js b/src/function/algebra/resolve.js index 5e8f80f155..7a840865e3 100644 --- a/src/function/algebra/resolve.js +++ b/src/function/algebra/resolve.js @@ -1,5 +1,5 @@ import { createMap } from '../../utils/map.js' -import { isFunctionNode, isNode, isOperatorNode, isParenthesisNode, isSymbolNode } from '../../utils/is.js' +import { isNode } from '../../utils/is.js' import { factory } from '../../utils/factory.js' const name = 'resolve' @@ -46,65 +46,86 @@ export const createResolve = /* #__PURE__ */ factory(name, dependencies, ({ * If there is a cyclic dependency among the variables in `scope`, * resolution is impossible and a ReferenceError is thrown. */ - function _resolve (node, scope, within = new Set()) { // note `within`: - // `within` is not documented, since it is for internal cycle - // detection only - if (!scope) { - return node - } - if (isSymbolNode(node)) { - if (within.has(node.name)) { - const variables = Array.from(within).join(', ') - throw new ReferenceError( - `recursive loop of variable definitions among {${variables}}` - ) + return typed('resolve', { + // First, the specific implementations that handle different Node types: + // (Note these take a "within" argument for cycle detection that is not + // part of the documented operation, as it is used only for internal + // cycle detection.) + 'SymbolNode, Map | null | undefined, Set': typed.referToSelf(self => + (symbol, scope, within) => { + // The key case for resolve; most other nodes we just recurse. + if (!scope) return symbol + if (within.has(symbol.name)) { + const variables = Array.from(within).join(', ') + throw new ReferenceError( + `recursive loop of variable definitions among {${variables}}` + ) + } + const value = scope.get(symbol.name) + if (isNode(value)) { + const nextWithin = new Set(within) + nextWithin.add(symbol.name) + return self(value, scope, nextWithin) + } + if (typeof value === 'number') { + return parse(String(value)) // ?? is this just to get the currently + // defined behavior for number literals, i.e. maybe numbers are + // currently being coerced to BigNumber? + } + if (value !== undefined) { + return new ConstantNode(value) + } + return symbol } - const value = scope.get(node.name) - if (isNode(value)) { - const nextWithin = new Set(within) - nextWithin.add(node.name) - return _resolve(value, scope, nextWithin) - } else if (typeof value === 'number') { - return parse(String(value)) - } else if (value !== undefined) { - return new ConstantNode(value) - } else { - return node + ), + 'OperatorNode, Map | null | undefined, Set': typed.referToSelf(self => + (operator, scope, within) => { + const args = operator.args.map(arg => self(arg, scope, within)) + // Has its own implementation because we don't recurse on the op also + return new OperatorNode( + operator.op, operator.fn, args, operator.implicit) } - } else if (isOperatorNode(node)) { - const args = node.args.map(function (arg) { - return _resolve(arg, scope, within) - }) - return new OperatorNode(node.op, node.fn, args, node.implicit) - } else if (isParenthesisNode(node)) { - return new ParenthesisNode(_resolve(node.content, scope, within)) - } else if (isFunctionNode(node)) { - const args = node.args.map(function (arg) { - return _resolve(arg, scope, within) - }) - return new FunctionNode(node.name, args) - } - - // Otherwise just recursively resolve any children (might also work - // for some of the above special cases) - return node.map(child => _resolve(child, scope, within)) - } - - return typed('resolve', { - Node: _resolve, - 'Node, Map | null | undefined': _resolve, - 'Node, Object': (n, scope) => _resolve(n, createMap(scope)), - // For arrays and matrices, we map `self` rather than `_resolve` - // because resolve is fairly expensive anyway, and this way - // we get nice error messages if one entry in the array has wrong type. - 'Array | Matrix': typed.referToSelf(self => A => A.map(n => self(n))), - 'Array | Matrix, null | undefined': typed.referToSelf( - self => A => A.map(n => self(n))), - 'Array, Object': typed.referTo( - 'Array,Map', selfAM => (A, scope) => selfAM(A, createMap(scope))), - 'Matrix, Object': typed.referTo( - 'Matrix,Map', selfMM => (A, scope) => selfMM(A, createMap(scope))), - 'Array | Matrix, Map': typed.referToSelf( - self => (A, scope) => A.map(n => self(n, scope))) + ), + 'FunctionNode, Map | null | undefined, Set': typed.referToSelf(self => + (func, scope, within) => { + const args = func.args.map(arg => self(arg, scope, within)) + // The only reason this has a separate implementation of its own + // is that we don't resolve the func.name itself. But is that + // really right? If the tree being resolved was the parse of + // 'f(x,y)' and 'f' is defined in the scope, is it clear that we + // don't want to replace the function symbol, too? Anyhow, leaving + // the implementation as it was before the refactoring. + return new FunctionNode(func.name, args) + } + ), + 'Node, Map | null | undefined, Set': typed.referToSelf(self => + (node, scope, within) => { + // The generic case: just recurse + return node.map(child => self(child, scope, within)) + } + ), + // Second, generic forwarders to deal with optional arguments and different types: + Node: typed.referToSelf( + self => node => self(node, undefined, new Set()) + ), + 'Node, Map | null | undefined': typed.referToSelf( + self => (node, scope) => self(node, scope, new Set()) + ), + 'Node, Object': typed.referToSelf( + self => (node, objScope) => self(node, createMap(objScope), new Set()) + ), + // And finally, the array/matrix handlers: + 'Array | Matrix': typed.referToSelf( + self => A => A.map(n => self(n, undefined, new Set())) + ), + 'Array | Matrix, Map | null | undefined': typed.referToSelf( + self => (A, scope) => A.map(n => self(n, scope, new Set())) + ), + 'Array | Matrix, Object': typed.referToSelf( + self => (A, objScope) => A.map(n => self(n, createMap(objScope), new Set())) + ), + 'Array | Matrix, Map | null | undefined, Set': typed.referToSelf( + self => (A, scope, within) => A.map(n => self(n, scope, within)) + ) }) }) diff --git a/test/unit-tests/function/algebra/resolve.test.js b/test/unit-tests/function/algebra/resolve.test.js index c0e73b7504..5fdadf05a6 100644 --- a/test/unit-tests/function/algebra/resolve.test.js +++ b/test/unit-tests/function/algebra/resolve.test.js @@ -102,4 +102,53 @@ describe('resolve', function () { }), /ReferenceError.*\{x, y, z\}/) }) + + it('should allow resolving custom nodes in custom ways', function () { + const mymath = math.create() + const Node = mymath.Node + class IntervalNode extends Node { + // a node that represents any value in a closed interval + constructor (left, right) { + super() + this.left = left + this.right = right + } + + static name = 'IntervalNode' + get type () { return 'IntervalNode' } + get isIntervalNode () { return true } + clone () { + return new IntervalNode(this.left, this.right) + } + + _toString (options) { + return `[|${this.left}, ${this.right}|]` + } + + midpoint () { + return (this.left + this.right) / 2 + } + } + + mymath.typed.addTypes( + [{ + name: 'IntervalNode', + test: entity => entity && entity.isIntervalNode + }], + 'RangeNode') // Insert just before RangeNode in type order + + // IntervalNodes resolve to their midpoint: + const resolveInterval = mymath.typed({ + 'IntervalNode, Map|null|undefined, Set': + (node, _scope, _within) => new mymath.ConstantNode(node.midpoint()) + }) + // Merge with standard resolve: + mymath.import({ resolve: resolveInterval }) + + // And finally test: + const innerNode = new IntervalNode(1, 3) + const outerNode = new mymath.OperatorNode( + '+', 'add', [innerNode, new mymath.ConstantNode(4)]) + assert.strictEqual(mymath.resolve(outerNode).toString(), '2 + 4') + }) })