Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support context propagation without breaking changes #943

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions cel/cel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1043,6 +1043,41 @@ func BenchmarkContextEval(b *testing.B) {
}
}

func TestContextEvalPropagation(t *testing.T) {
env, err := NewEnv(Function("test",
Overload("test_int", []*Type{}, IntType,
FunctionBindingContext(func(ctx context.Context, _ ...ref.Val) ref.Val {
md := ctx.Value("metadata")
if md == nil {
return types.NewErr("cannot find metadata value")
}
return types.Int(md.(int))
}),
),
))
if err != nil {
t.Fatalf("NewEnv() failed: %v", err)
}
ast, iss := env.Compile("test()")
if iss.Err() != nil {
t.Fatalf("env.Compile(expr) failed: %v", iss.Err())
}
prg, err := env.Program(ast)
if err != nil {
t.Fatalf("env.Program() failed: %v", err)
}

expected := 10
ctx := context.WithValue(context.Background(), "metadata", expected)
out, _, err := prg.ContextEval(ctx, map[string]interface{}{})
if err != nil {
t.Fatalf("prg.ContextEval() failed: %v", err)
}
if out != types.Int(expected) {
t.Errorf("prg.ContextEval() got %v, but wanted %d", out, expected)
}
}

func TestEvalRecover(t *testing.T) {
e := testEnv(t,
Function("panic",
Expand Down
18 changes: 18 additions & 0 deletions cel/decls.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,18 +275,36 @@ func UnaryBinding(binding functions.UnaryOp) OverloadOpt {
return decls.UnaryBinding(binding)
}

// UnaryBindingContext provides the implementation of a unary overload. The provided function is protected by a runtime
// type-guard which ensures runtime type agreement between the overload signature and runtime argument types.
func UnaryBindingContext(binding functions.UnaryOpContext) OverloadOpt {
return decls.UnaryBindingContext(binding)
}

// BinaryBinding provides the implementation of a binary overload. The provided function is protected by a runtime
// type-guard which ensures runtime type agreement between the overload signature and runtime argument types.
func BinaryBinding(binding functions.BinaryOp) OverloadOpt {
return decls.BinaryBinding(binding)
}

// BinaryBindingContext provides the implementation of a binary overload. The provided function is protected by a runtime
// type-guard which ensures runtime type agreement between the overload signature and runtime argument types.
func BinaryBindingContext(binding functions.BinaryOpContext) OverloadOpt {
return decls.BinaryBindingContext(binding)
}

// FunctionBinding provides the implementation of a variadic overload. The provided function is protected by a runtime
// type-guard which ensures runtime type agreement between the overload signature and runtime argument types.
func FunctionBinding(binding functions.FunctionOp) OverloadOpt {
return decls.FunctionBinding(binding)
}

// FunctionBindingContext provides the implementation of a variadic overload. The provided function is protected by a runtime
// type-guard which ensures runtime type agreement between the overload signature and runtime argument types.
func FunctionBindingContext(binding functions.FunctionOpContext) OverloadOpt {
return decls.FunctionBindingContext(binding)
}

// OverloadIsNonStrict enables the function to be called with error and unknown argument values.
//
// Note: do not use this option unless absoluately necessary as it should be an uncommon feature.
Expand Down
8 changes: 7 additions & 1 deletion cel/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package cel

import (
"context"
"errors"
"sync"

Expand Down Expand Up @@ -452,14 +453,19 @@ func (e *Env) ParseSource(src Source) (*Ast, *Issues) {

// Program generates an evaluable instance of the Ast within the environment (Env).
func (e *Env) Program(ast *Ast, opts ...ProgramOption) (Program, error) {
return e.ProgramContext(context.Background(), ast, opts...)
}

// ProgramContext generates an evaluable instance of the Ast within the environment (Env).
func (e *Env) ProgramContext(ctx context.Context, ast *Ast, opts ...ProgramOption) (Program, error) {
optSet := e.progOpts
if len(opts) != 0 {
mergedOpts := []ProgramOption{}
mergedOpts = append(mergedOpts, e.progOpts...)
mergedOpts = append(mergedOpts, opts...)
optSet = mergedOpts
}
return newProgram(e, ast, optSet)
return newProgram(ctx, e, ast, optSet)
}

// CELTypeAdapter returns the `types.Adapter` configured for the environment.
Expand Down
2 changes: 1 addition & 1 deletion cel/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ type ProgramOption func(p *prog) (*prog, error)
// InterpretableDecorators can be used to inspect, alter, or replace the Program plan.
func CustomDecorator(dec interpreter.InterpretableDecorator) ProgramOption {
return func(p *prog) (*prog, error) {
p.decorators = append(p.decorators, dec)
p.decorators = append(p.decorators, interpreter.ToInterpretableDecoratorContext(dec))
return p, nil
}
}
Expand Down
61 changes: 35 additions & 26 deletions cel/program.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,16 +118,16 @@ type prog struct {
evalOpts EvalOption
defaultVars interpreter.Activation
dispatcher interpreter.Dispatcher
interpreter interpreter.Interpreter
interpreter interpreter.InterpreterContext
interruptCheckFrequency uint

// Intermediate state used to configure the InterpretableDecorator set provided
// to the initInterpretable call.
decorators []interpreter.InterpretableDecorator
decorators []interpreter.InterpretableDecoratorContext
regexOptimizations []*interpreter.RegexOptimization

// Interpretable configured from an Ast and aggregate decorator set based on program options.
interpretable interpreter.Interpretable
interpretable interpreter.InterpretableContext
callCostEstimator interpreter.ActualCostEstimator
costOptions []interpreter.CostTrackerOption
costLimit *uint64
Expand All @@ -151,15 +151,15 @@ func (p *prog) clone() *prog {
// ProgramOption values.
//
// If the program cannot be configured the prog will be nil, with a non-nil error response.
func newProgram(e *Env, a *Ast, opts []ProgramOption) (Program, error) {
func newProgram(ctx context.Context, e *Env, a *Ast, opts []ProgramOption) (Program, error) {
// Build the dispatcher, interpreter, and default program value.
disp := interpreter.NewDispatcher()
disp := interpreter.NewDispatcherContext()

// Ensure the default attribute factory is set after the adapter and provider are
// configured.
p := &prog{
Env: e,
decorators: []interpreter.InterpretableDecorator{},
decorators: []interpreter.InterpretableDecoratorContext{},
dispatcher: disp,
costOptions: []interpreter.CostTrackerOption{},
}
Expand All @@ -175,45 +175,46 @@ func newProgram(e *Env, a *Ast, opts []ProgramOption) (Program, error) {

// Add the function bindings created via Function() options.
for _, fn := range e.functions {
bindings, err := fn.Bindings()
bindings, err := fn.BindingsContext()
if err != nil {
return nil, err
}
err = disp.Add(bindings...)
err = disp.AddContext(bindings...)
if err != nil {
return nil, err
}
}

// Set the attribute factory after the options have been set.
var attrFactory interpreter.AttributeFactory
var attrFactory interpreter.AttributeFactoryContext
attrFactorOpts := []interpreter.AttrFactoryOption{
interpreter.EnableErrorOnBadPresenceTest(p.HasFeature(featureEnableErrorOnBadPresenceTest)),
}
if p.evalOpts&OptPartialEval == OptPartialEval {
attrFactory = interpreter.NewPartialAttributeFactory(e.Container, e.adapter, e.provider, attrFactorOpts...)
attrFactory = interpreter.NewPartialAttributeFactoryContext(e.Container, e.adapter, e.provider, attrFactorOpts...)
} else {
attrFactory = interpreter.NewAttributeFactory(e.Container, e.adapter, e.provider, attrFactorOpts...)
attrFactory = interpreter.NewAttributeFactoryContext(e.Container, e.adapter, e.provider, attrFactorOpts...)
}
interp := interpreter.NewInterpreter(disp, e.Container, e.provider, e.adapter, attrFactory)

interp := interpreter.NewInterpreterContext(disp, e.Container, e.provider, e.adapter, attrFactory)
p.interpreter = interp

// Translate the EvalOption flags into InterpretableDecorator instances.
decorators := make([]interpreter.InterpretableDecorator, len(p.decorators))
// Translate the EvalOption flags into InterpretableDecoratorContext instances.
decorators := make([]interpreter.InterpretableDecoratorContext, len(p.decorators))
copy(decorators, p.decorators)

// Enable interrupt checking if there's a non-zero check frequency
if p.interruptCheckFrequency > 0 {
decorators = append(decorators, interpreter.InterruptableEval())
decorators = append(decorators, interpreter.InterruptableEvalContext())
}
// Enable constant folding first.
if p.evalOpts&OptOptimize == OptOptimize {
decorators = append(decorators, interpreter.Optimize())
decorators = append(decorators, interpreter.OptimizeContext())
p.regexOptimizations = append(p.regexOptimizations, interpreter.MatchesRegexOptimization)
}
// Enable regex compilation of constants immediately after folding constants.
if len(p.regexOptimizations) > 0 {
decorators = append(decorators, interpreter.CompileRegexConstants(p.regexOptimizations...))
decorators = append(decorators, interpreter.CompileRegexConstantsContext(p.regexOptimizations...))
}

// Enable exhaustive eval, state tracking and cost tracking last since they require a factory.
Expand Down Expand Up @@ -243,21 +244,21 @@ func newProgram(e *Env, a *Ast, opts []ProgramOption) (Program, error) {

// Enable exhaustive eval over a basic observer since it offers a superset of features.
if p.evalOpts&OptExhaustiveEval == OptExhaustiveEval {
decs = append(decs, interpreter.ExhaustiveEval(), interpreter.Observe(observers...))
decs = append(decs, interpreter.ExhaustiveEvalContext(), interpreter.ObserveContext(observers...))
} else if len(observers) > 0 {
decs = append(decs, interpreter.Observe(observers...))
decs = append(decs, interpreter.ObserveContext(observers...))
}

return p.clone().initInterpretable(a, decs)
return p.clone().initInterpretable(ctx, a, decs)
}
return newProgGen(factory)
}
return p.initInterpretable(a, decorators)
return p.initInterpretable(ctx, a, decorators)
}

func (p *prog) initInterpretable(a *Ast, decs []interpreter.InterpretableDecorator) (*prog, error) {
func (p *prog) initInterpretable(ctx context.Context, a *Ast, decs []interpreter.InterpretableDecoratorContext) (*prog, error) {
// When the AST has been exprAST it contains metadata that can be used to speed up program execution.
interpretable, err := p.interpreter.NewInterpretable(a.impl, decs...)
interpretable, err := p.interpreter.NewInterpretableContext(ctx, a.impl, decs...)
if err != nil {
return nil, err
}
Expand All @@ -267,6 +268,10 @@ func (p *prog) initInterpretable(a *Ast, decs []interpreter.InterpretableDecorat

// Eval implements the Program interface method.
func (p *prog) Eval(input any) (v ref.Val, det *EvalDetails, err error) {
return p.eval(context.Background(), input)
}

func (p *prog) eval(ctx context.Context, input any) (v ref.Val, det *EvalDetails, err error) {
// Configure error recovery for unexpected panics during evaluation. Note, the use of named
// return values makes it possible to modify the error response during the recovery
// function.
Expand Down Expand Up @@ -294,7 +299,7 @@ func (p *prog) Eval(input any) (v ref.Val, det *EvalDetails, err error) {
if p.defaultVars != nil {
vars = interpreter.NewHierarchicalActivation(p.defaultVars, vars)
}
v = p.interpretable.Eval(vars)
v = p.interpretable.EvalContext(ctx, vars)
// The output of an internal Eval may have a value (`v`) that is a types.Err. This step
// translates the CEL value to a Go error response. This interface does not quite match the
// RPC signature which allows for multiple errors to be returned, but should be sufficient.
Expand Down Expand Up @@ -324,7 +329,7 @@ func (p *prog) ContextEval(ctx context.Context, input any) (ref.Val, *EvalDetail
default:
return nil, nil, fmt.Errorf("invalid input, wanted Activation or map[string]any, got: (%T)%v", input, input)
}
return p.Eval(vars)
return p.eval(ctx, vars)
}

// progFactory is a helper alias for marking a program creation factory function.
Expand Down Expand Up @@ -352,6 +357,10 @@ func newProgGen(factory progFactory) (Program, error) {

// Eval implements the Program interface method.
func (gen *progGen) Eval(input any) (ref.Val, *EvalDetails, error) {
return gen.eval(context.Background(), input)
}

func (gen *progGen) eval(ctx context.Context, input any) (ref.Val, *EvalDetails, error) {
// The factory based Eval() differs from the standard evaluation model in that it generates a
// new EvalState instance for each call to ensure that unique evaluations yield unique stateful
// results.
Expand All @@ -371,7 +380,7 @@ func (gen *progGen) Eval(input any) (ref.Val, *EvalDetails, error) {
}

// Evaluate the input, returning the result and the 'state' within EvalDetails.
v, _, err := p.Eval(input)
v, _, err := p.ContextEval(ctx, input)
if err != nil {
return v, det, err
}
Expand Down
Loading