From c4dbb57f567092db081e0b5f0ab2ae00e45b00c0 Mon Sep 17 00:00:00 2001 From: Zachary Sunberg Date: Mon, 27 Jul 2020 16:00:25 -0700 Subject: [PATCH] Version 0.9 (#307) * fix #250 * travis only tests 1.1 and 1 * removed inferred_in_latest * removed all of the old deprecated generative stuff * removed ddn code * before removing programatic deprecation macros * tests pass * before switching back to master * initial steps * tests pass * started * got rid of errors, switched to distribution initialstate (#308) * DDNOut -> Val * brought back DDNOut * tests pass * working on docs * working on docs * cleaned up example * a bit more cleanup * finished documentation to fix #280 * added deprecation case for when initialstate_distribution is implemented * Changed emphasis of explit/generative explanation * Update README.md * fixed typo * Update docs/src/def_solver.md Co-authored-by: Jayesh K. Gupta * Update runtests.jl * moved available() and add_registry() to deprecated.jl * Update def_pomdp.md Co-authored-by: Jayesh K. Gupta --- .travis.yml | 4 +- Project.toml | 9 +- README.md | 8 +- docs/make.jl | 15 +- docs/src/api.md | 78 +---- docs/src/concepts.md | 2 +- docs/src/ddns.md | 58 ---- docs/src/def_pomdp.md | 32 +- docs/src/def_solver.md | 274 +--------------- docs/src/dynamics.md | 39 +++ docs/src/explicit.md | 23 -- docs/src/generative.md | 94 ------ docs/src/index.md | 13 +- docs/src/interfaces.md | 7 +- docs/src/offline_solver.md | 137 ++++++++ docs/src/online_solver.md | 101 ++++++ docs/src/requirements.md | 31 -- docs/src/simulation.md | 4 +- docs/src/specifying_requirements.md | 102 ------ docs/src/{basic_properties.md => static.md} | 18 +- src/POMDPs.jl | 63 +--- src/belief.jl | 1 - src/ddn_struct.jl | 235 -------------- src/deprecated.jl | 135 +++++--- src/errors.jl | 117 ------- src/gen_impl.jl | 140 ++++----- src/generative.jl | 162 ++-------- src/pomdp.jl | 59 ++-- src/requirements_interface.jl | 206 ------------ src/requirements_internals.jl | 331 -------------------- src/requirements_printing.jl | 99 ------ src/space.jl | 20 +- src/utils.jl | 66 ---- test/runtests.jl | 85 ++--- test/test_ddn_struct.jl | 51 --- test/test_deprecated_generative.jl | 36 --- test/test_generative.jl | 106 +++---- test/test_generative_backedges.jl | 20 +- test/test_requirements.jl | 97 ------ 39 files changed, 674 insertions(+), 2404 deletions(-) delete mode 100644 docs/src/ddns.md create mode 100644 docs/src/dynamics.md delete mode 100644 docs/src/explicit.md delete mode 100644 docs/src/generative.md create mode 100644 docs/src/offline_solver.md create mode 100644 docs/src/online_solver.md delete mode 100644 docs/src/requirements.md delete mode 100644 docs/src/specifying_requirements.md rename docs/src/{basic_properties.md => static.md} (57%) delete mode 100644 src/ddn_struct.jl delete mode 100644 src/errors.jl delete mode 100644 src/requirements_interface.jl delete mode 100644 src/requirements_internals.jl delete mode 100644 src/requirements_printing.jl delete mode 100644 src/utils.jl delete mode 100644 test/test_ddn_struct.jl delete mode 100644 test/test_deprecated_generative.jl delete mode 100644 test/test_requirements.jl diff --git a/.travis.yml b/.travis.yml index a102d882..371fab49 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,7 +1,7 @@ language: julia julia: - - 1.0 + - 1.1 - 1 os: - linux @@ -20,7 +20,7 @@ jobs: - os: windows include: - stage: "Documentation" - julia: 1.0 + julia: 1 os: linux script: - julia --project=docs/ -e 'using Pkg; Pkg.instantiate(); diff --git a/Project.toml b/Project.toml index e5958f5d..052d6ca9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,22 +1,21 @@ name = "POMDPs" uuid = "a93abf59-7444-517b-a68a-c42f96afdd7d" -version = "0.8.4" +version = "0.9.0" [deps] Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" -LibGit2 = "76f85450-5226-5b5a-8eaa-529ad045b433" LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d" -Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" NamedTupleTools = "d9ec5142-1e00-5aa0-9d6a-321866360f50" +POMDPLinter = "f3bd98c0-eb40-45e2-9eb1-f2763262d755" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] -Distributions = "0.17,0.18,0.19,0.20,0.21,0.22, 0.23" +Distributions = "0.17, 0.18, 0.19, 0.20, 0.21, 0.22, 0.23" LightGraphs = "1" NamedTupleTools = "0.10, 0.11, 0.12, 0.13" -julia = "1" +julia = "1.1" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/README.md b/README.md index 4a0fe12f..c952b3b6 100644 --- a/README.md +++ b/README.md @@ -11,13 +11,19 @@ This package provides a core interface for working with [Markov decision processes (MDPs)](https://en.wikipedia.org/wiki/Markov_decision_process) and [partially observable Markov decision processes (POMDPs)](https://en.wikipedia.org/wiki/Partially_observable_Markov_decision_process). For examples, please see [POMDPExamples](https://github.com/JuliaPOMDP/POMDPExamples.jl), [QuickPOMDPs](https://github.com/JuliaPOMDP/QuickPOMDPs.jl), and the [Gallery](https://github.com/JuliaPOMDP/POMDPGallery.jl). +--- + +**NOTE**: We are currently in the process of upgrading solvers to work with POMDPs v0.9. This process is expected to take several weeks. For compatibility with all solvers, please install POMDPs v0.8 with `pkg"add POMDPs@v0.8"`. + +--- + Our goal is to provide a common programming vocabulary for: 1. Expressing problems as MDPs and POMDPs. 2. Writing solver software. 3. Running simulations efficiently. -There are [nested interfaces for expressing and interacting with (PO)MDPs](http://juliapomdp.github.io/POMDPs.jl/stable/def_pomdp): When the *[explicit](http://juliapomdp.github.io/POMDPs.jl/stable/explicit)* interface is used, the transition and observation probabilities are explicitly defined using api [functions](http://juliapomdp.github.io/POMDPs.jl/stable/explicit/#functional-form-explicit-pomdp); when the *[generative](http://juliapomdp.github.io/POMDPs.jl/stable/generative)* interface is used, only a single step simulator (e.g. (s', o, r) = G(s,a)) needs to be defined. Problems may also be defined with probability [tables](http://juliapomdp.github.io/POMDPs.jl/stable/explicit/#tabular-form-explicit-pomdp), or with the simplified [QuickPOMDPs interfaces](https://github.com/JuliaPOMDP/QuickPOMDPs.jl). +There are [several ways to define and interact with (PO)MDPs](http://juliapomdp.github.io/POMDPs.jl/stable/def_pomdp): transition and observation distributions and rewards can be defined with explicit probability distributions or implicitly with a function that samples from the distribution, or all of the dynamics can be defined in a single step simulator function: (s', o, r) = G(s,a). Problems may also be defined with probability [tables](https://github.com/JuliaPOMDP/POMDPExamples.jl/blob/master/notebooks/Defining-a-tabular-POMDP.ipynb), and the [QuickPOMDPs interfaces](https://github.com/JuliaPOMDP/QuickPOMDPs.jl) make defining simple problems easier. **Python** can be used to define and solve MDPs and POMDPs via the QuickPOMDPs or tabular interfaces and [pyjulia](https://github.com/JuliaPy/pyjulia) (Example: [tiger.py](https://github.com/JuliaPOMDP/QuickPOMDPs.jl/blob/master/examples/tiger.py)). diff --git a/docs/make.jl b/docs/make.jl index d69f5566..fa1bd451 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -19,18 +19,19 @@ makedocs( "Defining (PO)MDP Models" => [ "def_pomdp.md", - "ddns.md", - "basic_properties.md", - "explicit.md", - "generative.md", + "static.md", "interfaces.md", - "requirements.md", + "dynamics.md", ], - "Writing Solvers and Updaters" => [ + "Writing Solvers" => [ "def_solver.md", - "specifying_requirements.md", + "offline_solver.md", + "online_solver.md" + ], + + "Writing Belief Updaters" => [ "def_updater.md" ], diff --git a/docs/src/api.md b/docs/src/api.md index a158fcb5..e021065b 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -1,24 +1,6 @@ # API Documentation -Documentation for the `POMDPs.jl` user interface. You can get help for any type or -function in the module by typing `?` in the Julia REPL followed by the name of -type or function. For example: - -```julia -julia> using POMDPs -julia> ? -help?> reward -search: reward - - reward{S,A,O}(pomdp::POMDP{S,A,O}, state::S, action::A, statep::S) - - Returns the immediate reward for the s-a-s triple - - reward{S,A,O}(pomdp::POMDP{S,A,O}, state::S, action::A) - - Returns the immediate reward for the s-a pair - -``` +Docstrings for POMDPs.jl interface members can be [accessed through Julia's built-in documentation system](https://docs.julialang.org/en/v1/manual/documentation/index.html#Accessing-Documentation-1) or in the list below. ```@meta CurrentModule = POMDPs @@ -30,7 +12,6 @@ CurrentModule = POMDPs Pages = ["api.md"] ``` - ## Index ```@index @@ -50,33 +31,17 @@ Updater ## Model Functions -### [Explicit](@id explicit_api) - -These functions return *distributions*. +### Dynamics ```@docs transition observation -initialstate_distribution reward -``` - -### [Generative](@id generative_api) - -These functions should return *states*, *observations*, and/or *rewards*. - -!!! note - - `gen` in POMDPs.jl v0.8 corresponds to the `generate_` functions in previous versions - -```@docs -@gen gen -initialstate -initialobs +@gen ``` -### [Common](@id common_api) +### Static Properties ```@docs states @@ -84,6 +49,8 @@ actions observations isterminal discount +initialstate +initialobs stateindex actionindex obsindex @@ -92,29 +59,16 @@ convert_a convert_o ``` -## Distribution/Space Functions +### Distributions and Spaces ```@docs rand pdf mode mean -dimensions support ``` -## Dynamic decision networks - -```@docs -DDNStructure -DDNNode -DDNOut -DistributionDDNNode -FunctionDDNNode -ConstantDDNNode -GenericDDNNode -``` - ## Belief Functions ```@docs @@ -152,26 +106,8 @@ actiontype obstype ``` -### Requirements Specification -```@docs -check_requirements -show_requirements -get_requirements -requirements_info -@POMDP_require -@POMDP_requirements -@requirements_info -@get_requirements -@show_requirements -@warn_requirements -@req -@subreq -implemented -``` - ### Utility Tools ```@docs add_registry -available ``` diff --git a/docs/src/concepts.md b/docs/src/concepts.md index e6ec078e..87de57eb 100644 --- a/docs/src/concepts.md +++ b/docs/src/concepts.md @@ -46,7 +46,7 @@ represented by a concrete subtype of the [`POMDP`](@ref) abstract type, explicit definition is often not required), and `O` is defined by implementing [`observation`](@ref) if the [*explicit*](@ref defining_pomdps) interface is used or [`gen`](@ref) if the [*generative*](@ref defining_pomdps) interface is used. -POMDPs.jl can also be extended to accommodate POMDP-like problem classes with expanded [dynamic decision networks](@ref Dynamic-Decision-Networks), such as constrained or factored POMDPs, and it contains functions for defining optional problem behavior +POMDPs.jl contains additional functions for defining optional problem behavior such as a [discount factor](@ref Discount-Factor) or a set of [terminal states](@ref Terminal-States). More information can be found in the [Defining POMDPs](@ref defining_pomdps) section. diff --git a/docs/src/ddns.md b/docs/src/ddns.md deleted file mode 100644 index d689d853..00000000 --- a/docs/src/ddns.md +++ /dev/null @@ -1,58 +0,0 @@ -# Dynamic Decision Networks - -Part of the conceptual definition of a POMDP or MDP is a dynamic decision network (DDN) that defines which random variables are dependent on each other. -Usually, problem writers will not have to interact directly with the DDN, but it is a helpful concept for understanding, and it can be customized for special problem types. - -The standard POMDPs.jl DDN models are shown below: - -| Standard MDP DDN | Standard POMDP DDN | -|:---:|:---:| -|![MDP DDN](figures/mdp_ddn.svg) | ![POMDP DDN](figures/pomdp_ddn.svg) | - -!!! note - - In order to provide additional flexibility, these DDNs have `:s`→`:o`, `:sp`→`:r` and `:o`→`:r` edges that are typically absent from the DDNs traditionally used in the (PO)MDP literature. Traditional (PO)MDP algorithms are compatible with these DDNs because only ``R(s,a)``, the expectation of ``R(s, a, s', o)`` over all ``s'`` and ``o`` is needed to make optimal decisions. - -## DDN structure representation - -In POMDPs.jl, each DDN node corresponds to a [`Symbol`](https://docs.julialang.org/en/v1/base/base/#Core.Symbol). Often a `p` character (mnemonic: "prime") is appended to denote a new value for the next timestep, e.g. `:sp` represents ``s'``, the state at the next step. - -A [`DDNStructure`](@ref) object contains the names of all the nodes, the edges between the nodes, and an object for each node that defines its implementation. - -Currently, there are four types of nodes: -- [`DistributionDDNNode`](@ref) to define nodes with stochastic output. -- [`FunctionDDNNode`](@ref) to define a node that is a deterministic function of other nodes. -- [`ConstantDDNNode`](@ref) for a constant. -- [`GenericDDNNode`](@ref) for a node that has no implementation other than [`gen`](@ref) (see [Defining behavior for nodes](@ref) below). - -This set is not expected to handle all possible behavior, so new types are likely to be added in the future (and they should be requested when concrete needs are encountered). - -## Defining behavior for nodes - -For any node in the DDN, the function [`gen`](@ref)`(::DDNNode{:nodename}, m, parent_values..., rng)` will be called to sample a value (see the docstring for more information). This method can always be implemented to provide a generative definition for a node. - -Some nodes can alternatively have an explicit implementation. For example, a `DistributionDDNNode` contains a function that is called with the (PO)MDP models and values sampled from the parent nodes to return a distribution. The state transition node, `:sp`, is a particular case of this. If [`gen`](@ref)`(::GenVar{:sp}, m, s, a, rng)` is not defined by the problem writer, `rand(rng, transition(m, s, a))` will be called to generate values for `:sp`. - -### Mixing generative and explicit node definitions for a POMDP - -POMDP models will often contain a mixture of Generative and explicit definitions, and this is an encouraged paradigm. For example - -```julia -using Distributions -struct MyPOMDP <: POMDP{Float64, Float64, Float64} end -POMDPs.gen(::GenVar{:sp}, m::MyPOMDP, s, a, rng) = s+a -POMDPs.observation(::GenVar{:o}, m, s, a, sp, rng) = Normal(sp) -``` -would be a suitable distribution for a POMDP that will be solved with particle filtering methods where an explicit observation definition is needed, but a generative state transition definition is sufficient. - -!!! note - - It is usually best to *avoid* providing both a generative and explicit definition of *the same node* because it is easy to introduce inconsistency. - -## Customizing the DDN - -The DDN structure for a particular (PO)MDP problem `type` is defined with the [`DDNStructure`](@ref) trait, which should return a [`DDNStructure`](@ref) object (or something else that implements the appropriate methods). See the docstring for an example. - -If a specialized DDN structure is to be compatible with standard POMDP solvers, it should have the standard `:sp`, `:r`, and `:o` nodes. - -Currently (as of September, 2019), no solver has special behavior based on the DDN structure, but it is expected that packages will define new DDN structures for specialized cases like constrained POMDPs, mixed observability MDPs, or factored POMDPs. If you are considering creating a solver that relies on a specific DDN structure, please contact the developers so we can coordinate. diff --git a/docs/src/def_pomdp.md b/docs/src/def_pomdp.md index 4a724843..35e5e80a 100644 --- a/docs/src/def_pomdp.md +++ b/docs/src/def_pomdp.md @@ -7,7 +7,7 @@ Since POMDPs.jl was designed with performance and flexibility as first prioritie - [QuickPOMDPs.jl](https://github.com/JuliaPOMDP/QuickPOMDPs.jl) provides structures for concisely defining simple POMDPs without object-oriented programming. - [POMDPExamples.jl](https://github.com/JuliaPOMDP/POMDPExamples.jl) provides tutorials for defining problems. - [The Tabular(PO)MDP model](https://github.com/JuliaPOMDP/POMDPExamples.jl/blob/master/notebooks/Defining-a-tabular-POMDP.ipynb) from [POMDPModels.jl](https://github.com/JuliaPOMDP/POMDPModels.jl) allows users to define POMDPs with matrices for the transitions, observations and rewards. -- The [`gen` function](@ref generative_doc) is the easiest way to wrap a pre-existing simulator from another project or written in another programming language so that it can be used with POMDPs.jl solvers and simulators. See also [RLInterface.jl](https://github.com/JuliaPOMDP/RLInterface.jl) for an even higher level interface for simulators where the state is not accessible. +- The [`gen`](@ref) function is the easiest way to wrap a pre-existing simulator from another project or written in another programming language so that it can be used with POMDPs.jl solvers and simulators. ## Overview @@ -16,36 +16,27 @@ Custom POMDP problems are defined by implementing the functions specified by the !!! note - The main generative and explicit interfaces use an object-oriented programming paradigm and require familiarity with Julia. For users new to Julia, [QuickPOMDPs](https://github.com/JuliaPOMDP/QuickPOMDPs.jl) usually requires less knowledge of the language and no object-oriented programming. + The main POMDPs.jl interface uses an object-oriented programming paradigm and require familiarity with Julia. For users new to Julia, [QuickPOMDPs](https://github.com/JuliaPOMDP/QuickPOMDPs.jl) usually requires less knowledge of the language and no object-oriented programming. -There are two ways of specifying the state dynamics and observation behavior of a POMDP. The problem definition may include a mixture of *explicit* definitions of probability distributions, or *generative* definitions that simulate states and observations without explicitly defining the distributions. In scientific papers explicit definitions are often written as ``T(s' | s, a)`` for transitions and ``O(o | s, a, s')`` for observations, while a generative definition might be expressed as ``s', o, r = G(s, a)`` (or ``s', r = G(s,a)`` for an MDP). - -Accordingly, the POMDPs.jl model API is grouped into three sections: -1. The [*explicit*](@ref explicit_api) interface containing *functions that explicitly define distributions for DDN nodes.* -2. The [*generative*](@ref generative_api) interface containing *functions that return sampled states and observations for DDN nodes.* -3. [*Common*](@ref common_api) functions used in both. +In this guide, the interface is divided into two sections: functions that define static properties of the problem, and functions that describe the dynamics - how the states, observations and rewards change over time. There are two ways of specifying the dynamic behavior of a POMDP. The problem definition may include a mixture of *explicit* definitions of probability distributions, or *generative* definitions that simulate states and observations without explicitly defining the distributions. In scientific papers explicit definitions are often written as ``T(s' | s, a)`` for transitions and ``O(o | s, a, s')`` for observations, while a generative definition might be expressed as ``s', o, r = G(s, a)`` (or ``s', r = G(s,a)`` for an MDP). ## What do I need to implement? -Because of the wide variety or problems and solvers that POMDPs.jl interfaces with, the question of which functions from the interface need to be implemented does not have a short answer for all cases. In general, a problem will be defined by implementing a combination of functions from the generative, explicit, and common parts of the interface. +Because of the wide variety or problems and solvers that POMDPs.jl interfaces with, the question of which functions from the interface need to be implemented does not have a short answer for all cases. In general, a problem will be defined by implementing a combination of functions. Specifically, a problem writer will need to define - Explicit or generative definitions for - - the state transition model ([DDN](@ref Dynamic-Decision-Networks) node `:sp`), - - the reward function ([DDN](@ref Dynamic-Decision-Networks) node `:r`), and - - the observation model ([DDN](@ref Dynamic-Decision-Networks) node `:o`, for POMDPs only). + - the state transition model, + - the reward function, and + - the observation model. - Functions to define some other properties of the problem such as the state, action, and observation spaces, which states are terminal, etc. -!!! note - - Since an explicit definition for a DDN node contains all of the information required for a generative definition, POMDPs.jl will automatically synthesize the generative functions for that node at runtime if an explicit model is available. Thus, there is never a need for both generative and explicit definitions of a node, and it is usually best to avoid redundant definitions because it is easy for them to become inconsistent. - The precise answer for which functions need to be implemented depends on two factors: problem complexity and which solver will be used. In particular, 2 questions should be asked: 1. Is it difficult or impossible to specify a probability distribution explicitly? 2. What solvers will be used to solve this, and what are their requirements? -If the answer to (1) is yes, then a generative definition should be used. Question (2) should be answered by reading about the solvers and trying to run them, or through the [requirements](@ref requirements) interface if it has been defined for the solver. +If the answer to (1) is yes, then a generative definition should be used. Question (2) should be answered by reading about the solvers and trying to run them. Some solvers have specified their requirements using the [POMDPLinter package](https://github.com/JuliaPOMDP/POMDPLinter.jl), however, these requirements are written separately from the solver code, and often the best way is to write a simple prototype problem and running the solver until all `MethodError`s have been fixed. !!! note @@ -55,9 +46,6 @@ If the answer to (1) is yes, then a generative definition should be used. Questi The following pages provide more details on specific parts of the interface: -- [Dynamic Decision Networks](@ref) -- [Explicit DDN node definitions](@ref explicit_doc) -- [Generative DDN node definitions](@ref generative_doc) -- [Basic Properties (common part of the api)](@ref basic) +- [Static Properties](@ref static) - [Spaces and Distributions](@ref) -- [Requirements](@ref requirements) +- [Dynamics](@ref dynamics) diff --git a/docs/src/def_solver.md b/docs/src/def_solver.md index 0829a620..b7cef991 100644 --- a/docs/src/def_solver.md +++ b/docs/src/def_solver.md @@ -1,268 +1,26 @@ -# Defining a Solver +# Solvers -In this section, we will walk through an implementation of the -[QMDP](http://www-anw.cs.umass.edu/~barto/courses/cs687/Cassandra-etal-POMDP.pdf) algorithm. QMDP is the fully -observable approximation of a POMDP policy, and relies on the Q-values to determine actions. +Defining a solver involves creating or using four pieces of code: -## Background +1. A subtype of [`Solver`](@ref) that holds the parameters and configuration options for the solver. +2. A subtype of [`Policy`](@ref) that holds all of the data needed to choose actions online. +3. A method of [`solve`](@ref) that takes the solver and a (PO)MDP as arguments, performs all of the offline computations for solving the problem, and returns the policy. +4. A method of [`action`](@ref) that takes in the policy and a state or belief and returns an action. -Let's say we are working with a POMDP defined by the tuple $(\mathcal{S}, \mathcal{A}, \mathcal{Z}, T, R, O, \gamma)$, -where $\mathcal{S}$, $\mathcal{A}$, $\mathcal{Z}$ are the discrete state, action, and observation spaces -respectively. The QMDP algorithm assumes it is given a discrete POMDP. In our model $T : \mathcal{S} \times -\mathcal{A} \times \mathcal{S} \rightarrow [0, 1]$ is the transition function, $R: \mathcal{S} \times \mathcal{A} -\rightarrow \mathbb{R}$ is the reward function, and $O: \mathcal{Z} \times \mathcal{A} \times \mathcal{S} \rightarrow -[0,1]$ is the observation function. In a POMDP, our goal is to compute a policy $\pi$ that maps beliefs to actions $\pi: b \rightarrow a$. For -QMDP, a belief can be represented by a discrete probability distribution over the state space (although there may be -other ways to define a belief in general and POMDPs.jl allows this flexibility). +In many cases, items 2 and 4 can be satisfied with an off-the-shelf solver from [POMDPPolicies.jl](https://github.com/JuliaPOMDP/POMDPPolicies.jl). [POMDPModelTools.jl](https://github.com/JuliaPOMDP/POMDPModelTools.jl) also contains many tools that are useful for defining solvers in a robust, concise, and readable manner. -It can be shown (e.g. in [1], section 6.3.2) that the optimal value function for a POMDP can be written in terms of alpha vectors. In the QMDP approximation, there is a single alpha vector that corresponds to each action ($\alpha_a$), and the policy is computed according to +## Online and Offline Solvers -$\pi(b) = \underset{a}{\text{argmax}} \, \alpha_{a}^{T}b$ +Generally, solvers can be grouped into two categories: *Offline* solvers that do most of their computational work *before* interacting with the environment, and *online* solvers that do their work online. +Although offline and online solvers both use the exact same [`Solver`](@ref), [`solve`](@ref), [`Policy`](@ref), [`action`](@ref) structure, the work of defining online and offline solvers is focused on different portions. -Thus, the alpha vectors can be used to compactly represent a QMDP policy. +For an offline solver, most of the implementation effort will be spent on the [`solve`] function, and an off-the-shelf policy from [POMDPPolicies.jl](https://github.com/JuliaPOMDP/POMDPPolicies.jl) will typically be used. -## QMDP Algorithm +For an online solver, the [`solve`](@ref) function typically does little or no work, but merely creates a policy object that will carry out computation online. It is typical in POMDPs.jl to use the term "Planner" to name a [`Policy`](@ref) object for an online solver that carries out a large amount of computation at interaction time. In this case most of the effort will be focused on implementing the [`action`](@ref) method for the "Planner" `Policy` type. -QMDP uses the columns of the Q-matrix obtained by solving the MDP defined by $(\mathcal{S}, \mathcal{A}, T, R, \gamma)$ (that is, the fully observable MDP that forms the basis for the POMDP we are trying to solve). -If you are familiar with the value iteration algorithm for MDPs, the procedure for finding these alpha vectors is identical. Let's first -initialize the alpha vectors $\alpha_{a}^{0} = 0$ for all $s$, and then iterate +## Examples -$\alpha_{a}^{k+1}(s) = R(s,a) + \gamma \sum_{s'} T(s'|s,a) \max_{a'} \alpha_{a'}^{k}(s')$ - -After enough iterations, the alpha vectors converge to the QMDP approximation. - -Remember that QMDP is just an approximation method, and does not guarantee that the alpha vectors you obtain actually -represent your POMDP value function. Specifically, QMDP has trouble in problems with information gathering actions -(because we completely ignored the observation function when computing our policy). However, QMDP works very well in problems where a particular choice of action has -little impact on the reduction in state uncertainty. - - -## Requirements for a Solver - -Before getting into the implementation details, let's first go through what a POMDP solver must be able to do and support. We need three custom types that inherit from abstract types in POMDPs.jl. These type are Solver, Policy, and Updater. It is usually useful to have a custom type that represents the belief used by your policy as well. - -The requirements are as follows: - -```julia -# types -QMDPSolver -QMDPPolicy -DiscreteUpdater # already implemented for us in BeliefUpdaters -DiscreteBelief # already implemented for us in BeliefUpdaters -# methods -updater(p::QMDPPolicy) # returns a belief updater suitable for use with QMDPPolicy -initialize_belief(bu::DiscreteUpdater, initial_state_dist) # returns a Discrete belief -solve(solver::QMDPSolver, pomdp::POMDP) # solves the POMDP and returns a policy -update(bu::DiscreteUpdater, belief_old::DiscreteBelief, action, obs) # returns an updated belied (already implemented) -action(policy::QMDPPolicy, b::DiscreteBelief) # returns a QMDP action -``` - -You can find the implementations of these types and methods below. - -## Defining the Solver and Policy Types - -Let's first define the Solver type. The QMDP solver type should contain all the information needed to compute a policy (other than the problem itself). This information can be thought of as the hyperparameters of the solver. In QMDP, we only need two hyper-parameters. We may want to set the maximum number of iterations that the algorithm runs for, and a tolerance value (also known as the Bellman residual). Both of these quantities define terminating criteria for the algorithm. The algorithm stops either when the maximum number of iterations has been reached or when the infinity norm of the difference in utility values between two iterations goes below the tolerance value. The type definition has the form: - -```julia -using POMDPs # first load the POMDPs module -type QMDPSolver <: Solver - max_iterations::Int64 # max number of iterations QMDP runs for - tolerance::Float64 # Bellman residual: terminates when max||Ut-Ut-1|| < tolerance -end -# default constructor -QMDPSolver(;max_iterations::Int64=100, tolerance::Float64=1e-3) = QMDPSolver(max_iterations, tolerance) -``` - -Note that the QMDPSolver inherits from the abstract Solver type that's part of POMDPs.jl. - -Now, let's define a policy type. In general, the policy should contain all the information needed to map a belief to an action. As mentioned earlier, we need alpha vectors to be part of our policy. We can represent the alpha vectors using a matrix of size $|\mathcal{S}| \times |\mathcal{A}|$. Recall that in POMDPs.jl, the actions can be represented in a number of ways (Int64, concrete types, etc), so we need a way to map these actions to integers so we can index into our alpha matrix. The type looks like: - -```julia -using POMDPModelTools # for ordered_actions - -type QMDPPolicy <: Policy - alphas::Matrix{Float64} # matrix of alpha vectors |S|x|A| - action_map::Vector{Any} # maps indices to actions - pomdp::POMDP # models for convenience -end -# default constructor -function QMDPPolicy(pomdp::POMDP) - ns = n_states(pomdp) - na = n_actions(pomdp) - alphas = zeros(ns, na) - am = Any[] - space = ordered_actions(pomdp) - for a in iterator(space) - push!(am, a) - end - return QMDPPolicy(alphas, am, pomdp) -end -``` - -Now that we have our solver and policy types, we can write the solve function to compute the policy. - -## Writing the Solve Function - -The solve function takes in a solver, a POMDP, and an optional policy argument. Let's compute those alpha vectors! - -```julia -function POMDPs.solve(solver::QMDPSolver, pomdp::POMDP) - - policy = QMDPPolicy(pomdp) - - # get solver parameters - max_iterations = solver.max_iterations - tolerance = solver.tolerance - discount_factor = discount(pomdp) - - # intialize the alpha-vectors - alphas = policy.alphas - - # initalize space - sspace = ordered_states(pomdp) # returns a discrete state space object of the pomdp - aspace = ordered_actions(pomdp) # returns a discrete action space object - - # main loop - for i = 1:max_iterations - residual = 0.0 - # state loop - for (istate, s) in enumerate(sspace) - old_alpha = maximum(alphas[istate,:]) # for residual - max_alpha = -Inf - # action loop - # alpha(s) = R(s,a) + discount_factor * sum(T(s'|s,a)max(alpha(s')) - for (iaction, a) in enumerate(aspace) - # the transition function modifies the dist argument to a distribution availible from that state-action pair - dist = transition(pomdp, s, a) # fills distribution over neighbors - q_new = 0.0 - for sp in iterator(dist) - # pdf returns the probability mass of sp in dist - p = pdf(dist, sp) - p == 0.0 ? continue : nothing # skip if zero prob - # returns the reward from s-a-sp triple - r = reward(pomdp, s, a, sp) - - # stateindex returns an integer - sidx = stateindex(pomdp, sp) - q_new += p * (r + discount_factor * maximum(alphas[sidx,:])) - end - new_alpha = q_new - alphas[istate, iaction] = new_alpha - new_alpha > max_alpha ? (max_alpha = new_alpha) : nothing - end # actiom - # update the value array - diff = abs(max_alpha - old_alpha) - diff > residual ? (residual = diff) : nothing - end # state - # check if below Bellman residual - residual < tolerance ? break : nothing - end # main - # return the policy - policy -end -``` - -At each iteration, the algorithm iterates over the state space and computes an alpha vector for each action. There is a check at the end to see if the Bellman residual has been satisfied. The solve function assumes the following POMDPs.jl functions are implemented by the user of QMDP: - -```julia -states(pomdp) # (in ordered_states) returns a state space object of the pomdp -actions(pomdp) # (in ordered_actions) returns the action space object of the pomdp -transition(pomdp, s, a) # returns the transition distribution for the s, a pair -reward(pomdp, s, a, sp) # returns real valued reward from s, a, sp triple -pdf(dist, sp) # returns the probability of sp being in dist -stateindex(pomdp, sp) # returns the integer index of sp (for discrete state spaces) -``` - -Now that we have a solve function, we define the [`action`](@ref) function to let users evaluate the policy: - -```julia -using LinearAlgebra - -function POMDPs.action(policy::QMDPPolicy, b::DiscreteBelief) - alphas = policy.alphas - ihi = 0 - vhi = -Inf - (ns, na) = size(alphas) - @assert length(b.b) == ns "Length of belief and alpha-vector size mismatch" - # see which action gives the highest util value - for ai = 1:na - util = dot(alphas[:,ai], b.b) - if util > vhi - vhi = util - ihi = ai - end - end - # map the index to action - return policy.action_map[ihi] -end -``` - -## Belief Updates - -Let's now talk about how we deal with beliefs. Since QMDP is a discrete POMDP solver, we can assume that the user will represent their belief as a probablity distribution over states. That means that we can also use a discrete belief to work with our policy! -Lucky for us, the JuliaPOMDP organization contains tools that we can use out of the box for working with discrete beliefs. The POMDPToolbox package contains a `DiscreteBelief` type that does exactly what we need. The [`updater`](@ref) function allows us to declare that the `DiscreteUpdater` is the default updater to be used with a QMDP policy: - -```julia -using BeliefUpdaters # remeber to load the package that implements discrete beliefs for us -POMDPs.updater(p::QMDPPolicy) = DiscreteUpdater(p.pomdp) +Solver implementation is most clearly explained through examples. The following sections contain examples of both online and offline solver definitions: +```@contents +Pages = ["offline_solver.md", "online_solver.md"] ``` -These are all the functions that you'll need to have a working POMDPs.jl solver. Let's now use existing benchmark models to evaluate it. - -## Evaluating the Solver - -We'll use the POMDPModels package from JuliaPOMDP to initialize a Tiger POMDP problem and solve it with QMDP. - -```julia -using POMDPModels -using POMDPSimulators - -# initialize model and solver -pomdp = TigerPOMDP() -solver = QMDPSolver() - -# compute the QMDP policy -policy = solve(solver, pomdp) - -# initalize updater and belief -b_up = updater(policy) -init_dist = initialstate_distribution(pomdp) - -# create a simulator object for recording histories -sim_hist = HistoryRecorder(max_steps=100) - -# run a simulation -r = simulate(sim_hist, pomdp, policy, b_up, init_dist) -``` - -That's all you need to define a solver and evaluate its performance! - -## Defining Requirements - -If you share your solver, in order to make it easy to use, specifying requirements as described [here](@ref specifying_requirements) is highly recommended. - -\[1\] *Decision Making Under Uncertainty: Theory and Application* by -Mykel J. Kochenderfer, MIT Press, 2015 - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/docs/src/dynamics.md b/docs/src/dynamics.md new file mode 100644 index 00000000..398467f2 --- /dev/null +++ b/docs/src/dynamics.md @@ -0,0 +1,39 @@ +# [Defining (PO)MDP Dynamics](@id dynamics) + +The dynamics of a (PO)MDP define how states, observations, and rewards are generated at each time step. One way to visualize the structure of (PO)MDP is with a *dynamic decision network* (DDN) (see for example [*Decision Making under Uncertainty* by Kochenderfer et al.](https://ieeexplore.ieee.org/book/7288640) or [this webpage](https://artint.info/html/ArtInt_229.html) for more discussion of dynamic decision networks). + +The POMDPs.jl DDN models are shown below: + +| Standard MDP DDN | Standard POMDP DDN | +|:---:|:---:| +|![MDP DDN](figures/mdp_ddn.svg) | ![POMDP DDN](figures/pomdp_ddn.svg) | + +!!! note + + In order to provide additional flexibility, these DDNs have `:s`→`:o`, `:sp`→`:r` and `:o`→`:r` edges that are typically absent from the DDNs traditionally used in the (PO)MDP literature. Traditional (PO)MDP algorithms are compatible with these DDNs because only ``R(s,a)``, the expectation of ``R(s, a, s', o)`` over all ``s'`` and ``o`` is needed to make optimal decisions. + +The task of defining the dynamics of a (PO)MDP consists of defining a model for each of the nodes in the DDN. Models for each node can either be implemented separately through the [`transition`](@ref), [`observation`](@ref), and [`reward`](@ref) functions, or together with the [`gen`](@ref) function. + +## Separate explicit or generative definition + +- [`transition`](@ref)`(pomdp, s, a)` defines the state transition probability distribution for state `s` and action `a`. This defines an explicit model for the `:sp` DDN node. +- [`observation`](@ref)`(pomdp, [s,] a, sp)` defines the observation distribution given that action `a` was taken and the state is now `sp` (The observation can optionally depend on `s` - see docstring). This defines an explicit model for the `:o` DDN node. +- [`reward`](@ref)`(pomdp, s, a[, sp[, o]])` defines the reward, which is a deterministic function of the state and action (and optionally `sp` and `o` - see docstring). This defines an explicit model for the `:r` DDN node. + +[`transition`](@ref) and [`observation`](@ref) should return distribution objects that implement part or all of the [distribution interface](@ref Distributions). Some predefined distributions can be found in [Distributions.jl](https://github.com/JuliaStats/Distributions.jl) or [POMDPModelTools.jl](https://github.com/JuliaPOMDP/POMDPModelTools.jl), or custom types that represent distributions appropriate for the problem may be created. + +!!! note + + There is no requirement that a problem defined using the explicit interface be discrete; it is straightforward to define continuous POMDPs with the explicit interface, provided that the distributions have some finite parameterization. + +## Combined generative definition + +If the state, observation, and reward are generated simultaneously, a new method of the [`gen`](@ref) function should be implemented to return the state, observation and reward in a single `NamedTuple`. + +### Examples + +An example of defining a problem using separate functions can be found at: +[https://github.com/JuliaPOMDP/POMDPExamples.jl/blob/master/notebooks/Defining-a-POMDP-with-the-Explicit-Interface.ipynb](https://github.com/JuliaPOMDP/POMDPExamples.jl/blob/master/notebooks/Defining-a-POMDP-with-the-Explicit-Interface.ipynb) + +An example of defining a problem with a combined `gen` function can be found at: +[https://github.com/JuliaPOMDP/POMDPExamples.jl/blob/master/notebooks/Defining-a-POMDP-with-the-Generative-Interface.ipynb](https://github.com/JuliaPOMDP/POMDPExamples.jl/blob/master/notebooks/Defining-a-POMDP-with-the-Generative-Interface.ipynb) diff --git a/docs/src/explicit.md b/docs/src/explicit.md deleted file mode 100644 index 1f81254b..00000000 --- a/docs/src/explicit.md +++ /dev/null @@ -1,23 +0,0 @@ -# [Explicit (PO)MDP Interface](@id explicit_doc) - -When using the explicit interface, the transition and observation probabilities must be explicitly defined. - -!!! note - - There is no requirement that a problem defined using the explicit interface be discrete; it is straightforward to define continuous POMDPs with the explicit interface, provided that the distributions have some finite parameterization. - -## Explicit (PO)MDP interface - -The explicit interface consists of the following functions: - -- [`initialstate_distribution`](@ref)`(pomdp)` specifies the initial distribution of states for a problem (this is also translated to the initial belief for pomdps). -- [`transition`](@ref)`(pomdp, s, a)` defines the state transition probability distribution for state `s` and action `a`. This defines an explicit model for the [`:sp` DDN node](@ref Dynamic-decision-networks). -- [`observation`](@ref)`(pomdp, [s,] a, sp)` defines the observation distribution given that action `a` was taken and the state is now `sp` (The observation can optionally depend on `s` - see docstring). This defines an explicit model for the [`:o` DDN node](@ref Dynamic-decision-networks). -- [`reward`](@ref)`(pomdp, s, a[, sp[, o]])` defines the reward, which is a deterministic function of the state and action (and optionally `sp` and `o` - see docstring). This defines an explicit model for the [`:r` DDN node](@ref Dynamic-decision-networks). - -[`transition`](@ref), [`observation`](@ref), and [`initialstate_distribution`](@ref) should return distribution objects that implement part or all of the [distribution interface](@ref Distributions). Some predefined distributions can be found in [Distributions.jl](https://github.com/JuliaStats/Distributions.jl) or [POMDPModelTools.jl](https://github.com/JuliaPOMDP/POMDPModelTools.jl), or custom types that represent distributions appropriate for the problem may be created. - -### Example - -An example of defining a problem using the explicit interface can be found at: -[https://github.com/JuliaPOMDP/POMDPExamples.jl/blob/master/notebooks/Defining-a-POMDP-with-the-Explicit-Interface.ipynb](https://github.com/JuliaPOMDP/POMDPExamples.jl/blob/master/notebooks/Defining-a-POMDP-with-the-Explicit-Interface.ipynb) diff --git a/docs/src/generative.md b/docs/src/generative.md deleted file mode 100644 index 8bd8c123..00000000 --- a/docs/src/generative.md +++ /dev/null @@ -1,94 +0,0 @@ -# [Generative (PO)MDP Interface](@id generative_doc) - -## Quick Start - -A generative model for most (PO)MDPs can be completely defined with one function: -```julia -function POMDPs.gen(m::YourPOMDPType, s, a, rng) - # do dynamics/transition calculations here - return (sp= #=new state=#, r= #=reward=#, o= #=observation=#) -end -``` -(`o` is not needed for MDPs.) - -## Interface Description - -The *generative* interface consists of two components: -- The [`gen`](@ref) function or the [`@gen`](@ref) macro returns samples (e.g. states, observations and rewards) from a generative POMDP model. -- [`initialstate`](@ref) returns a sampled initial state. -The generative interface is typically used when it is easier to return sampled states and observations rather than explicit distributions as in the [Explicit interface](@ref explicit_doc). -This type of model is often referred to as a "black-box" model. - -In some special cases (e.g. reinforcement learning with [RLInterface.jl](https://github.com/JuliaPOMDP/RLInterface.jl)), an initial observation is needed before any actions are taken. In this case, the [`initialobs`](@ref) function will be used. - -## The [`gen`](@ref) function and [`@gen`](@ref) macro - -In most cases solvers and simulators should use the [`@gen`](@ref) macro, while problem-writers should implement new methods of the [`gen`](@ref) function. For example -```julia -sp, o, r = @gen(:sp,:o,:r)(m, s, a, rng) -``` -calls the generative model for POMDP `m` at state `s` and action `a`, and stores the next state, observation, and reward in variables `sp`, `o`, and `r`. `rng` is a [random number generator](@ref Random-number-generators). - -The [`gen`](@ref) function has three versions differentiated by the type of the first argument. - -- `gen(m::Union{POMDP, MDP}, s, a, rng)` provides a way to implement a generative model for an entire (PO)MDP in a single function. It should return values for a subset of [the DDN Nodes](@ref Dynamic-Decision-Networks) as a `NamedTuple`. - - This is typically the quickest and easiest way to implement a new POMDP model or wrap an existing simulator. - - Example (defined by a problem writer): `gen(m::MyPOMDP, s, a, rng) = (sp=s+a, r=s^2, o=s+a+randn(rng))` - - This version should **never** be called by a solver or simulator, since there is no guarantee of which values will be present in the returned object. - - Values for DDN nodes not present in the returned `NamedTuple` will be generated in the normal way with `gen(::DDNNode, ...)` or an explicit representation. - -- `gen(::`[`DDNNode`](@ref)`{nodename}, m, parent_values..., rng)` defines the generative model for a **single [DDN node](@ref Dynamic-Decision-Networks)**. Together, a group of these functions can define a problem. - - Example (defined by a problem writer): `gen(::DDNNode{:o}, m::MyPOMDP, s, a, sp, rng) = sp + randn(rng)` - - Solver writers should only directly call this version in very rare cases when it needs to access to values for a particular node of the DDN generated by specific values of its parent nodes. - -- `gen(::`[`DDNOut`](@ref)`{nodenames}, m, s, a, rng)` returns a value or tuple of values for a subset of nodes in the [DDN](@ref Dynamic-Decision-Networks). The arguments are values for the **input nodes** (currently `:s` and `:a`), treating the entire DDN as a single black box. - - Example (called in a solver): `sp, o, r = gen(DDNOut(:sp,:o,:r), m, s, a, rng)` - - This function is automatically synthesized by POMDPs.jl by combining `gen(m, s, a, rng)` and `gen(::DDNNode, ...)` or [explicit model definitions](@ref explicit_doc) for all [DDN nodes](@ref Dynamic-Decision-Networks). - - This version should only be implemented directly by problem writers in very rare cases when they need precise control for efficiency. - -In all versions, `m` is a (PO)MDP model, and `rng` is a [random number generator](@ref Random-number-generators). - -## Examples - -An example of defining a problem with the generative interface can be found [in the POMDPExamples package](https://github.com/JuliaPOMDP/POMDPExamples.jl/blob/master/notebooks/Defining-a-POMDP-with-the-Generative-Interface.ipynb). - -## Random number generators - -The `rng` argument to functions in the generative interface is a random number generator such as `Random.GLOBAL_RNG` or another `MersenneTwister`. It should be used to generate all random numbers within the function (e.g. use `rand(rng)` instead of `rand()`). This will ensure that all simulations are exactly repeatable. See the [Julia documentation on random numbers](https://docs.julialang.org/en/v1/stdlib/Random/#Random-Numbers-1) for more information about these objects. - -## Performance considerations - -In general, calling `gen(::DDNOut, ...)` when `gen(::POMDP, ...)` is implemented does not introduce much overhead. In fact, in some cases, the compiler will even optimize out calculations of extra genvars. For example: -```julia -struct M <: MDP{Int, Int} end - -POMDPs.gen(::M, s, a, rng) = (sp=s+a, r=s^2) - -@code_warntype gen(DDNOut(:sp), M(), 1, 1, Random.GLOBAL_RNG) -``` -will yield -``` -Body::Int64 -1 ─ %1 = (Base.add_int)(s, a)::Int64 -│ nothing -└── return %1 -``` -indicating that the compiler will only perform the addition to find the next state and skip the `s^2` calculation for the reward. - -Unfortunately, if random numbers are used in `gen`, the compiler will not be able to optimize out the change in the rng's state, so it may be beneficial to directly implement versions of `gen(::DDNNode, ...)`. -For example -```julia -POMDPs.gen(::DDNNode{:sp}, ::M, s, a, rng) = s+a -POMDPs.reward(::M, s, a) = abs(s) -PODMPs.gen(::DDNNode{:o}, ::M, s, a, sp, rng) = sp+randn(rng) -``` -might be more efficient than -```julia -function POMDPs.gen(::M, s, a, rng) - sp = s + a - return (sp=sp, r=abs(s), o=sp+randn(rng)) -end -``` -in the context of particle filtering. - -As always, though, one should resist the urge towards premature optimization; careful profiling to see what is actually slow is much more effective than speculation. diff --git a/docs/src/index.md b/docs/src/index.md index 31d3fafe..39940ab1 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -16,26 +16,31 @@ observable Markov decision processes (POMDPs) in the Julia programming language. [JuliaPOMDP](https://github.com/JuliaPOMDP) community maintains these packages. The list of solver and support packages is maintained at the [POMDPs.jl Readme](https://github.com/JuliaPOMDP/POMDPs.jl#supported-packages). -## Manual Outline +## Documentation Outline + +Documentation comes in four forms: +1. How-to examples are available in the [POMDPExamples package](https://github.com/JuliaPOMDP/POMDPExamples.jl) and in pages in this document with "Example" in the title. +2. An explanatory guide is available in the sections outlined below. +3. Reference docstrings for the entire interface are available in the [API Documentation](@ref) section. When updating these documents, make sure this is synced with [docs/make.jl](https://github.com/JuliaPOMDP/POMDPs.jl/blob/master/docs/make.jl)!! ### Basics ```@contents -Pages = ["index.md", "install.md", "get_started.md", "concepts.md"] +Pages = ["install.md", "get_started.md", "concepts.md"] ``` ### Defining POMDP Models ```@contents -Pages = [ "def_pomdp.md", "explicit.md", "generative.md", "requirements.md", "interfaces.md" ] +Pages = [ "def_pomdp.md", "static.md", "interfaces.md", "dynamics.md"] ``` ### Writing Solvers and Updaters ```@contents -Pages = [ "def_solver.md", "specifying_requirements.md", "def_updater.md" ] +Pages = [ "def_solver.md", "offline_solver.md", "online_solver.md", "def_updater.md" ] ``` ### Analyzing Results diff --git a/docs/src/interfaces.md b/docs/src/interfaces.md index 86db428c..d93d402f 100644 --- a/docs/src/interfaces.md +++ b/docs/src/interfaces.md @@ -9,8 +9,7 @@ A space object should contain the information needed to define the set of all po The following functions may be called on a space object (Click on a function to read its documentation): - [`rand`](@ref) -- [`length`](@ref) -- [`dimensions`](@ref) +- [`iterate`](https://docs.julialang.org/en/v1/manual/interfaces/#man-interface-iteration-1) and [the rest of the iteration interface](https://docs.julialang.org/en/v1/manual/interfaces/#man-interface-iteration-1) for discrete spaces. ## Distributions @@ -18,7 +17,7 @@ A distribution object represents a probability distribution. The following functions may be called on a distribution object (Click on a function to read its documentation): -- [`rand`](@ref)`(rng, d)` [^1] +- [`rand`](@ref)`([rng,] d)` [^1] - [`support`](@ref) - [`pdf`](@ref) - [`mode`](@ref) @@ -26,4 +25,4 @@ The following functions may be called on a distribution object (Click on a funct You can find some useful pre-made distribution objects in [Distributions.jl](https://github.com/JuliaStats/Distributions.jl) or [POMDPModelTools.jl](https://juliapomdp.github.io/POMDPModelTools.jl/latest/distributions.html). -[^1] `rand(rng::AbstractRNG, d)` where `d` is a distribution object is the only method of [`rand`](@ref) that is officially part of the POMDPs.jl interface, so it is the only required method for new distributions. However, users may wish to [hook into the official julia rand interface](https://docs.julialang.org/en/v1/stdlib/Random/index.html#Generating-values-from-a-collection-1) to enable more flexible `rand` calls. +[^1]: Distributions should support both `rand(rng::AbstractRNG, d)` and `rand(d)`. The recommended way to do this is by implmenting `Base.rand(rng::AbstractRNG, s::Random.SamplerTrivial{<:YourDistribution})` from the [julia rand interface](https://docs.julialang.org/en/v1/stdlib/Random/index.html#Generating-values-from-a-collection-1). diff --git a/docs/src/offline_solver.md b/docs/src/offline_solver.md new file mode 100644 index 00000000..aa793c34 --- /dev/null +++ b/docs/src/offline_solver.md @@ -0,0 +1,137 @@ +# Example: Defining an offline solver + +In this example, we will define a simple [offline solver](@ref Online-and-Offline-Solvers) that works for both POMDPs and MDPs. In order to focus on the code structure, we will not create an algorithm that finds an optimal policy, but rather a *greedy policy*, that is, one that optimizes the expected immediate reward. For information on using this solver in a simulation, see [Running Simulations](@ref). + +We begin by creating a solver type. Since there are no adjustable parameters for the solver, it is an empty type, but for a more complex solver, parameters would usually be included as type fields. + +```jldoctest offline; output=false +using POMDPs + +struct GreedyOfflineSolver <: Solver end + +# output + +``` + +Next, we define the functions that will make the solver work for both MDPs and POMDPs. + +### MDP Case + +Finding a greedy policy for an MDP consists of determining the action that has the best reward for each state. First, we create a simple policy object that holds a greedy action for each state. + +```jldoctest offline; output=false +struct DictPolicy{S,A} <: Policy + actions::Dict{S,A} +end + +POMDPs.action(p::DictPolicy, s) = p.actions[s] + +# output + +``` + +!!! note + A `POMDPPolicies.VectorPolicy` could be used here. We include this example to show how to define a custom policy. + +The solve function calculates the best greedy action for each state and saves it in a policy. To have the widest possible compatibility with POMDP models, we want to use [`reward`](@ref)`(m, s, a, sp)` instead of [`reward`](@ref)`(m, s, a)`, which means we need to calculate the expectation of the reward over transitions to every possible next state. + +```jldoctest offline; output=false +function POMDPs.solve(::GreedyOfflineSolver, m::MDP) + + best_actions = Dict{statetype(m), actiontype(m)}() + + for s in states(m) + if !isterminal(m, s) + best = -Inf + for a in actions(m) + td = transition(m, s, a) + r = 0.0 + for sp in support(td) + r += pdf(td, sp) * reward(m, s, a, sp) + end + if r >= best + best_actions[s] = a + end + end + end + end + + return DictPolicy(best_actions) +end + +# output + +``` + +!!! note + We limited this implementation to using basic POMDPs.jl implementation functions, but tools such as `POMDPModelTools.StateActionReward`, `POMDPModelTools.ordered_states`, and `POMDPModelTools.weighted_iterator` could have been used for a more concise and efficient implementation. + +We can now verify whether the policy produces the greedy action on an example from POMDPModels: + +```jldoctest offline +using POMDPModels + +gw = SimpleGridWorld(size=(2,1), rewards=Dict(GWPos(2,1)=>1.0)) +policy = solve(GreedyOfflineSolver(), gw) + +action(policy, GWPos(1,1)) + +# output + +:right +``` + +### POMDP Case + +For a POMDP, the greedy solution is the action that maximizes the expected immediate reward according to the belief. Since there are an infinite number of possible beliefs, the greedy solution for every belief cannot be calculated online. However, the greedy policy can take the form of an alpha vector policy where each action has an associated alpha vector with each entry corresponding to the immediate reward from taking the action in that state. + +Again, because a POMDP, may have [`reward`](@ref)`(m, s, a, sp, o)` instead of [`reward`](@ref)`(m, s, a)`, we use the former and calculate the expectation over all next states and observations. + +```jldoctest offline; output=false +import POMDPPolicies + +function POMDPs.solve(::GreedyOfflineSolver, m::POMDP) + + alphas = Vector{Float64}[] + + for a in actions(m) + alpha = zeros(length(states(m))) + for s in states(m) + if !isterminal(m, s) + r = 0.0 + td = transition(m, s, a) + for sp in support(td) + tp = pdf(td, sp) + od = observation(m, s, a, sp) + for o in support(od) + r += tp * pdf(od, o) * reward(m, s, a, sp, o) + end + end + alpha[stateindex(m, s)] = r + end + end + push!(alphas, alpha) + end + + return POMDPPolicies.AlphaVectorPolicy(m, alphas, collect(actions(m))) +end + +# output + +``` +We can now verify that a policy created by the solver determines the correct greedy actions: + +```jldoctest offline; output=false +using POMDPModels +using POMDPModelTools # for Deterministic, Uniform + +tiger = TigerPOMDP() +policy = solve(GreedyOfflineSolver(), tiger) + +@assert action(policy, Deterministic(TIGER_LEFT)) == TIGER_OPEN_RIGHT +@assert action(policy, Deterministic(TIGER_RIGHT)) == TIGER_OPEN_LEFT +@assert action(policy, Uniform(states(tiger))) == TIGER_LISTEN + +# output + +``` diff --git a/docs/src/online_solver.md b/docs/src/online_solver.md new file mode 100644 index 00000000..2bd49943 --- /dev/null +++ b/docs/src/online_solver.md @@ -0,0 +1,101 @@ +# Example: Defining an online solver + +In this example, we will define a simple [online solver](@ref Online-and-Offline-Solvers) that works for both POMDPs and MDPs. In order to focus on the code structure, we will not create an algorithm that finds an optimal policy, but rather a *greedy policy*, that is, one that optimizes the expected immediate reward. For information on using this solver in a simulation, see [Running Simulations](@ref). + +In order to handle the widest range of problems, we will use [`@gen`](@ref) to generate Mone Carlo samples to estimate the reward even if only a simulator is available. We begin by creating the necessary types and the solve function. The only solver parameter is the number of samples used to estimate the reward at each step. + +```jldoctest online; output=false +using POMDPs + +struct MonteCarloGreedySolver <: Solver + num_samples::Int +end + +struct MonteCarloGreedyPlanner{M} <: Policy + m::M + num_samples::Int +end + +POMDPs.solve(sol::MonteCarloGreedySolver, m) = MonteCarloGreedyPlanner(m, sol.num_samples) + +# output + +``` + +Next, we define the [`action`](@ref) function where the online work takes place. + +### MDP Case + +```jldoctest online; output=false +function POMDPs.action(p::MonteCarloGreedyPlanner{<:MDP}, s) + best_reward = -Inf + local best_action + for a in actions(p.m) + reward_sum = sum(@gen(:r)(p.m, s, a) for _ in 1:p.num_samples) + if reward_sum >= best_reward + best_reward = reward_sum + best_action = a + end + end + return best_action +end + +# output + +``` + +### POMDP Case + +```jldoctest online +function POMDPs.action(p::MonteCarloGreedyPlanner{<:POMDP}, b) + best_reward = -Inf + local best_action + for a in actions(p.m) + s = rand(b) + reward_sum = sum(@gen(:r)(p.m, s, a) for _ in 1:p.num_samples) + if reward_sum >= best_reward + best_reward = reward_sum + best_action = a + end + end + return best_action +end + +# output + +``` + +### Verification + +We can now verify that the online planner works in some simple cases: + +```jldoctest online +using POMDPModels + +gw = SimpleGridWorld(size=(2,1), rewards=Dict(GWPos(2,1)=>1.0)) +solver = MonteCarloGreedySolver(1000) +planner = solve(solver, gw) + +action(planner, GWPos(1,1)) + +# output + +:right +``` + +```jldoctest online; output=false +using POMDPModels +using POMDPModelTools # for Deterministic, Uniform + +tiger = TigerPOMDP() +solver = MonteCarloGreedySolver(1000) + +planner = solve(solver, tiger) + +@assert action(planner, Deterministic(TIGER_LEFT)) == TIGER_OPEN_RIGHT +@assert action(planner, Deterministic(TIGER_RIGHT)) == TIGER_OPEN_LEFT +# note action(planner, Uniform(states(tiger))) is not very reliable with this number of samples + +# output + +``` diff --git a/docs/src/requirements.md b/docs/src/requirements.md deleted file mode 100644 index aedfc56a..00000000 --- a/docs/src/requirements.md +++ /dev/null @@ -1,31 +0,0 @@ -# [Interface Requirements for Problems](@id requirements) - -Due to the large variety of problems that can be expressed as MDPs and POMDPs and the wide variety of solution techniques available, there is considerable variation in which of the POMDPs.jl interface functions must be implemented to use each solver. No solver requires all of the functions in the interface, so it is wise to determine which functions are needed before jumping into implementation. - -Solvers can communicate these requirements through the `@requirements_info` and `@show_requirements` macros. `@requirements_info` should give an overview of the requirements for a solver, which is supplied as the first argument, the macro can usually be more informative if a problem is specified as the second arg. For example, if you are implementing a new problem `NewMDP` and want to use the `DiscreteValueIteration` solver, you might run the following: - -![requirements_info for a new problem](figures/requirements_info_new.png) - -Note that a few of the requirements could not be shown because [`actions`](@ref) is not implemented for the new problem. - -If you would like to see a list of all of the requirements for a solver, try running `@requirements_info` with a fully implemented model from `POMDPModels`, for example, - -![requirements_info for a fully-implemented problem](figures/requirements_info_gw.png) - -`@show_requirements` is a lower-level tool that can be used to show the requirements for a specific function call, for example -```julia -@show_requirements solve(ValueIterationSolver(), NewMDP()) -``` -or -```julia -policy = solve(ValueIterationSolver(), GridWorld()) -@show_requirements action(policy, GridWorldState(1,1)) -``` - -In some cases, a solver writer may not have specified the requirements, in which case the requirements query macros will output - -``` -[No requirements specified] -``` - -In this case, please file an issue on the solver's github page to encourage the solver writer to specify requirements. diff --git a/docs/src/simulation.md b/docs/src/simulation.md index 50cc4d9b..efde95a0 100644 --- a/docs/src/simulation.md +++ b/docs/src/simulation.md @@ -19,8 +19,8 @@ In general, POMDP simulations take up to 5 inputs (see also the [`simulate`](@re The last three of these inputs are optional. If they are not explicitly provided, they should be inferred using the following POMDPs.jl functions: - `up = `[`updater`](@ref)`(policy)` -- `isd = `[`initialstate_distribution`](@ref)`(pomdp)` -- `s = `[`initialstate`](@ref)`(pomdp, rng)` +- `isd = `[`initialstate`](@ref)`(pomdp)` +- `s = `rand(rng, [`initialstate`](@ref)`(pomdp))` In addition, a random number generator `rng` is assumed to be available. diff --git a/docs/src/specifying_requirements.md b/docs/src/specifying_requirements.md deleted file mode 100644 index c7ccdcf7..00000000 --- a/docs/src/specifying_requirements.md +++ /dev/null @@ -1,102 +0,0 @@ -# [Specifying Requirements](@id specifying_requirements) - -## Purpose - -When a researcher or student wants to use a solver in the POMDPs ecosystem, the first question they will ask is "What do I have to implement to use this solver?". The requirements interface provides a standard way for solver writers to answer this question. - -## Internal interface - -The most important functions in the requirements interface are [`get_requirements`](@ref), [`check_requirements`](@ref), and [`show_requirements`](@ref). - -`get_requirements(f::Function, args::Tuple{...})` should be implemented by a solver or simulator writer for all important functions that use the POMDPs.jl interface. In practice, this function will rarely by implemented directly because the [@POMDP_require](@ref pomdp_require_section) macro automatically creates it. The function should return a `RequirementSet` object containing all of the methods POMDPs.jl functions that need to be implemented for the function to work with the specified arguments. - -[`check_requirements`](@ref) returns true if [all of the requirements in a `RequirementSet` are met](@ref implemented_section), and [`show_requirements`](@ref) prints out a list of the requirements in a `RequirementSet` and indicates which ones have been met. - -## [@POMDP_require](@id pomdp_require_section) - -The [`@POMDP_require`](@ref) macro is the main point of interaction with the requirements system for solver writers. It uses a special syntax to automatically implement [`get_requirements`](@ref). This is best shown by example. Consider this `@POMDP_require` block from the [DiscreteValueIteration package](https://github.com/JuliaPOMDP/DiscreteValueIteration.jl): - -```julia -@POMDP_require solve(solver::ValueIterationSolver, mdp::Union{MDP,POMDP}) begin - P = typeof(mdp) - S = statetype(P) - A = actiontype(P) - @req discount(::P) - @req n_states(::P) - @req n_actions(::P) - @subreq ordered_states(mdp) - @subreq ordered_actions(mdp) - @req transition(::P,::S,::A) - @req reward(::P,::S,::A,::S) - @req stateindex(::P,::S) - as = actions(mdp) - ss = states(mdp) - @req iterator(::typeof(as)) - @req iterator(::typeof(ss)) - s = first(iterator(ss)) - a = first(iterator(as)) - dist = transition(mdp, s, a) - D = typeof(dist) - @req iterator(::D) - @req pdf(::D,::S) -end -``` - -The first expression argument to the macro is a function signature specifying what the requirements apply to. The above example implements `get_requirements{P<:Union{POMDP,MDP}}(solve::typeof(solve), args::Tuple{ValueIterationSolver,P})` which will construct a `RequirementSet` containing the requirements for executing the `solve` function with `ValueIterationSolver` and `MDP` or `POMDP` arguments at runtime. - -The second expression is a [`begin`-`end` block](http://docs.julialang.org/en/release-0.5/manual/control-flow/#compound-expressions) that specifies the requirements. The arguments in the function signature (`solver` and `mdp` in this example) may be used within the block. - -The [`@req`](@ref) macro is used to specify a required function. Each [`@req`](@ref) should be followed by a function with the argument types specified. The [`@subreq`](@ref) macro is used to denote that the requirements of another function are also required. Each [`@subreq`](@ref) should be followed by a function call. - -## `requirements_info` - -While the `@POMDP_require` macro is used to specify requirements for a specific method, the [`requirements_info`](@ref) function is a more flexible communication tool for a solver writer. [`requirements_info`](@ref) should print out a message describing the requirements for a solver. The exact form of the message is up to the solver writer, but it should be carefully thought-out because problem-writers will be directed to call the function (via the `@requirements_info` macro) as the first step in using a new solver (see [tutorial](def_pomdp.md)). - -By default, `requirements_info` calls [`show_requirements`](@ref) on the `solve` function. This is adequate in many cases, but in some cases, notably for online solvers such as [MCTS](https://github.com/JuliaPOMDP/MCTS.jl), the requirements for [`solve`](@ref) do not give a good indication of the requirements for using the solver. Instead, the requirements for [`action`](@ref) should be displayed. The following example shows a more informative version of `requirements_info` from the MCTS package. Since [`action`](@ref) requires a state argument, `requirements_info` prompts the user to provide one. - -```julia -function POMDPs.requirements_info(solver::AbstractMCTSSolver, problem::Union{POMDP,MDP}) - if statetype(typeof(problem)) <: Number - s = one(statetype(typeof(problem))) - requirements_info(solver, problem, s) - else - println(""" - Since MCTS is an online solver, most of the computation occurs in `action(policy, state)`. In order to view the requirements for this function, please, supply a state as the third argument to `requirements_info`, e.g. - - @requirements_info $(typeof(solver))() $(typeof(problem))() $(statetype(typeof(problem)))() - - """) - end -end - -function POMDPs.requirements_info(solver::AbstractMCTSSolver, problem::Union{POMDP,MDP}, s) - policy = solve(solver, problem) - requirements_info(policy, s) -end - -function POMDPs.requirements_info(policy::AbstractMCTSPolicy, s) - @show_requirements action(policy, s) -end -``` - -## `@warn_requirements` - -The `@warn_requirements` macro is a useful tool to improve usability of a solver. It will show a requirements list only if some requirements are not met. It might be used, for example, in the solve function to give a problem writer a useful error if some required methods are missing (assuming the solver writer has already used `@POMDP_require` to specify the requirements for `solve`): - -```julia -function solve(solver::ValueIterationSolver, mdp::Union{POMDP, MDP}) - @warn_requirements solve(solver, mdp) - - # do the work of solving -end -``` - -`@warn_requirements` does perform a runtime check of requirements every time it is called, so it should not be used in code that may be used in fast, high-performance loops. - -## [Determining whether a function is implemented](@id implemented_section) - -When checking requirements in `check_requirements`, or printing in `show_requirements`, the [`implemented`](@ref) function is used to determine whether an implementation for a function is available. For example `implemented(discount, Tuple{NewPOMDP})` should return true if the writer of the `NewPOMDP` problem has implemented discount for their problem. In most cases, the default implementation, -```julia -implemented(f::Function, TT::TupleType) = method_exists(f, TT) -``` -will automatically handle this, but there may be cases in which you want to override the behavior of `implemented`, for example, if the function can be synthesized from other functions. Examples of this can be found in the [default implementations of the generative interface funcitons](https://github.com/JuliaPOMDP/POMDPs.jl/blob/master/src/generative_impl.jl.jl). diff --git a/docs/src/basic_properties.md b/docs/src/static.md similarity index 57% rename from docs/src/basic_properties.md rename to docs/src/static.md index 116875e6..29dbdde2 100644 --- a/docs/src/basic_properties.md +++ b/docs/src/static.md @@ -1,6 +1,6 @@ -# [Defining Basic (PO)MDP Properties](@id basic) +# [Defining Static (PO)MDP Properties](@id static) -In addition to the [dynamic decision network (DDN)](@ref Dynamic-decision-networks) that defines the state and observation dynamics, a POMDPs.jl problem definition will include definitions of various other properties. Each of these properties is defined by implementing a new method of an interface function for the problem. +The definition of a (PO)MDP includes several static properties, which are defined with the functions listed in this section. This section is an overview, with links to the docstrings for detailed usage information. To use most solvers, it is only necessary to implement a few of these functions. @@ -18,6 +18,12 @@ It is often important to limit the action space based on the current state, beli This can be accomplished with the [`actions`](@ref)`(m, s)` or [`actions`](@ref)`(m, b)` function. See [Histories associated with a belief](@ref) and the [`history`](@ref) and [`currentobs`](@ref) docstrings for more information. +## Initial Distributions + +[`initialstate`](@ref)`(pomdp)` should return the distribution of the initial state, either as an explicit distribution (e.g. a `POMDPModelTools.SparseCat`) that conforms to the [distribution interface](@ref Distributions) or with a `POMDPModelTools.ImplicitDistribution` to easily specify a function to sample from the space. + +[`initialobs`](@ref)`(pomdp, state)` is used to return the distribution of the initial observation in occasional cases where the policy expects an initial observation rather than an initial belief, e.g. in a reinforcement learning setting. It is not used in a standard POMDP simulation. + ## Discount Factor [`discount`](@ref)`(pomdp)` should return a number between 0 and 1 to define the discount factor. @@ -35,7 +41,13 @@ For discrete problems, some solvers rely on a fast method for finding the index - [`actionindex`](@ref)`(pomdp, a)` - [`obsindex`](@ref)`(pomdp, o)` -Note that the converse mapping (from indices to states) is not part of the POMDPs interface. A solver will typically create a vector containing all the states to define it. +!!! note + + The converse mapping (from indices to states) is not part of the POMDPs interface. A solver will typically create a vector containing all the states to define it. + +!!! note + + There is no requirement that the object returned by the [space functions](@ref Spaces) above respect the same ordering as the `index` functions. The `index` functions are the *sole definition* of ordering of the states. The `POMDPModelTools` package contains convenience functions for constructing a list of states that respects the ordering specified by the `index` functions. For example, `POMDPModelTools.ordered_states` returns an `AbstractVector` of the states in the order specified by `stateindex`. ## Conversion to vector types diff --git a/src/POMDPs.jl b/src/POMDPs.jl index c2c164f2..e490f91f 100644 --- a/src/POMDPs.jl +++ b/src/POMDPs.jl @@ -4,15 +4,15 @@ Provides a basic interface for defining and solving MDPs/POMDPs module POMDPs using Random -using Base: @pure import Base: rand import Statistics import Distributions: rand, pdf, mode, mean, support import NamedTupleTools import Pkg -import LibGit2 using LightGraphs -using Logging + +# For Deprecated +import POMDPLinter export # Abstract type @@ -28,13 +28,14 @@ export observation, reward, isterminal, + initialstate, + initialobs, # Generative model functions gen, @gen, - initialstate, - initialobs, - + DDNOut, + # Discrete Functions length, stateindex, @@ -46,9 +47,7 @@ export pdf, mode, mean, - dimensions, support, - initialstate_distribution, # Solver types Solver, @@ -72,8 +71,6 @@ export simulate, # Utilities - implemented, - @implemented, convert_s, convert_a, convert_o, @@ -81,28 +78,9 @@ export actiontype, obstype, - # DDNs - DDNNode, - DDNOut, - DDNStructure, - DDNStructure, - DistributionDDNNode, - FunctionDDNNode, - ConstantDDNNode, - InputDDNNode, - GenericDDNNode, - node, - depvars, - depnames, - nodenames, - outputnames, - name, - add_node, - pomdp_ddn, - mdp_ddn, - DistributionNotImplemented, - - # Requirements checking + # Deprecated + implemented, + @implemented, RequirementSet, check_requirements, show_requirements, @@ -116,35 +94,20 @@ export @warn_requirements, @req, @subreq, + initialstate_distribution, + dimensions + - # Deprecated - generate_s, - generate_o, - generate_sr, - generate_so, - generate_or, - generate_sor, - sampletype, - n_states, - n_actions, - n_observations - -include("requirements_internals.jl") -include("requirements_printing.jl") include("pomdp.jl") include("solver.jl") include("simulator.jl") -include("requirements_interface.jl") include("distribution.jl") include("belief.jl") include("space.jl") include("policy.jl") include("type_inferrence.jl") -include("ddn_struct.jl") -include("errors.jl") include("generative.jl") include("gen_impl.jl") -include("utils.jl") include("deprecated.jl") end diff --git a/src/belief.jl b/src/belief.jl index 1e06a7a8..dfb889c5 100644 --- a/src/belief.jl +++ b/src/belief.jl @@ -57,4 +57,3 @@ Return the latest observation associated with belief `b`. If a solver or updater implements `history(b)` for a belief type, `currentobs` has a default implementation. """ currentobs(b) = history(b)[end].o -@impl_dep currentobs(::B) where B history(::B) diff --git a/src/ddn_struct.jl b/src/ddn_struct.jl deleted file mode 100644 index 785a5592..00000000 --- a/src/ddn_struct.jl +++ /dev/null @@ -1,235 +0,0 @@ -""" - DDNNode(x::Symbol) - DDNNode{x::Symbol}() - -Reference to a named node in the POMDP or MDP dynamic decision network (DDN). - -Note that `gen(::DDNNode, m, depargs..., rng)` always takes an argument for each dependency whereas `gen(::DDNOut, m, s, a, rng)` only takes `s` and `a` arguments (the inputs to the entire DDN). - -`DDNNode` is a "value type". See [the documentation of `Val`](https://docs.julialang.org/en/v1/manual/types/index.html#%22Value-types%22-1) for more conceptual details about value types. -""" -struct DDNNode{name} end - -@pure DDNNode(name::Symbol) = DDNNode{name}() - -""" -Get the name of a DDNNode. -""" -name(::DDNNode{n}) where n = n -name(::Type{DDNNode{n}}) where n = n - -""" - DDNOut(x::Symbol) - DDNOut{x::Symbol}() - DDNOut(::Symbol, ::Symbol,...) - DDNOut{x::NTuple{N, Symbol}}() - -Reference to one or more named nodes in the POMDP or MDP dynamic decision network (DDN). - -Note that `gen(::DDNOut, m, s, a, rng)` always takes `s` and `a` arguments (the inputs to the entire DDN) while `gen(::DDNNode, m, depargs..., rng)` takes a variable number of arguments (one for each dependency). - -`DDNOut` is a "value type". See [the documentation of `Val`](https://docs.julialang.org/en/v1/manual/types/index.html#%22Value-types%22-1) for more conceptual details about value types. -""" -struct DDNOut{names} end - -@pure DDNOut(name::Symbol) = DDNOut{name}() -@pure DDNOut(names...) = DDNOut{names}() -@pure DDNOut(names::Tuple) = DDNOut{names}() - -struct DDNStructure{N<:NamedTuple, D<:NamedTuple} - "Node implementations." - nodes::N - "Dependency tree. NamedTuple full of Tuples of DDNNodes." - deps::D -end - -node(d::DDNStructure, name::Symbol) = d.nodes[name] -depvars(d::DDNStructure, name::Symbol) = d.deps[name] -depnames(d::DDNStructure, n::Symbol) = map(name, depvars(d, n)) - -nodenames(d::DDNStructure) = keys(d.nodes) -nodenames(DDN::Type{D}) where {D <: DDNStructure} = fieldnames(DDN.parameters[1]) -outputnames(d::DDNStructure) = outputnames(typeof(d)) # XXX Port to 0.8 -function outputnames(::Type{D}) where D <: DDNStructure - tuple(Iterators.filter(sym->!(sym in (:s, :a)), nodenames(D))...) -end - -function add_node(d::DDNStructure, n::DDNNode{name}, node, deps) where name - @assert !haskey(d.nodes, name) "DDNStructure already has a node named :$name" - return DDNStructure(merge(d.nodes, NamedTuple{tuple(name)}(tuple(node))), - merge(d.deps, NamedTuple{tuple(name)}(tuple(deps)))) -end - -function add_node(d::DDNStructure, n::Symbol, node, deps::NTuple{N,Symbol}) where N - return add_node(d, DDNNode(n), node, map(DDNNode, deps)) -end - -depstype(DDN::Type{D}) where {D <: DDNStructure} = DDN.parameters[2] - -""" - sorted_deppairs(DDN::Type{D}, symbols) where D <: DDNStructure - -Create a list of name=>deps pairs sorted so that dependencies come before dependents. - -`symbols` is any iterable collection of `Symbol`s. -""" -function sorted_deppairs end # this is implemented below - -""" - DDNStructure(::Type{M}) where M <: Union{MDP, POMDP} - -Trait of an MDP/POMDP type for describing the structure of the dynamic Baysian network. - -# Example - - struct MyMDP <: MDP{Int, Int} end - POMDPs.gen(::MyMDP, s, a, rng) = (sp=s+a+rand(rng, [1,2,3]), r=s^2) - - # make a new node, delta_s, that is deterministically equal to sp - s - function POMDPs.DDNStructure(::Type{MyMDP}) - ddn = mdp_ddn() - return add_node(ddn, :delta_s, FunctionDDNNode((m,s,sp)->sp-s), (:s, :sp)) - end - - gen(DDNOut(:delta_s), MyMDP(), 1, 1, Random.GLOBAL_RNG) -""" -function DDNStructure end - -DDNStructure(::Type{M}) where M <: MDP = mdp_ddn() -DDNStructure(::Type{M}) where M <: POMDP = pomdp_ddn() - -DDNStructure(m) = DDNStructure(typeof(m)) - -struct InputDDNNode end # this does nothing for now - -""" -DDN node defined by a function that maps the model and values from the parent nodes to a distribution - -# Example - DistributionDDNNode((m, s, a)->POMDPModelTools.Deterministic(s+a)) -""" -struct DistributionDDNNode{F} - dist_func::F -end - -@generated function gen(n::DistributionDDNNode, m, args...) - # apparently needs to be @generated for type stability - argexpr = (:(args[$i]) for i in 1:length(args)-1) - quote - rand(last(args), n.dist_func(m, $(argexpr...))) - end -end - -function implemented(g::typeof(gen), n::DistributionDDNNode, M, Deps, RNG) - return implemented(n.dist_func, Tuple{M, Deps.parameters...}) -end - - -""" -DDN node defined by a function that determinisitically maps the model and values from the parent nodes to a new value. - -# Example - FunctionDDNNode((m, s, a)->s+a) -""" -struct FunctionDDNNode{F} - f::F -end - -@generated function gen(n::FunctionDDNNode, m, args...) - # apparently this needs to be @generated for type stability - argexpr = (:(args[$i]) for i in 1:length(args)-1) - quote - n.f(m, $(argexpr...)) - end -end - -function implemented(g::typeof(gen), n::FunctionDDNNode, M, Deps, RNG) - return implemented(n.f, Tuple{M, Deps.parameters...}) -end - -""" -DDN node that always takes a deterministic constant value. -""" -struct ConstantDDNNode{T} - val::T -end - -gen(n::ConstantDDNNode, args...) = n.val -implemented(g::typeof(gen), n::ConstantDDNNode, M, Deps, RNG) = true - -""" -DDN node that can only have a generative model; `gen(::DDNNode{:x}, ...)` must be implemented for a node of this type. -""" -struct GenericDDNNode end - -gen(::GenericDDNNode, args...) = error("No `gen(::DDNNode, ...)` method implemented for a GenericDDNNode (see stack trace for name)") -implemented(g::typeof(gen), GenericDDNNode, M, Deps, RNG) = false - -# standard DDNs -function mdp_ddn() - DDNStructure((s = InputDDNNode(), - a = InputDDNNode(), - sp = DistributionDDNNode(transition), - r = FunctionDDNNode(reward), - ), - (s = (), - a = (), - sp = map(DDNNode, (:s, :a)), - r = map(DDNNode, (:s, :a, :sp)), - ) - ) -end - -function pomdp_ddn() - DDNStructure((s = InputDDNNode(), - a = InputDDNNode(), - sp = DistributionDDNNode(transition), - o = DistributionDDNNode(observation), - r = FunctionDDNNode(reward), - ), - (s = (), - a = (), - sp = map(DDNNode, (:s, :a)), - o = map(DDNNode, (:s, :a, :sp)), - r = map(DDNNode, (:s, :a, :sp, :o)), - ) - ) -end - -function sorted_deppairs(ddn::Type{D}, symbols) where D <: DDNStructure - depnames = Dict{Symbol, Vector{Symbol}}() - NT = depstype(ddn) - for key in fieldnames(NT) - depnames[key] = collect(map(name, fieldtype(NT, key).parameters)) - end - return sorted_deppairs(depnames, symbols) -end - -sorted_deppairs(ddn::Type{D}, symbol::Symbol) where D <: DDNStructure = sorted_deppairs(ddn, tuple(symbol)) - -function sorted_deppairs(depnames::Dict{Symbol, Vector{Symbol}}, symbols) - dag = SimpleDiGraph(length(depnames)) - labels = Symbol[] - nodemap = Dict{Symbol, Int}() - for sym in symbols - if !haskey(nodemap, sym) - push!(labels, sym) - nodemap[sym] = length(labels) - end - add_dep_edges!(dag, nodemap, labels, depnames, sym) - end - sortednodes = topological_sort_by_dfs(dag) - sortednames = labels[filter(n -> n<=length(labels), sortednodes)] - return [n=>depnames[n] for n in sortednames] -end - -function add_dep_edges!(dag, nodemap, labels, depnames, sym) - for dep in depnames[sym] - if !haskey(nodemap, dep) - push!(labels, dep) - nodemap[dep] = length(labels) - end - add_edge!(dag, nodemap[dep], nodemap[sym]) - add_dep_edges!(dag, nodemap, labels, depnames, dep) - end -end diff --git a/src/deprecated.jl b/src/deprecated.jl index 768fbbf9..28be3d70 100644 --- a/src/deprecated.jl +++ b/src/deprecated.jl @@ -1,50 +1,115 @@ -@deprecate generate_s(args...) gen(DDNOut(:sp), args...) -@deprecate generate_o(args...) gen(DDNNode(:o), args...) # this one should be DDNNode because the arguments are not s and a -@deprecate generate_sr(args...) gen(DDNOut(:sp,:r), args...) -@deprecate generate_so(args...) gen(DDNOut(:sp,:o), args...) -@deprecate generate_sor(args...) gen(DDNOut(:sp,:o,:r), args...) -generate_or(args...) = error("POMDPs.jl v0.8 no longer supports generate_or") # there is no equivalent for this in the new system, but AFAIK no one used it. +@deprecate implemented POMDPLinter.implemented +@deprecate RequirementSet POMDPLinter.RequirementSet +@deprecate check_requirements POMDPLinter.check_requirements +@deprecate show_requirements POMDPLinter.show_requirements +@deprecate get_requirements POMDPLinter.get_requirements +@deprecate requirements_info POMDPLinter.requirements_info -const old_generate = Dict(:sp => generate_s, - :o => generate_o, - (:sp,:r) => generate_sr, - (:sp,:o) => generate_so, - (:sp,:o,:r) => generate_sor) -const new_ddnvars = Dict(generate_s => DDNNode{:sp}, - generate_o => DDNNode{:o}, - generate_sr => DDNOut{(:sp,:r)}, - generate_so => DDNOut{(:sp,:o)}, - generate_sor => DDNOut{(:sp,:o,:r)}) +macro implemented(ex) + @warn("POMDPs.@implemented is deprecated, use POMDPLinter.@implemented instead.", maxlog=1) + tplex = POMDPLinter.convert_req(ex) + return quote + POMDPLinter.implemented($(esc(tplex))...) + end +end + +macro POMDP_require(args...) + @warn("POMDPs.@POMDP_require is deprecated, use POMDPLinter.@POMDP_require instead.", maxlog=1) + POMDPLinter.pomdp_require(args...) +end +macro POMDP_requirements(args...) + @warn("POMDPs.@POMDP_requirements is deprecated, use POMDPLinter.@POMDP_requirements instead.", maxlog=1) + POMDPLinter.pomdp_requirements(args...) +end + +macro requirements_info(exprs...) + @warn("POMDPs.@requirements_info is deprecated, use POMDPLinter.@requirements_info instead.", maxlog=1) + quote + requirements_info($([esc(ex) for ex in exprs]...)) + end +end -GenerateFunctions = Union{(typeof(f) for f in values(old_generate))...} +macro get_requirements(call) + @warn("POMDPs.@get_requirements is deprecated, use POMDPLinter.@get_requirements instead.", maxlog=1) + return quote get_requirements($(esc(POMDPLinter.convert_call(call)))...) end +end -function implemented_by_user(g::GenerateFunctions, TT::TupleType) - m = which(g, TT) - return m.module != POMDPs +macro show_requirements(call) + @warn("POMDPs.@show_requirements is deprecated, use POMDPLinter.@show_requirements instead.", maxlog=1) + quote + reqs = get_requirements($(esc(POMDPLinter.convert_call(call)))...) + show_requirements(reqs) + end end -function implemented(g::GenerateFunctions, TT::TupleType) - if implemented_by_user(g, TT) - return true +macro warn_requirements(call) + @warn("POMDPs.@warn_requirements is deprecated, use POMDPLinter.@warn_requirements instead.", maxlog=1) + quote + reqs = get_requirements($(esc(POMDPLinter.convert_call(call)))...) + c = check_requirements(reqs) + if !ismissing(c) && c == false + show_requirements(reqs) + end end - return implemented(gen, Tuple{new_ddnvars[g], TT.parameters...}) end -@deprecate sampletype Random.gentype -@deprecate n_states(m) length(states(m)) -@deprecate n_actions(m) length(actions(m)) -@deprecate n_observations(m) length(observations(m)) +macro req(args...) + :(error("POMDPs.@req no longer exists. Please use POMDPLinter.@req")) +end + +macro subreq(args...) + :(error("POMDPs.@subreq no longer exists. Please use POMDPLinter.@subreq")) +end + +function gen(o::DDNOut{symbols}, m::Union{MDP,POMDP}, s, a, rng) where symbols + if symbols isa Symbol + @warn("gen(DDNOut(:$symbols), m, s, a, rng) is deprecated, use @gen(:$symbols)(m, s, a, rng) instead.", maxlog=1) + else + symbolstring = join([":$s" for s in symbols], ", ") + @warn("gen(DDNOut($symbolstring), m, s, a, rng) is deprecated, use @gen($symbolstring)(m, s, a, rng) instead.", maxlog=1) + end + return genout(DDNOut(symbols), m, s, a, rng) +end + +@deprecate initialstate(m, rng) rand(rng, initialstate(m)) +@deprecate initialstate_distribution initialstate + +# for the case when initialstate is called, but initialstate_distribution is implemented +function initialstate(m::Union{MDP,POMDP}) + method = which(initialstate_distribution, Tuple{typeof(m)}) + if method.module == POMDPs # ignore the @deprecated definition to avoid infinite recurse + throw(MethodError(initialstate, (m,))) + else + @warn("Falling back to using deprecated function initialstate_distribution(::$(typeof(m))). Please implement this as initialstate(::$(typeof(m))) instead.", maxlog=1) + return initialstate_distribution(m) + end +end + +@deprecate initialobs(m, s, rng) rand(rng, initialobs(m, s)) + +dimensions(s::Any) = error("dimensions is no longer part of the POMDPs.jl interface.") """ -The version 0.7 DDNStructure just has the nodenames + available() + +Prints all the available packages in the JuliaPOMDP registry """ -struct DDNStructureV7{nodenames} end +function available() + @warn("POMDPs.available() is deprecated. Please see the POMDPs.jl README for a list of packages.") + reg_dict = read_registry(joinpath(Pkg.depots1(), "registries", "JuliaPOMDP", "Registry.toml")) + for (uuid, pkginfo) in reg_dict["packages"] + println(pkginfo["name"]) + end +end + +function read_registry(regfile) + registry = Pkg.TOML.parsefile(regfile) + return registry +end -nodenames(d::DDNStructureV7) = nodenames(typeof(d)) -nodenames(::Type{D}) where D <: DDNStructureV7 = D.parameters[1] -outputnames(d::DDNStructureV7) = outputnames(typeof(d)) -function outputnames(::Type{D}) where D <: DDNStructureV7 - tuple(Iterators.filter(sym->!(sym in (:s, :a)), nodenames(D))...) +function add_registry(;kwargs...) + @warn("""POMDPs.add_registry() is deprecated. Use Pkg.pkg"registry add https://github.com/JuliaPOMDP/Registry" instead.""") + Pkg.pkg"registry add https://github.com/JuliaPOMDP/Registry" end diff --git a/src/errors.jl b/src/errors.jl deleted file mode 100644 index e77c211d..00000000 --- a/src/errors.jl +++ /dev/null @@ -1,117 +0,0 @@ -struct DistributionNotImplemented <: Exception - sym::Symbol - gen_firstarg::Type - func::Function - modeltype::Type - dep_argtypes::AbstractVector -end - -function Base.showerror(io::IO, ex::DistributionNotImplemented) - println(io, """\n - - POMDPs.jl could not find an implementation for DDN Node :$(ex.sym). Consider the following options: - """) - - argstring = string("::", ex.modeltype, string((", ::$T" for T in ex.dep_argtypes)...)) - - i = 1 - if ex.gen_firstarg <: DDNOut - M = ex.modeltype - S = statetype(M) - A = actiontype(M) - printstyled(io, "$i) Implement POMDPs.gen(::$M, ::$S, ::$A, ::AbstractRNG) to return a NamedTuple with key :$(ex.sym).\n", bold=true) - gen_analysis(io, ex.sym, M, [S,A]) - println(io) - i += 1 - end - printstyled(io, "$i) Implement POMDPs.gen(::DDNNode{:$(ex.sym)}, $argstring, ::AbstractRNG).\n", - bold=true) - try_show_method_candidates(io, MethodError(gen, Tuple{DDNNode{ex.sym}, ex.modeltype, ex.dep_argtypes..., AbstractRNG})) - i += 1 - printstyled(io, "\n\n$i) Implement $(ex.func)($argstring).\n", bold=true) - try_show_method_candidates(io, MethodError(ex.func, Tuple{ex.modeltype, ex.dep_argtypes...})) - - println(io, "\n\nThis error message uses heuristics to make recommendations for POMDPs.jl problem implementers. If it was misleading or you believe there is an inconsistency, please file an issue: https://github.com/JuliaPOMDP/POMDPs.jl/issues/new") -end - -function distribution_impl_error(sym, func, modeltype, dep_argtypes) - st = stacktrace() - acceptable = (:distribution_impl_error, nameof(func), nameof(gen), nameof(genout)) - gen_firstarg = nothing # The first argument to the `gen` call that is furthest down in the stack trace - - try - for sf in stacktrace() # step up the stack trace - - # if it is a macro from ddn_struct.jl or gen_impl.jl it is ok - if sf.func === Symbol("macro expansion") - bn = basename(String(sf.file)) - if !(bn in ["ddn_struct.jl", "gen_impl.jl", "none"]) - break - # the call stack includes a macro from some other package - end - - # if it is not a function we know about, give up - elseif !(sf.func in acceptable) - break - - # if it is gen, check to see if it's the DDNNode version - elseif sf.func === nameof(gen) - sig = sf.linfo.def.sig - if sig isa UnionAll && - sig.body.parameters[1] == typeof(gen) && - sig.body.parameters[2] <: Union{DDNNode, DDNOut} - # bingo! - gen_firstarg = sig.body.parameters[2] - dep_argtypes = [sig.body.parameters[3:end-1]...] - end - end - end - catch ex - @debug("Error throwing DistributionNotImplemented error:\n$(sprint(showerror, ex))") - throw(MethodError(func, Tuple{modeltype, dep_argtypes...})) - end - - if gen_firstarg === nothing - throw(MethodError(func, Tuple{modeltype, dep_argtypes...})) - else - throw(DistributionNotImplemented(sym, gen_firstarg, func, modeltype, dep_argtypes)) - end -end - -function gen_analysis(io, sym::Symbol, modeltype::Type, dep_argtypes) - argtypes = Tuple{modeltype, dep_argtypes..., AbstractRNG} - rts = Base.return_types(gen, argtypes) - if length(rts) <= 0 # there should always be the default NamedTuple() impl. - @debug("Error analyzing the return types for gen. Please submit an issue at https://github.com/JuliaPOMDP/POMDPs.jl/issues/new", argtypes=argtypes, rts=rts) - elseif length(rts) == 1 - rt = first(rts) - if rt == typeof(NamedTuple()) && !implemented(gen, argtypes) - try_show_method_candidates(io, MethodError(gen, argtypes)) - println(io) - else - println(io, "\nThis method was implemented and the return type was inferred to be $rt. Is this type always a NamedTuple with key :$(sym)?") - end - else - println(io, "(POMDPs.jl could not determine if this method was implemented correctly. [Base.return_types(gen, argtypes) = $(rts)])") - end -end - -function try_show_method_candidates(io, args...) - try - Base.show_method_candidates(io, args...) # this isn't exported, so it might break - catch ex - @debug("Unable to show method candidates. Please submit an issue at https://github.com/JuliaPOMDP/POMDPs.jl/issues/new.\n$(sprint(showerror, ex))") - end -end - -transition(m, s, a) = distribution_impl_error(:sp, transition, typeof(m), [typeof(s), typeof(a)]) -function implemented(t::typeof(transition), TT::TupleType) - m = which(t, TT) - return m.module != POMDPs # see if this was implemented by a user elsewhere -end - -observation(m, sp) = distribution_impl_error(:o, observation, typeof(m), [typeof(sp)]) -function implemented(o::typeof(observation), TT::Type{Tuple{M, SP}}) where {M<:POMDP, SP} - m = which(o, TT) - return m.module != POMDPs -end diff --git a/src/gen_impl.jl b/src/gen_impl.jl index ee9c2da3..474dac95 100644 --- a/src/gen_impl.jl +++ b/src/gen_impl.jl @@ -1,39 +1,32 @@ -@generated function gen(v::DDNOut{symbols}, m, s, a, rng) where symbols - - # deprecation of old generate_ functions - if symbols isa Tuple && # if it is just one, it will be handled in the DDNNode version - haskey(old_generate, symbols) && - implemented_by_user(old_generate[symbols], Tuple{m, s, a, rng}) - - @warn("""Using user-implemented function - $(old_generate[symbols])(::M, ::S, ::A, ::RNG) - which is deprecated in POMDPs v0.8. Please implement this as - POMDPs.gen(::M, ::S, ::A, ::RNG) or - POMDPs.gen(::DDNOut{$symbols}, ::M, ::S, ::A, ::RNG) - instead. See the POMDPs.gen documentation for more details.""", M=m, S=s, A=a, RNG=rng) - return :($(old_generate[symbols])(m, s, a, rng)) - end - - quote - ddn = DDNStructure(m) - genout(v, ddn, m, s, a, rng) - end -end +gen(m::Union{MDP,POMDP}, s, a, rng) = NamedTuple() """ -Sample values for nodes specified in the first argument by sampling values for all intermediate nodes. + DDNOut(x::Symbol) + DDNOut{x::Symbol}() + DDNOut(::Symbol, ::Symbol,...) + DDNOut{x::NTuple{N, Symbol}}() + +Reference to one or more named nodes in the POMDP or MDP dynamic decision network (DDN). + +`DDNOut` is a "value type". See [the documentation of `Val`](https://docs.julialang.org/en/v1/manual/types/index.html#%22Value-types%22-1) for more conceptual details about value types. """ -@inline @generated function genout(v::DDNOut{symbols}, ddn::DDNStructure, m, s, a, rng) where symbols - +struct DDNOut{names} end + +DDNOut(name::Symbol) = DDNOut{name}() +DDNOut(names...) = DDNOut{names}() +DDNOut(names::Tuple) = DDNOut{names}() + +@generated function genout(v::DDNOut{symbols}, m::Union{MDP,POMDP}, s, a, rng) where symbols + # use anything available from gen(m, s, a, rng) expr = quote x = gen(m, s, a, rng) @assert x isa NamedTuple "gen(m::Union{MDP,POMDP}, ...) must return a NamedTuple; got a $(typeof(x))" end - + # add gen for any other variables - for (var, depargs) in sorted_deppairs(ddn, symbols) - if var in (:s, :a) # eventually should look for InputDDNNodes instead of being hardcoded + for (var, depargs) in sorted_deppairs(m, symbols) + if var in (:s, :a) # input nodes continue end @@ -43,7 +36,7 @@ Sample values for nodes specified in the first argument by sampling values for a if haskey(x, $sym) # should be constant at compile time $var = x[$sym] else - $var = gen(DDNNode{$sym}(), m, $(depargs...), rng) + $var = $(node_expr(Val(var), depargs)) end end append!(expr.args, varblock.args) @@ -60,56 +53,57 @@ Sample values for nodes specified in the first argument by sampling values for a return expr end -@generated function gen(::DDNNode{x}, m, args...) where x - # this function is only @generated to deal with deprecation of gen functions - - # deprecation of old generate_ functions - if haskey(old_generate, x) && implemented_by_user(old_generate[x], Tuple{m, args...}) - @warn("""Using user-implemented function - $(old_generate[x])(::M, ::Argtypes...) - which is deprecated in POMDPs v0.8. Please implement this as - POMDPs.gen(::M, ::Argtypes...) or - POMDPs.gen(::DDNNode{:$x}, ::M, ::Argtypes...) - instead. See the POMDPs.gen documentation for more details.""", M=m, Argtypes=args) - return :($(old_generate[x])(m, args...)) - end - - quote - gen(node(DDNStructure(m), x), m, args...) - end +function sorted_deppairs(m::Type{<:MDP}, symbols) + deps = Dict(:s => Symbol[], + :a => Symbol[], + :sp => [:s, :a], + :r => [:s, :a, :sp], + :info => Symbol[] + ) + return sorted_deppairs(deps, symbols) end -gen(m::Union{MDP, POMDP}, s, a, rng) = NamedTuple() +function sorted_deppairs(m::Type{<:POMDP}, symbols) + deps = Dict(:s => Symbol[], + :a => Symbol[], + :sp => [:s, :a], + :o => [:s, :a, :sp], + :r => [:s, :a, :sp, :o], + :info => Symbol[] + ) + return sorted_deppairs(deps, symbols) +end -function implemented(g::typeof(gen), TT::TupleType) - m = which(g, TT) - if m.module != POMDPs # implemented by a user elsewhere - return true +function sorted_deppairs(depnames::Dict{Symbol, Vector{Symbol}}, symbols) + dag = SimpleDiGraph(length(depnames)) + labels = Symbol[] + nodemap = Dict{Symbol, Int}() + for sym in symbols + if !haskey(nodemap, sym) + push!(labels, sym) + nodemap[sym] = length(labels) + end + add_dep_edges!(dag, nodemap, labels, depnames, sym) end - v = first(TT.parameters) - if v <: Union{MDP, POMDP} - return false # already checked above for implementation in another module - else - @assert v <: Union{DDNNode, DDNOut} - vp = first(v.parameters) - if haskey(old_generate, vp) && implemented_by_user(old_generate[vp], Tuple{TT.parameters[2:end]...}) # old generate function is implemented - return true + sortednodes = topological_sort_by_dfs(dag) + sortednames = labels[filter(n -> n<=length(labels), sortednodes)] + return [n=>depnames[n] for n in sortednames] +end + +sorted_deppairs(dn::Dict{Symbol, Vector{Symbol}}, sym::Symbol) = sorted_deppairs(dn, tuple(sym)) + +function add_dep_edges!(dag, nodemap, labels, depnames, sym) + for dep in depnames[sym] + if !haskey(nodemap, dep) + push!(labels, dep) + nodemap[dep] = length(labels) end - - return implemented(g, v, TT.parameters[2], Tuple{TT.parameters[3:end-1]...}, TT.parameters[end]) + add_edge!(dag, nodemap[dep], nodemap[sym]) + add_dep_edges!(dag, nodemap, labels, depnames, dep) end end -function implemented(g::typeof(gen), Var::Type{D}, M::Type, Deps::TupleType, RNG::Type) where D <: DDNNode - v = first(Var.parameters) - ddn = DDNStructure(M) - return implemented(g, node(ddn, v), M, Deps, RNG) -end - -function implemented(g::typeof(gen), Vars::Type{D}, M::Type, Deps::TupleType, RNG::Type) where D <: DDNOut - if length(Deps.parameters) == 2 && implemented(g, Tuple{M, Deps.parameters..., RNG}) # gen(m, s, a, rng) is implemented - return true # should this be true or missing? - else - return missing # this is complicated because we need to know the types of everything in the ddn - end -end +node_expr(::Val{:sp}, depargs) = :(rand(rng, transition(m, $(depargs...)))) +node_expr(::Val{:o}, depargs) = :(rand(rng, observation(m, $(depargs...)))) +node_expr(::Val{:r}, depargs) = :(reward(m, $(depargs...))) +node_expr(::Val{:info}, depargs) = :(nothing) diff --git a/src/generative.jl b/src/generative.jl index 0d4eec15..07a51967 100644 --- a/src/generative.jl +++ b/src/generative.jl @@ -1,24 +1,9 @@ """ - gen(...) - -Sample from generative model of a POMDP or MDP. - -In most cases solver and simulator writers should use the `@gen` macro. Problem writers may wish to implement one or more new methods of the function for their problem. - -There are three versions of the function: -- The most convenient version to implement is gen(m::Union{MDP,POMDP}, s, a, rng::AbstractRNG), which returns a `NamedTuple`. -- Defining behavior for and sampling from individual nodes of the dynamic decision network can be accomplished using the version with a `DDNNode` argument. -- A version with a `DDNOut` argument is provided by the compiler to sample multiple nodes at once. - -See below for detailed documentation for each type. - ---- - gen(m::Union{MDP,POMDP}, s, a, rng::AbstractRNG) -Convenience function for implementing the entire MDP/POMDP generative model in one function by returning a `NamedTuple`. +Function for implementing the entire MDP/POMDP generative model by returning a `NamedTuple`. -The `NamedTuple` version of `gen` is the most convenient for problem writers to implement. However, it should *never* be used directly by solvers or simulators. Instead solvers and simulators should use the version with a `DDNOut` first argument. +Solver and simulator writers should use the `@gen` macro to call a generative model. # Arguments - `m`: an `MDP` or `POMDP` model @@ -27,139 +12,30 @@ The `NamedTuple` version of `gen` is the most convenient for problem writers to - `rng`: a random number generator (Typically a `MersenneTwister`) # Return -The function should return a [`NamedTuple`](https://docs.julialang.org/en/v1/base/base/#Core.NamedTuple). Typically, this `NamedTuple` will be `(sp=, r=)` for an `MDP` or `(sp=, o=, r=) for a `POMDP`. - ---- - - gen(v::DDNNode{name}, m::Union{MDP,POMDP}, depargs..., rng::AbstractRNG) - -Sample a value from a node in the dynamic decision network. - -These functions will be used within gen(::DDNOut, ...) to sample values for all outputs and their dependencies. They may be implemented directly by a problem-writer if they wish to implement a generative model for a particular node in the dynamic decision network, and may be called in solvers to sample a value for a particular node. - -# Arguments -- `v::DDNNode{name}`: which DDN node the function should sample from. -- `depargs`: values for all the dependent nodes. Dependencies are determined by `deps(DDNStructure(m), name)`. -- `rng`: a random number generator (Typically a `MersenneTwister`) - -# Return -A sampled value from the specified node. - -# Examples -Let `m` be a `POMDP`, `s` and `sp` be states of `m`, `a` be an action of `m`, and `rng` be an `AbstractRNG`. -- `gen(DDNNode(:sp), m, s, a, rng)` returns the next state. -- `gen(DDNNode(:o), m, s, a, sp, rng)` returns the observation given the previous state, action, and new state. - ---- +The function should return a [`NamedTuple`](https://docs.julialang.org/en/v1/base/base/#Core.NamedTuple). With a subset of following entries: - gen(t::DDNOut{X}, m::Union{MDP,POMDP}, s, a, rng::AbstractRNG) where X +## MDP +- `sp`: the next state +- `r`: the reward for the step +- `info`: extra debugging information, typically in an associative container like a NamedTuple -Sample values from several nodes in the dynamic decision network. X is a symbol or tuple of symbols indicating which nodes to output. +## POMDP +- `sp`: the next state +- `o`: the observation +- `r`: the reward for the step +- `info`: extra debugging information, typically in an associative container like a NamedTuple -An implementation of this method is automatically provided by POMDPs.jl. Solvers and simulators should use this version. Problem writers may implement it directly in special cases (see the POMDPs.jl documentation for more information). - -# Arguments -- `t::DDNOut`: which DDN nodes the function should sample from. -- `m`: an `MDP` or `POMDP` model -- `s`: the current state -- `a`: the action -- `rng`: a random number generator (Typically a `MersenneTwister`) +Some elements can be left out. For instance if `o` is left out of the return, the problem-writer can also implement `observation` and POMDPs.jl will automatically use it when needed. -# Return -If the `DDNOut` parameter, `X`, is a symbol, return a value sample from the corresponding node. If `X` is a tuple of symbols, return a `Tuple` of values sampled from the specified nodes. +# Example +```julia +struct LQRMDP <: MDP{Float64, Float64} end -# Examples -Let `m` be an `MDP` or `POMDP`, `s` be a state of `m`, `a` be an action of `m`, and `rng` be an `AbstractRNG`. -- `gen(DDNOut(:sp, :r), m, s, a, rng)` returns a `Tuple` containing the next state and reward. -- `gen(DDNOut(:sp, :o, :r), m, s, a, rng)` returns a `Tuple` containing the next state, observation, and reward. -- `gen(DDNOut(:sp), m, s, a, rng)` returns the next state. +POMDPs.gen(m::LQRMDP, s, a, rng) = (sp = s + a + randn(rng), r = -s^2 - a^2) +``` """ function gen end -""" - initialstate(m::Union{POMDP,MDP}, rng::AbstractRNG) - -Return a sampled initial state for the problem `m`. - -Usually the initial state is sampled from an initial state distribution. The random number generator `rng` should be used to draw this sample (e.g. use `rand(rng)` instead of `rand()`). -""" -function initialstate end - -function implemented(f::typeof(initialstate), TT::Type) - if !hasmethod(f, TT) - return false - end - m = which(f, TT) - if m.module == POMDPs && !implemented(initialstate_distribution, Tuple{TT.parameters[1]}) - return false - else - return true - end -end - -@generated function initialstate(p::Union{POMDP,MDP}, rng) - impl = quote - d = initialstate_distribution(p) - return rand(rng, d) - end - - # it is technically illegal to call this within the generated function - if implemented(initialstate_distribution, Tuple{p}) - return impl - else - return quote - try - $impl # trick to get the compiler to insert the right backedges - catch - throw(MethodError(initialstate, (p, rng))) - end - end - end -end - -""" - initialobs(m::POMDP, s, rng::AbstractRNG) - -Return a sampled initial observation for the problem `m` and state `s`. - -This function is only used in cases where the policy expects an initial observation rather than an initial belief, e.g. in a reinforcement learning setting. It is not used in a standard POMDP simulation. - -By default, it will fall back to `observation(m, s)`. The random number generator `rng` should be used to draw this sample (e.g. use `rand(rng)` instead of `rand()`). -""" -function initialobs end - -function implemented(f::typeof(initialobs), TT::Type) - if !hasmethod(f, TT) - return false - end - m = which(f, TT) - if m.module == POMDPs && !implemented(observation, Tuple{TT.parameters[1:2]...}) - return false - else - return true - end -end - -@generated function initialobs(m::POMDP, s, rng) - impl = quote - d = observation(m, s) - return rand(rng, d) - end - - # it is technically illegal to call this within the generated function - if implemented(observation, Tuple{m, s}) - return impl - else - return quote - try - $impl # trick to get the compiler to insert the right backedges - catch - throw(MethodError(initialobs, (m, s, rng))) - end - end - end -end - """ @gen(X)(m, s, a) @gen(X)(m, s, a, rng::AbstractRNG) @@ -186,6 +62,6 @@ Let `m` be an `MDP` or `POMDP`, `s` be a state of `m`, `a` be an action of `m`, macro gen(symbols...) quote # this should be an anonymous function, but there is a bug (https://github.com/JuliaLang/julia/issues/36272) - f(m, s, a, rng=Random.GLOBAL_RNG) = gen(DDNOut($(symbols...)), m, s, a, rng) + f(m, s, a, rng=Random.GLOBAL_RNG) = genout(DDNOut($(symbols...)), m, s, a, rng) end end diff --git a/src/pomdp.jl b/src/pomdp.jl index 0ac60ccd..7629da1c 100644 --- a/src/pomdp.jl +++ b/src/pomdp.jl @@ -21,28 +21,32 @@ Abstract base type for a fully observable Markov decision process. abstract type MDP{S,A} end """ - discount(problem::POMDP) - discount(problem::MDP) + discount(m::POMDP) + discount(m::MDP) Return the discount factor for the problem. """ function discount end """ - transition(problem::POMDP, state, action) - transition(problem::MDP, state, action) + transition(m::POMDP, state, action) + transition(m::MDP, state, action) -Return the transition distribution from the current state-action pair +Return the transition distribution from the current state-action pair. + +If it is difficult to define the probability density or mass function explicitly, consider using `POMDPModelTools.ImplicitDistribution` to define a generative model. """ function transition end """ - observation(problem::POMDP, statep) - observation(problem::POMDP, action, statep) - observation(problem::POMDP, state, action, statep) + observation(m::POMDP, statep) + observation(m::POMDP, action, statep) + observation(m::POMDP, state, action, statep) Return the observation distribution. You need only define the method with the fewest arguments needed to determine the observation distribution. +If it is difficult to define the probability density or mass function explicitly, consider using `POMDPModelTools.ImplicitDistribution` to define a generative model. + # Example ```julia using POMDPModelTools # for SparseCat @@ -54,21 +58,11 @@ observation(p::MyPOMDP, sp::Int) = SparseCat([sp-1, sp, sp+1], [0.1, 0.8, 0.1]) """ function observation end -""" - observation(problem::POMDP, action, statep) - -Return the observation distribution for the a-s' tuple (action and next state) -""" observation(problem::POMDP, a, sp) = observation(problem, sp) -@impl_dep observation(::P,::A,::S) where {P<:POMDP,S,A} observation(::P,::S) - -""" - observation(problem::POMDP, state, action, statep) +POMDPLinter.@impl_dep observation(::P,::A,::S) where {P<:POMDP,S,A} observation(::P,::S) -Return the observation distribution for the s-a-s' tuple (state, action, and next state) -""" observation(problem::POMDP, s, a, sp) = observation(problem, a, sp) -@impl_dep observation(::P,::S,::A,::S) where {P<:POMDP,S,A} observation(::P,::A,::S) +POMDPLinter.@impl_dep observation(::P,::S,::A,::S) where {P<:POMDP,S,A} observation(::P,::A,::S) """ reward(m::POMDP, s, a) @@ -90,10 +84,10 @@ For some problems, it is easier to express `reward(m, s, a, sp)` or `reward(m, s function reward end reward(m::Union{POMDP,MDP}, s, a, sp) = reward(m, s, a) -@impl_dep reward(::P,::S,::A,::S) where {P<:Union{POMDP,MDP},S,A} reward(::P,::S,::A) +POMDPLinter.@impl_dep reward(::P,::S,::A,::S) where {P<:Union{POMDP,MDP},S,A} reward(::P,::S,::A) reward(m::Union{POMDP,MDP}, s, a, sp, o) = reward(m, s, a, sp) -@impl_dep reward(::P,::S,::A,::S,::O) where {P<:Union{POMDP,MDP},S,A,O} reward(::P,::S,::A,::S) +POMDPLinter.@impl_dep reward(::P,::S,::A,::S,::O) where {P<:Union{POMDP,MDP},S,A,O} reward(::P,::S,::A,::S) """ isterminal(m::Union{MDP,POMDP}, s) @@ -105,12 +99,25 @@ If a state is terminal, no actions will be taken in it and no additional rewards isterminal(problem::Union{POMDP,MDP}, state) = false """ - initialstate_distribution(pomdp::POMDP) - initialstate_distribution(mdp::MDP) + initialstate(m::Union{POMDP,MDP}) -Return a distribution of the initial state of the pomdp or mdp. +Return a distribution of initial states for (PO)MDP `m`. + +If it is difficult to define the probability density or mass function explicitly, consider using `POMDPModelTools.ImplicitDistribution` to define a model for sampling. +""" +function initialstate end + +""" + initialobs(m::POMDP, s) + +Return a distribution of initial observations for POMDP `m` and state `s`. + +If it is difficult to define the probability density or mass function explicitly, consider using `POMDPModelTools.ImplicitDistribution` to define a model for sampling. + +This function is only used in cases where the policy expects an initial observation rather than an initial belief, e.g. in a reinforcement learning setting. It is not used in a standard POMDP simulation. """ -function initialstate_distribution end +function initialobs end + """ stateindex(problem::POMDP, s) diff --git a/src/requirements_interface.jl b/src/requirements_interface.jl deleted file mode 100644 index 7142d5f6..00000000 --- a/src/requirements_interface.jl +++ /dev/null @@ -1,206 +0,0 @@ -""" - implemented(function, Tuple{Arg1Type, Arg2Type}) - -Check whether there is an implementation available that will return a suitable value. -""" -implemented(f::Function, TT::TupleType) = hasmethod(f, TT) -implemented(req::Req) = implemented(first(req), last(req)) - -""" - @implemented function(::Arg1Type, ::Arg2Type) - -Check whether there is an implementation available that will return a suitable value. -""" -macro implemented(ex) - tplex = convert_req(ex) - return quote - implemented($(esc(tplex))...) - end -end - -""" - get_requirements(f::Function, args::Tuple) - -Return a RequirementSet for the function f and arguments args. -""" -get_requirements(f::Function, args::Tuple) = Unspecified((f, typeof(args))) - - -""" - @get_requirements f(arg1, arg2) - -Call get_requirements(f, (arg1,arg2)). -""" -macro get_requirements(call) - return quote get_requirements($(esc(convert_call(call)))...) end -end - - -""" - @POMDP_require solve(s::CoolSolver, p::POMDP) begin - PType = typeof(p) - @req states(::PType) - @req actions(::PType) - @req transition(::PType, ::S, ::A) - s = first(states(p)) - a = first(actions(p)) - t_dist = transition(p, s, a) - @req rand(::AbstractRNG, ::typeof(t_dist)) - end - -Create a get_requirements implementation for the function signature and the requirements block. -""" -macro POMDP_require(typedcall, block) - fname, args, types = unpack_typedcall(typedcall) - tconstr = Expr[:($(Symbol(:T,i))<:$(esc(C))) for (i,C) in enumerate(types)] # oh snap - ts = Symbol[Symbol(:T,i) for i in 1:length(types)] - req_spec = :(($fname, Tuple{$(types...)})) - fimpl = quote - function POMDPs.get_requirements(f::typeof($(esc(fname))), args::Tuple{$(ts...)}) where {$(tconstr...)} # dang - ($([esc(a) for a in args]...),) = args # whoah - return $(pomdp_requirements(req_spec, block)) - end - end - return fimpl -end - -""" - reqs = @POMDP_requirements CoolSolver begin - PType = typeof(p) - @req states(::PType) - @req actions(::PType) - @req transition(::PType, ::S, ::A) - s = first(states(p)) - a = first(actions(p)) - t_dist = transition(p, s, a) - @req rand(::AbstractRNG, ::typeof(t_dist)) - end - -Create a RequirementSet object. -""" -macro POMDP_requirements(name, block) - return pomdp_requirements(name, block) -end - - -""" - @warn_requirements solve(solver, problem) - -Print a warning if there are unmet requirements. -""" -macro warn_requirements(call::Expr) - quote - reqs = get_requirements($(esc(convert_call(call)))...) - c = check_requirements(reqs) - if !ismissing(c) && c == false - show_requirements(reqs) - end - end -end - - -""" - @show_requirements solve(solver, problem) - -Print a a list of requirements for a function call. -""" -macro show_requirements(call::Expr) - quote - reqs = get_requirements($(esc(convert_call(call)))...) - show_requirements(reqs) - end -end - -""" - @requirements_info ASolver() [YourPOMDP()] - -Print information about the requirements for a solver. -""" -macro requirements_info(exprs...) - quote - requirements_info($([esc(ex) for ex in exprs]...)) - end -end - -""" - requirements_info(s::Solver, p::Union{POMDP,MDP}, ...) - -Print information about the requirement for solver s. -""" -function requirements_info(s::Union{Solver,Simulator}) - stype = typeof(s) - try - stype = stype.name.name - catch ex - # do nothing - end - println("""Please supply a POMDP as a second argument to requirements_info. - e.g. `@requirements_info $(stype)() YourPOMDP()` - """) -end -function requirements_info(s::Union{Solver,Simulator}, p::Union{POMDP,MDP}, args...) - reqs = get_requirements(solve, (s, p)) - show_requirements(reqs) -end - -""" - @req f( ::T1, ::T2) - -Convert a `f( ::T1, ::T2)` expression to a `(f, Tuple{T1,T2})::Req` for pushing to a `RequirementSet`. - -If in a `@POMDP_requirements` or `@POMDP_require` block, marks the requirement for including in the set of requirements. -""" -macro req(ex) - return esc(convert_req(ex)) -end - -""" - @subreq f(arg1, arg2) - -In a `@POMDP_requirements` or `@POMDP_require` block, include the requirements for `f(arg1, arg2)` as a child argument set. -""" -macro subreq(ex) - return quote - get_requirements($(esc(convert_call(ex)))...) - end -end - -""" - check_requirements(r::AbstractRequirementSet) - -Check whether the methods in `r` have implementations with `implemented()`. Return true if all methods have implementations. -""" -function check_requirements(r::AbstractRequirementSet) - analyzed = Set() - return recursively_check(r, analyzed) -end - -""" - show_requirements(r::AbstractRequirementSet) - -Check whether the methods in `r` have implementations with `implemented()` and print out a formatted list showing which are missing. Return true if all methods have implementations. -""" -function show_requirements(r::AbstractRequirementSet) - buf = stdout - reported = Set{Req}() - analyzed = Set() - - show_heading(buf, r.requirer) - println(buf) - - allthere, first_exception = recursively_show(buf, r, analyzed, reported) - - if ismissing(allthere) || !allthere - println("Note: Missing methods are often due to incorrect importing. You must explicitly import POMDPs functions to add new methods.") - println() - end - - if first_exception != nothing - print("Throwing the first exception (from processing ") - printstyled(handle_method(first_exception.requirer), color=:blue) - println(" requirements):\n") - throw(first_exception.exception) - end - - return allthere -end diff --git a/src/requirements_internals.jl b/src/requirements_internals.jl deleted file mode 100644 index a72045db..00000000 --- a/src/requirements_internals.jl +++ /dev/null @@ -1,331 +0,0 @@ -const TupleType = Type # should be Tuple{T1,T2,...} -const Req = Tuple{Function, TupleType} - -abstract type AbstractRequirementSet end - -mutable struct Unspecified <: AbstractRequirementSet - requirer - parent::Union{Nothing, Any} -end - -Unspecified(requirer) = Unspecified(requirer, nothing) - -mutable struct RequirementSet <: AbstractRequirementSet - requirer - reqs::Vector{Req} # not actually a set - to preserve intuitive ordering - deps::Vector{AbstractRequirementSet} - parent::Union{Nothing, Any} - exception::Union{Nothing, Exception} -end - -function RequirementSet(requirer, parent=nothing) - return RequirementSet(requirer, - Vector{Tuple{Function, TupleType}}(), - AbstractRequirementSet[], - parent, - nothing) -end - -Base.push!(r::RequirementSet, func::Function, argtypes::TupleType) = push!(r, (func, argtypes)) -Base.push!(r::RequirementSet, t::Tuple{Function, TupleType}) = push!(r.reqs, t) - -function push_dep!(r::RequirementSet, dep::AbstractRequirementSet) - dep.parent = r.requirer - push!(r.deps, dep) -end - -""" -Return an expression that creates a RequirementSet using the code in the block. The resulting code will *always* return a RequirementSet, but it may be incomplete if the exception field is not null. -""" -function pomdp_requirements(name::Union{Expr,String}, block::Expr) - block = deepcopy(block) - req_found = handle_reqs!(block, :reqs) - if !req_found - block = esc(block) - @warn("No @req or @subreq found in @POMDP_requirements block.") - end - - newblock = quote - reqs = RequirementSet($(esc(name))) - try - $block - catch exception - reqs.exception = exception - end - reqs - end - return newblock -end - -const CheckedList = Vector{Tuple{Union{Bool,Missing}, Function, TupleType}} - - -""" -Return a `(f, Tuple{T1,T2})::Req` expression given a `f( ::T1, ::T2)` expression. -""" -function convert_req(ex::Expr) - malformed = false - if ex.head == :call - func = ex.args[1] - argtypes = Union{Symbol, Expr}[] - for a in ex.args[2:end] - if isa(a, Expr) - if a.head == :(::) - if length(a.args) == 1 - push!(argtypes, a.args[1]) - elseif length(a.args) == 2 - push!(argtypes, a.args[2]) - else - malformed = true - break - end - else - malformed = true - break - end - else - push!(argtypes, :(typeof($a))) - end - end - else - malformed = true - end - if malformed # throw error at parse time so solver writers will have to deal with this - error(""" - Malformed requirement expression: $ex - Requirements should be expressed in the form `function_name(::Type1, ::Type2)` or `function_name(arg1, arg2)`. - """) - else - return quote ($func, Tuple{$(argtypes...)}) end - end -end - -function recursively_show(io::IO, - r::RequirementSet, - analyzed::Set, - reported::Set{Req}) - if r.requirer in analyzed - return true - end - - push!(analyzed, r.requirer) - - checked = CheckedList() - allthere = true - for fp in r.reqs - if !(fp in reported) - push!(reported, fp) - exists = implemented(first(fp), last(fp)) - allthere = exists & allthere - push!(checked, (exists, first(fp), last(fp))) - end - end - - show_requirer(io, r) - if isempty(checked) - println(io, " [No additional requirements]") - else - show_checked_list(io, checked) - end - - if r.exception == nothing # no exception - first_exception = nothing - else - allthere = false - show_incomplete(io, r) - first_exception = r - end - - for dep in r.deps - depcomplete, depexception = recursively_show(io, dep, analyzed, reported) - allthere = allthere & depcomplete - if first_exception == nothing && depexception != nothing - first_exception = depexception - end - end - - return allthere, first_exception -end - -function recursively_show(io::IO, r::Unspecified, analyzed::Set, reported::Set{Req}) - if r.requirer in analyzed - return true, nothing - else - push!(analyzed, r.requirer) - show_requirer(io::IO, r) - println(io, " [No requirements specified]") - return true, nothing - end -end - - -function recursively_check(r::RequirementSet, analyzed::Set) - if r.requirer in analyzed - return true - end - - push!(analyzed, r.requirer) - - allthere = r.exception == nothing - for fp in r.reqs - allthere = allthere & implemented(first(fp), last(fp)) - end - - for dep in r.deps - allthere = allthere & recursively_check(dep, analyzed) - end - - return allthere -end - -function recursively_check(r::Unspecified, analyzed::Set) - push!(analyzed, r.requirer) - return true -end - -""" -Return a tuple (not an Expr) of the function name, arguments, and argument types. - -E.g. `f(arg1::T1, arg2::T2)` would be unpacked to (:f, [:arg1, :arg2], [:T1, :T2]) -""" -function unpack_typedcall(typedcall::Expr) - malformed = false - if typedcall.head != :call - malformed = true - end - - args = Union{Symbol,Expr}[] - types = Union{Symbol,Expr}[] - for expr in typedcall.args[2:end] - if isa(expr,Expr) && expr.head == :(::) - push!(args, expr.args[1]) - push!(types, expr.args[2]) - elseif isa(expr,Symbol) - push!(args, expr) - push!(types, :Any) - else - malformed = true - end - end - - if malformed - error(""" - Malformed typed funciton call expression: $typedcall - Expected the form `function_name(arg1::Type1, arg2::Type2)`. - """) - end - - return (typedcall.args[1], args, types) -end - -""" -Return a `(f, (arg1,arg2))` expression given a `f(arg1, arg2)` expression. -""" -function convert_call(call::Expr) - malformed = false - if call.head == :call - func = call.args[1] - args = Union{Symbol, Expr}[] - for a in call.args[2:end] - if isa(a, Expr) && a.head == :(::) - @assert length(args) == 2 - push!(args, a.args[1]) - else - push!(args, a) - end - end - else - malformed = true - end - if malformed # throw error at parse time so solver writers will have to deal with this - error(""" - Malformed call expression: $call - Expected the form `funcion_name(arg1, arg2)` - """) - else - return quote ($func, ($(args...),)) end - end -end - - -# this is where the freaking magic happens. -""" - handle_reqs!(block, reqs_name::Symbol) - -Replace any @req calls with `push!(\$reqs_name, )` - -Returns true if there was a requirement in there and so should not be escaped. -""" -function handle_reqs!(node::Expr, reqs_name::Symbol) - - if node.head == :macrocall && node.args[1] == Symbol("@req") - macro_node = copy(node) - node.head = :call - expanded = macroexpand(POMDPs, macro_node) - if isa(expanded, Expr) && expanded.head == :error - throw(expanded.args[1]) - end - node.args = [:push!, reqs_name, esc(expanded)] - return true - elseif node.head == :macrocall && node.args[1] == Symbol("@subreq") - macro_node = copy(node) - node.head = :call - expanded = macroexpand(POMDPs, macro_node) - if isa(expanded, Expr) && expanded.head == :error - throw(expanded.args[1]) - end - node.args = [:push_dep!, reqs_name, esc(macroexpand(POMDPs, expanded))] - return true - else - found = falses(length(node.args)) - for (i, arg) in enumerate(node.args) - found[i] = handle_reqs!(arg, reqs_name) - end - if any(found) - for i in 1:length(node.args) - if !found[i] # && !(isa(node.args[i], Expr) && node.args[i].head == :line) # this would not escape lines (I don't know what implications that has) - node.args[i] = esc(node.args[i]) - end - end - end - return any(found) - end -end - -function handle_reqs!(node::Any, reqs_name::Symbol) - # for anything that's not an Expr - return false -end - -""" - @impl_dep reward(::P,::S,::A,::S) where {P<:POMDP,S,A} reward(::P,::S,::A) - -Declare an implementation dependency and automatically implement `implemented`. - -In the example above, `@implemented reward(::P,::S,::A,::S)` will return true if the user has implemented `reward(::P,::S,::A,::S)` OR `reward(::P,::S,::A)` - -THIS IS ONLY INTENDED FOR USE INSIDE POMDPs AND MAY NOT FUNCTION CORRECTLY ELSEWHERE -""" -macro impl_dep(signature, dependency) - if signature.head == :where - sig_req = signature.args[1] - wheres = signature.args[2:end] - else - sig_req = signature - wheres = () - end - tplex = convert_req(sig_req) - deptplex = convert_req(dependency) - impled = quote - function implemented(f::typeof(first($tplex)), TT::Type{last($tplex)}) where {$(wheres...)} - m = which(f,TT) - if m.module == POMDPs && !implemented($deptplex...) - return false - else # a more specific implementation exists - return true - end - return false - end - end - return esc(impled) -end diff --git a/src/requirements_printing.jl b/src/requirements_printing.jl deleted file mode 100644 index 3ae1cb02..00000000 --- a/src/requirements_printing.jl +++ /dev/null @@ -1,99 +0,0 @@ -function show_heading(io::IO, requirer) - print(io, "INFO: POMDPs.jl requirements for ") - printstyled(io, handle_method(requirer), color=:blue) - println(io, " and dependencies. ([✔] = implemented correctly; [X] = not implemented; [?] = could not determine)") -end - -function show_requirer(io::IO, r::AbstractRequirementSet) - print(io, "For ") - printstyled(io, "$(handle_method(r.requirer))", color=:blue) - if r.parent == nothing - println(io, ":") - else - println(io, " (in $(handle_method(r.parent))):") - end -end - -function show_checked_list(io::IO, cl::AbstractVector{T}) where T <: Tuple - for item in cl - if ismissing(first(item)) - printstyled(io, " [?] $(format_method(item[2], item[3]))", color=:yellow) - println(io) - elseif first(item) == true - printstyled(io, " [✔] $(format_method(item[2], item[3]))", color=:green) - println(io) - else - @assert first(item) == false - printstyled(io, " [X] $(format_method(item[2], item[3]))", color=:red) - println(io) - end - end -end - -function show_incomplete(io, r::RequirementSet) - @assert r.exception != nothing - extype = typeof(r.exception) - printstyled(io, " WARNING: Some requirements may not be shown because a $(extype) was thrown.", color=:yellow) - println(io) -end - -handle_method(str::Any) = string(str) -handle_method(str::Req) = format_method(str...) -short_method(str::Any) = string(str) -short_method(str::Req) = string(first(str)) - -function format_method(f::Function, argtypes::TupleType; module_names=false, color=nothing) - fname = f - typenames = argtypes.parameters - if !module_names - fname = typeof(f).name.mt.name - mless_typenames = [] - for t in argtypes.parameters - if isa(t, Union) - str = "Union{" - for (i, tt) in enumerate(fieldnames(typeof(t))) - str = string(str, getfield(t, tt), i= v"1.4.0" - Pkg.pkg"registry add https://github.com/JuliaPOMDP/Registry" - else - url = "https://github.com/JuliaPOMDP/Registry" - depot = Pkg.depots1() - # clone to temp dir first - tmp = mktempdir() - Base.shred!(LibGit2.CachedCredentials()) do creds - LibGit2.with(Pkg.GitTools.clone(url, tmp; header = "registry from $(repr(url))", credentials = creds)) do repo - end - end - # verify that the clone looks like a registry - if !isfile(joinpath(tmp, "Registry.toml")) - Pkg.Types.pkgerror("no `Registry.toml` file in cloned registry") - end - - registry = read_registry(joinpath(tmp, "Registry.toml")) - verify_registry(registry) - - # copy to depot - regpath = joinpath(depot, "registries", registry["name"]) - ispath(dirname(regpath)) || mkpath(dirname(regpath)) - if Pkg.Types.isdir_windows_workaround(regpath) - existing_registry = read_registry(joinpath(regpath, "Registry.toml")) - @assert registry["uuid"] == existing_registry["uuid"] - @info("registry `$(registry["name"])` already exists in `$(Base.contractuser(dirname(regpath)))`") - else - cp(tmp, regpath) - Pkg.Types.printpkgstyle(stdout, :Added, "registry `$(registry["name"])` to `$(Base.contractuser(dirname(regpath)))`") - end - end -end - -function read_registry(regfile) - registry = Pkg.TOML.parsefile(regfile) - return registry -end - -const REQUIRED_REGISTRY_ENTRIES = ("name", "uuid", "repo", "packages") - -function verify_registry(registry::Dict{String, Any}) - for key in REQUIRED_REGISTRY_ENTRIES - haskey(registry, key) || Pkg.Types.pkgerror("no `$key` entry in `Registry.toml`.") - end -end diff --git a/test/runtests.jl b/test/runtests.jl index 9bb708b5..02c73fd9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,48 +3,10 @@ using Test using POMDPs using Random -POMDPs.logger_context(::Test.TestLogger) = IOContext(stderr) - -mightbemissing(x) = ismissing(x) || x - -# using Logging -# global_logger(ConsoleLogger(stderr, Logging.Debug)) - -mutable struct A <: POMDP{Int,Bool,Bool} end -@testset "implement" begin - - @test_throws MethodError length(states(A())) - @test_throws MethodError stateindex(A(), 1) - - @test !@implemented transition(::A, ::Int, ::Bool) - POMDPs.transition(::A, s, a) = [s+a] - @test @implemented transition(::A, ::Int, ::Bool) - - @test !@implemented discount(::A) - POMDPs.discount(::A) = 0.95 - @test @implemented discount(::A) - - @test !@implemented reward(::A,::Int,::Bool,::Int) - @test !@implemented reward(::A,::Int,::Bool) - POMDPs.reward(::A,::Int,::Bool) = -1.0 - @test @implemented reward(::A,::Int,::Bool,::Int) - @test @implemented reward(::A,::Int,::Bool) - - @test !@implemented observation(::A,::Int,::Bool,::Int) - @test !@implemented observation(::A,::Bool,::Int) - POMDPs.observation(::A,::Bool,::Int) = [true, false] - @test @implemented observation(::A,::Int,::Bool,::Int) - @test @implemented observation(::A,::Bool,::Int) -end - @testset "infer" begin include("test_inferrence.jl") end -@testset "require" begin - include("test_requirements.jl") -end - @testset "generative" begin include("test_generative.jl") end @@ -53,14 +15,6 @@ end include("test_generative_backedges.jl") end -@testset "ddn_struct" begin - include("test_ddn_struct.jl") -end - -@testset "gendep" begin - include("test_deprecated_generative.jl") -end - struct CI <: POMDP{Int,Int,Int} end struct CV <: POMDP{Vector{Float64},Vector{Float64},Vector{Float64}} end @@ -82,10 +36,7 @@ struct CV <: POMDP{Vector{Float64},Vector{Float64},Vector{Float64}} end end struct EA <: POMDP{Int, Int, Int} end -@testset "error" begin - @test_throws MethodError transition(EA(), 1, 2) - @test_throws DistributionNotImplemented gen(DDNOut(:sp), EA(), 1, 2, Random.GLOBAL_RNG) -end +struct EB <: POMDP{Int, Int, Int} end @testset "history" begin POMDPs.history(i::Int) = [(o=i,)] @@ -93,4 +44,36 @@ end @test currentobs(4) == 4 end -POMDPs.add_registry() +@testset "deprecated" begin + + POMDPs.add_registry() + + @test !@implemented transition(::EA, ::Int, ::Int) + POMDPs.transition(::EA, ::Int, ::Int) = [0] + @test @implemented transition(::EA, ::Int, ::Int) + + @POMDP_require solve(a::Int, b::Int) begin + @req transition(::EA, ::Int, ::Int) + end + @POMDP_requirements Int begin end + @requirements_info Int + a = 1 + b = 2 + @get_requirements solve(a, b) + @show_requirements solve(a, b) + @warn_requirements solve(a, b) + + @test_throws ErrorException @req + @test_throws ErrorException @subreq + + @test gen(DDNOut(:sp), EA(), 1, 1, MersenneTwister(3)) == 0 + @test_throws MethodError @gen(:sp,:o)(EA(), 1, true, MersenneTwister(4)) + + POMDPs.initialstate(::EA) = [1,2,3] + @test (@test_deprecated initialstate_distribution(EA())) == initialstate(EA()) + @test (@test_deprecated initialstate(EA(), Random.GLOBAL_RNG)) in initialstate(EA()) + + @test_throws MethodError initialstate(EB()) + POMDPs.initialstate_distribution(m::EB) = [1] + @test initialstate(EB()) == [1] +end diff --git a/test/test_ddn_struct.jl b/test/test_ddn_struct.jl deleted file mode 100644 index db67cae7..00000000 --- a/test/test_ddn_struct.jl +++ /dev/null @@ -1,51 +0,0 @@ -struct DDNA <: MDP{Int, Int} end -ddn = DDNStructure(DDNA) - -@test all(v in Set((:s, :a, :r, :sp)) for v in nodenames(ddn)) -ns = Set(nodenames(ddn)) -@test all(v in ns for v in [:s, :a, :r, :sp]) - -struct DDNB <: POMDP{Int, Int, Int} end -ddn = DDNStructure(DDNB) - -@test all(v in Set((:s, :a, :r, :sp, :o)) for v in nodenames(ddn)) -ns = Set(nodenames(ddn)) -@test all(v in ns for v in [:s, :a, :r, :sp, :o]) - -@test Set(nodenames(ddn)) == Set(nodenames(typeof(ddn))) - -ddn = DDNStructure(DDNB) -@test node(ddn, :sp) == DistributionDDNNode(transition) -@test Set(depvars(ddn, :sp)) == Set((DDNNode(:s), DDNNode(:a))) -@test Set(depnames(ddn, :sp)) == Set((:s, :a)) - -module InfoModule - using POMDPs - export add_infonode - - add_infonode(ddn) = add_node(ddn, :info, ConstantDDNNode(nothing), (:s, :a)) -end - -using Main.InfoModule -struct DDNC <: POMDP{Int, Int, Int} end -POMDPs.DDNStructure(::Type{DDNC}) = pomdp_ddn() |> add_infonode -@test gen(DDNNode(:info), DDNC(), 1, 1, Random.GLOBAL_RNG) == nothing -@test gen(DDNOut(:info), DDNC(), 1, 1, Random.GLOBAL_RNG) == nothing - -# Example from DDNStructure docstring -struct MyMDP <: MDP{Int, Int} end -POMDPs.gen(::MyMDP, s, a, rng) = (sp=s+a+rand(rng, [1,2,3]), r=s^2) - -# make a new node delta_s that is deterministically sp-s -function POMDPs.DDNStructure(::Type{MyMDP}) - ddn = mdp_ddn() - return add_node(ddn, :delta_s, FunctionDDNNode((m,s,sp)->sp-s), (:s, :sp)) -end - -@test gen(DDNOut(:delta_s), MyMDP(), 1, 1, Random.GLOBAL_RNG) in [2, 3, 4] - -struct DDND <: MDP{Int, Int} end -POMDPs.DDNStructure(::Type{DDND}) = add_node(mdp_ddn(), :x, GenericDDNNode(), (:s, :a)) -@test_throws ErrorException gen(DDNNode(:x), DDND(), 1, 1, Random.GLOBAL_RNG) -POMDPs.gen(::DDNNode{:x}, ::DDND, s, a, rng) = s*a -@test gen(DDNOut(:x), DDND(), 1, 1, Random.GLOBAL_RNG) == 1 diff --git a/test/test_deprecated_generative.jl b/test/test_deprecated_generative.jl deleted file mode 100644 index 9b94e91c..00000000 --- a/test/test_deprecated_generative.jl +++ /dev/null @@ -1,36 +0,0 @@ -struct DW <: POMDP{Int, Bool, Int} end -@test_deprecated @test_throws DistributionNotImplemented generate_s(DW(), 1, true, Random.GLOBAL_RNG) -@test_deprecated @test_throws DistributionNotImplemented generate_sr(DW(), 1, true, Random.GLOBAL_RNG) -@test_deprecated @test_throws DistributionNotImplemented generate_o(DW(), 1, true, 2, Random.GLOBAL_RNG) -@test_deprecated @test_throws DistributionNotImplemented generate_so(DW(), 1, true, Random.GLOBAL_RNG) -@test_deprecated @test_throws DistributionNotImplemented generate_sor(DW(), 1, true, Random.GLOBAL_RNG) -@test_throws ErrorException generate_or(DW(), 1, true, 2, Random.GLOBAL_RNG) - -struct DB <: POMDP{Int, Bool, Bool} end - -POMDPs.transition(b::DB, s::Int, a::Bool) = Int[s+a] -@test implemented(generate_s, Tuple{DB, Int, Bool, MersenneTwister}) -@test generate_s(DB(), 1, false, Random.GLOBAL_RNG) == 1 - -@test mightbemissing(!@implemented generate_sor(::DB, ::Int, ::Bool, ::MersenneTwister)) -# don't run this test because it will compile gen(::DDNNode{:o},...) and I don't want to deal with the backedges -# @test_throws MethodError generate_sor(DB(), 1, false, Random.GLOBAL_RNG) - -POMDPs.reward(b::DB, s::Int, a::Bool, sp::Int) = -1.0 -POMDPs.generate_o(b::DB, s::Int, a::Bool, sp::Int, rng::AbstractRNG) = sp -@test @implemented generate_o(b::DB, s::Int, a::Bool, sp::Int, rng::MersenneTwister) -@test generate_sr(DB(), 1, false, Random.GLOBAL_RNG) == (1, -1.0) -@test mightbemissing(@implemented generate_so(b::DB, s::Int, a::Bool, rng::MersenneTwister)) -@test mightbemissing(@implemented generate_sor(b::DB, s::Int, a::Bool, rng::MersenneTwister)) -@test generate_sor(DB(), 1, true, Random.GLOBAL_RNG) == (2, 2, -1.0) - -# to exercise deprecation warning -struct DC <: POMDP{Nothing, Nothing, Nothing} end -POMDPs.generate_s(c::DC, s::Nothing, a::Nothing, rng::AbstractRNG) = nothing -@test gen(DDNNode(:sp), DC(), nothing, nothing, Random.GLOBAL_RNG) == nothing -@test gen(DDNOut(:sp), DC(), nothing, nothing, Random.GLOBAL_RNG) == nothing - -# test whether implemented gets DDNOut versions -struct DD <: MDP{Nothing, Nothing} end -POMDPs.generate_sr(m::DD, s, a, rng) = nothing -@test @implemented gen(::DDNOut{(:sp, :r)}, ::DD, ::Nothing, ::Nothing, ::MersenneTwister) diff --git a/test/test_generative.jl b/test/test_generative.jl index e58bc447..010cacff 100644 --- a/test/test_generative.jl +++ b/test/test_generative.jl @@ -1,103 +1,67 @@ -import POMDPs: transition, reward, initialstate_distribution +import POMDPs: transition, observation, reward, initialstate_distribution import POMDPs: gen -macro inferred_in_latest(expr) - if VERSION >= v"1.1" - return :(@inferred($expr)) - else - return expr - end +struct Deterministic{T} + x::T end +Base.rand(rng::AbstractRNG, d::Deterministic) = d.x struct W <: POMDP{Int, Bool, Int} end -@test !@implemented initialstate(::W, ::typeof(Random.GLOBAL_RNG)) -@test !@implemented initialstate(::W, ::typeof(Random.GLOBAL_RNG), ::Nothing) # wrong number args -@test_throws MethodError initialstate(W(), Random.GLOBAL_RNG) -@test !@implemented initialobs(::W, ::Int, ::typeof(Random.GLOBAL_RNG)) -@test !@implemented initialobs(::W, ::Int, ::typeof(Random.GLOBAL_RNG), ::Nothing) # wrong number args -@test_throws MethodError initialobs(W(), 1, Random.GLOBAL_RNG) -@test_throws DistributionNotImplemented gen(DDNNode(:sp), W(), 1, true, Random.GLOBAL_RNG) +@test_throws MethodError initialstate(W()) +@test_throws MethodError initialobs(W(), 1) try - gen(DDNNode(:sp), W(), 1, true, Random.GLOBAL_RNG) + @gen(:sp)(W(), 1, true, Random.GLOBAL_RNG) catch ex str = sprint(showerror, ex) - @test occursin(":sp", str) @test occursin("transition", str) end -@test_throws DistributionNotImplemented gen(DDNOut(:sp,:r), W(), 1, true, Random.GLOBAL_RNG) -@test_throws DistributionNotImplemented @gen(:sp,:r)(W(), 1, true, Random.GLOBAL_RNG) -@test_throws DistributionNotImplemented gen(DDNNode(:o), W(), 1, true, 2, Random.GLOBAL_RNG) +@test_throws MethodError @gen(:sp,:r)(W(), 1, true, Random.GLOBAL_RNG) + +@test_throws MethodError @gen(:sp,:o)(W(), 1, true, Random.GLOBAL_RNG) +@test_throws MethodError @gen(:sp,:o,:r)(W(), 1, true, Random.GLOBAL_RNG) +POMDPs.gen(::W, ::Int, ::Bool, ::AbstractRNG) = nothing +@test_throws AssertionError @gen(:sp)(W(), 1, true, Random.GLOBAL_RNG) +@test_throws AssertionError @gen(:sp,:r)(W(), 1, true, Random.GLOBAL_RNG) +POMDPs.gen(::W, ::Int, ::Bool, ::AbstractRNG) = (useless=nothing,) +@test_throws MethodError @gen(:sp,:r)(W(), 1, true, Random.GLOBAL_RNG) + +transition(::W, s, a) = Deterministic(s) +@test_throws MethodError @gen(:o)(W(), 1, true, Random.GLOBAL_RNG) try - gen(DDNNode(:o), W(), 1, true, 2, Random.GLOBAL_RNG) + @gen(:o)(W(), 1, true, Random.GLOBAL_RNG) catch ex str = sprint(showerror, ex) - @test occursin(":o", str) @test occursin("observation", str) end -@test_throws DistributionNotImplemented gen(DDNOut(:sp,:o), W(), 1, true, Random.GLOBAL_RNG) -@test_throws DistributionNotImplemented @gen(:sp,:o)(W(), 1, true, Random.GLOBAL_RNG) -@test_throws DistributionNotImplemented gen(DDNOut(:sp,:o,:r), W(), 1, true, Random.GLOBAL_RNG) -@test_throws DistributionNotImplemented @gen(:sp,:o,:r)(W(), 1, true, Random.GLOBAL_RNG) -POMDPs.gen(::W, ::Int, ::Bool, ::AbstractRNG) = nothing -@test_throws AssertionError gen(DDNOut(:sp), W(), 1, true, Random.GLOBAL_RNG) -@test_throws AssertionError @gen(:sp)(W(), 1, true, Random.GLOBAL_RNG) -@test_throws AssertionError gen(DDNOut(:sp,:r), W(), 1, true, Random.GLOBAL_RNG) -@test_throws AssertionError @gen(:sp,:r)(W(), 1, true, Random.GLOBAL_RNG) -POMDPs.gen(::W, ::Int, ::Bool, ::AbstractRNG) = (useless=nothing,) -@test_throws DistributionNotImplemented gen(DDNNode(:sp), W(), 1, true, Random.GLOBAL_RNG) -@test_throws DistributionNotImplemented gen(DDNOut(:sp,:r), W(), 1, true, Random.GLOBAL_RNG) -@test_throws DistributionNotImplemented @gen(:sp,:r)(W(), 1, true, Random.GLOBAL_RNG) struct B <: POMDP{Int, Bool, Bool} end -transition(b::B, s::Int, a::Bool) = Int[s+a] -@test implemented(gen, Tuple{DDNNode{:sp}, B, Int, Bool, MersenneTwister}) -@test @inferred_in_latest(gen(DDNNode(:sp), B(), 1, false, Random.GLOBAL_RNG)) == 1 +transition(b::B, s::Int, a::Bool) = Deterministic(s+a) +@test @inferred(@gen(:sp)(B(), 1, false, Random.GLOBAL_RNG)) == 1 -@test mightbemissing(@implemented(gen(::DDNOut{(:sp,:o,:r)}, ::B, ::Int, ::Bool, ::MersenneTwister))) -@test_throws DistributionNotImplemented gen(DDNOut(:sp,:o,:r), B(), 1, false, Random.GLOBAL_RNG) +@test_throws MethodError @gen(:sp,:o,:r)(B(), 1, false, Random.GLOBAL_RNG) reward(b::B, s::Int, a::Bool, sp::Int) = -1.0 -gen(::DDNNode{:o}, b::B, s::Int, a::Bool, sp::Int, rng::AbstractRNG) = sp -@test @inferred_in_latest(gen(DDNOut(:sp,:r), B(), 1, false, Random.GLOBAL_RNG)) == (1, -1.0) -@test @inferred_in_latest(@gen(:sp,:r)(B(), 1, false, Random.GLOBAL_RNG)) == (1, -1.0) - -@test @implemented gen(::DDNNode{:o}, b::B, s::Int, a::Bool, sp::Int, rng::AbstractRNG) -@test mightbemissing(@implemented(gen(::DDNOut{(:sp,:o)}, b::B, s::Int, a::Bool, rng::MersenneTwister))) -@test mightbemissing(@implemented gen(::DDNOut{(:sp,:o,:r)}, b::B, s::Int, a::Bool, rng::MersenneTwister)) -@test @inferred_in_latest(gen(DDNOut(:sp,:o,:r), B(), 1, true, Random.GLOBAL_RNG)) == (2, 2, -1.0) -@test @inferred_in_latest(@gen(:sp,:o,:r)(B(), 1, true, Random.GLOBAL_RNG)) == (2, 2, -1.0) +observation(b::B, s::Int, a::Bool, sp::Int) = Deterministic(sp) +@test @inferred(@gen(:sp,:r)(B(), 1, false, Random.GLOBAL_RNG)) == (1, -1.0) -initialstate_distribution(b::B) = Int[1,2,3] -@test @implemented initialstate(::B, ::MersenneTwister) -@test initialstate(B(), Random.GLOBAL_RNG) in initialstate_distribution(B()) -POMDPs.observation(b::B, s::Int) = Bool[s] -@test @implemented initialobs(::B, ::Int, ::MersenneTwister) -@test initialobs(B(), 1, Random.GLOBAL_RNG) == 1 +@test @inferred(@gen(:sp,:o,:r)(B(), 1, true, Random.GLOBAL_RNG)) == (2, 2, -1.0) mutable struct C <: POMDP{Nothing, Nothing, Nothing} end -gen(::DDNNode{:sp}, c::C, s::Nothing, a::Nothing, rng::AbstractRNG) = nothing -gen(::DDNNode{:o}, c::C, s::Nothing, a::Nothing, sp::Nothing, rng::AbstractRNG) = nothing +transition(c::C, s::Nothing, a::Nothing) = Deterministic(nothing) +observation(c::C, s::Nothing, a::Nothing, sp::Nothing) = Deterministic(nothing) reward(c::C, s::Nothing, a::Nothing) = 0.0 -@test mightbemissing(@implemented gen(::DDNOut{(:sp,:o,:r)}, ::C, ::Nothing, ::Nothing, ::MersenneTwister)) -@test @inferred_in_latest(gen(DDNOut(:sp,:o,:r), C(), nothing, nothing, Random.GLOBAL_RNG)) == (nothing, nothing, 0.0) -@test @inferred_in_latest(@gen(:sp,:o,:r)(C(), nothing, nothing, Random.GLOBAL_RNG)) == (nothing, nothing, 0.0) +@test @inferred(@gen(:sp,:o,:r)(C(), nothing, nothing, Random.GLOBAL_RNG)) == (nothing, nothing, 0.0) struct GD <: MDP{Int, Int} end -struct Deterministic{T} - x::T -end -Base.rand(rng::AbstractRNG, d::Deterministic) = d.x POMDPs.transition(::GD, s, a) = Deterministic(s + a) -@test @inferred_in_latest(gen(DDNNode(:sp), GD(), 1, 1, Random.GLOBAL_RNG)) == 2 +@test @inferred(@gen(:sp)(GD(), 1, 1, Random.GLOBAL_RNG)) == 2 POMDPs.reward(::GD, s, a) = s + a -@test @inferred_in_latest(gen(DDNNode(:r), GD(), 1, 1, 2, Random.GLOBAL_RNG)) == 2 +@test @inferred(@gen(:r)(GD(), 1, 1, Random.GLOBAL_RNG)) == 2 struct GE <: MDP{Int, Int} end -@test_throws DistributionNotImplemented gen(DDNNode(:sp), GE(), 1, 1, Random.GLOBAL_RNG) -@test_throws DistributionNotImplemented gen(DDNOut(:sp,:r), GE(), 1, 1, Random.GLOBAL_RNG) +@test_throws MethodError @gen(:sp)(GE(), 1, 1, Random.GLOBAL_RNG) +@test_throws MethodError @gen(:sp,:r)(GE(), 1, 1, Random.GLOBAL_RNG) POMDPs.gen(::GE, s, a, ::AbstractRNG) = (sp=s+a, r=s^2) -@test @inferred_in_latest(gen(DDNOut(:sp), GE(), 1, 1, Random.GLOBAL_RNG)) == 2 -@test @inferred_in_latest(@gen(:sp)(GE(), 1, 1, Random.GLOBAL_RNG)) == 2 -@test @inferred_in_latest(gen(DDNOut(:sp,:r), GE(), 1, 1, Random.GLOBAL_RNG)) == (2, 1) -@test @inferred_in_latest(@gen(:sp, :r)(GE(), 1, 1, Random.GLOBAL_RNG)) == (2, 1) +@test @inferred(@gen(:sp)(GE(), 1, 1, Random.GLOBAL_RNG)) == 2 +@test @inferred(@gen(:sp, :r)(GE(), 1, 1, Random.GLOBAL_RNG)) == (2, 1) diff --git a/test/test_generative_backedges.jl b/test/test_generative_backedges.jl index dc0887ce..7f66ad78 100644 --- a/test/test_generative_backedges.jl +++ b/test/test_generative_backedges.jl @@ -4,17 +4,17 @@ using Random let struct M <: POMDP{Int, Int, Char} end - @test_throws DistributionNotImplemented gen(DDNNode(:sp), M(), 1, 1, MersenneTwister(4)) + @test_throws MethodError @gen(:sp)(M(), 1, 1, MersenneTwister(4)) POMDPs.transition(::M, ::Int, ::Int) = [1] - @test gen(DDNNode(:sp), M(), 1, 1, MersenneTwister(4)) == 1 - @test_throws DistributionNotImplemented gen(DDNOut(:sp,:o,:r), M(), 1, 1, MersenneTwister(4)) - @test_throws DistributionNotImplemented gen(DDNOut(:sp,:r), M(), 1, 1, MersenneTwister(4)) + @test @gen(:sp)(M(), 1, 1, MersenneTwister(4)) == 1 + @test_throws MethodError @gen(:sp,:o,:r)(M(), 1, 1, MersenneTwister(4)) + @test_throws MethodError @gen(:sp,:r)(M(), 1, 1, MersenneTwister(4)) POMDPs.reward(::M, ::Int, ::Int, ::Int) = 0.0 - POMDPs.gen(::DDNNode{:o}, ::M, ::Int, ::Int, ::Int, ::AbstractRNG) = 'a' - @test gen(DDNOut(:sp,:r), M(), 1, 1, MersenneTwister(4)) == (1, 0.0) - @test gen(DDNOut(:sp,:o,:r), M(), 1, 1, MersenneTwister(4)) == (1, 'a', 0.0) + POMDPs.observation(::M, ::Int, ::Int, ::Int) = ['a'] + @test @gen(:sp,:r)(M(), 1, 1, MersenneTwister(4)) == (1, 0.0) + @test @gen(:sp,:o,:r)(M(), 1, 1, MersenneTwister(4)) == (1, 'a', 0.0) - @test_throws MethodError initialobs(M(), 1, MersenneTwister(4)) - POMDPs.observation(::M, ::Int) = ['a'] - @test initialobs(M(), 1, MersenneTwister(4)) == 'a' + @test_throws MethodError initialobs(M(), 1) + POMDPs.initialobs(::M, ::Int) = ['a'] + @test rand(MersenneTwister(4), initialobs(M(), 1)) == 'a' end diff --git a/test/test_requirements.jl b/test/test_requirements.jl deleted file mode 100644 index bb11a244..00000000 --- a/test/test_requirements.jl +++ /dev/null @@ -1,97 +0,0 @@ -using Test - -tcall = Meta.parse("f(arg1::T1, arg2::T2)") -@test POMDPs.unpack_typedcall(tcall) == (:f, [:arg1, :arg2], [:T1, :T2]) - -# tests case where types aren't specified -@POMDP_require tan(s) begin - @req sin(::typeof(s)) - @req cos(::typeof(s)) -end - -module MyModule - using POMDPs - using Random - - export CoolSolver, solve - - mutable struct CoolSolver <: Solver end - - p = nothing # to test hygeine - @POMDP_require solve(s::CoolSolver, p::POMDP) begin - PType = typeof(p) - S = statetype(PType) - A = actiontype(PType) - @req states(::PType) - @req actions(::PType) - @req transition(::PType, ::S, ::A) - @subreq util2(p) - s = first(states(p)) - @subreq util1(s) - a = first(actions(p)) - t_dist = transition(p, s, a) - @req rand(::AbstractRNG, ::typeof(t_dist)) - @req gen(::DDNOut{:o}, ::PType, ::S, ::A, ::MersenneTwister) - end - - function POMDPs.solve(s::CoolSolver, problem::POMDP{S,A,O}) where {S,A,O} - @warn_requirements solve(s, problem) - reqs = @get_requirements solve(s,problem) - @assert p==nothing - return check_requirements(reqs) - end - - util1(x) = abs(x) - - util2(p::POMDP) = observations(p) - @POMDP_require util2(p::POMDP) begin - P = typeof(p) - @req observations(::P) - end -end - -using POMDPs -using Main.MyModule - -mutable struct SimplePOMDP <: POMDP{Float64, Bool, Int} end -POMDPs.actions(SimplePOMDP) = [true, false] - -POMDPs.discount(::SimplePOMDP) = 0.9 - -let a = 0.0, f(x) = x^2 - @test @req(f(a, 4)) == @req(f(::typeof(a), ::typeof(4))) -end - -reqs = nothing # to check the hygeine of the macro -println("There should be a warning about no @reqs here:") -# 27 minutes has been spent trying to suppress this warning and automate a test for it. If you work more on it, please update this counter. The following things have been tried -# - @test_logs (:warn, "No") @POMDP_requirements ... -# - @capture_err @POMDP_requirements ... # From Suppressor.jl -# - @capture_out @POMDP_requirements ... # From Suppressor.jl -@POMDP_requirements "Warn none" begin - 1+1 -end -@test reqs == nothing -@test_throws LoadError macroexpand(Main, quote @POMDP_requirements "Malformed" begin - @req iterator(typeof(as)) - end -end) - -# solve(CoolSolver(), SimplePOMDP()) -@test_throws MethodError solve(CoolSolver(), SimplePOMDP()) - -POMDPs.states(::SimplePOMDP) = [1.4, 3.2, 5.8] -struct SimpleDistribution - ss::Vector{Float64} - b::Vector{Float64} -end -POMDPs.transition(p::SimplePOMDP, s::Float64, ::Bool) = SimpleDistribution(states(p), [0.2, 0.2, 0.6]) - -@test (solve(CoolSolver(), SimplePOMDP()) & false) == false - -POMDPs.observations(p::SimplePOMDP) = [1,2,3] - -Random.rand(rng::AbstractRNG, d::SimpleDistribution) = sample(rng, d.ss, WeightVec(d.b)) -POMDPs.gen(::DDNOut{:o}, m::SimplePOMDP, s, a, rng) = 1 - -@test solve(CoolSolver(), SimplePOMDP())