Skip to content

Commit

Permalink
Merge pull request #1 from invenia/bc/prototype
Browse files Browse the repository at this point in the history
Implement KeyedDistribution and KeyedSampleable
  • Loading branch information
bencottier authored Mar 16, 2021
2 parents 883d9c9 + 1ef92ce commit 92bb274
Show file tree
Hide file tree
Showing 6 changed files with 437 additions and 9 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
name: CI
# Run on master, tags, or any pull request
# Run on main, tags, or any pull request
on:
schedule:
- cron: '0 2 * * *' # Daily at 2 AM UTC (8 PM CST)
push:
branches: [master]
branches: [main]
tags: ["*"]
pull_request:
jobs:
Expand All @@ -15,8 +15,8 @@ jobs:
fail-fast: false
matrix:
version:
- '1'
- '1.0'
- '1.5' # Invenia Prod version
- '1' # Latest Release
os:
- ubuntu-latest
- macOS-latest
Expand Down
19 changes: 17 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,26 @@ uuid = "2576fb08-064d-4cab-b15d-8dda7fcb9a6d"
authors = ["Invenia Technical Computing Corporation"]
version = "0.1.0"

[deps]
AutoHashEquals = "15f4f7f2-30c1-5605-9d31-71845cf9641f"
AxisKeys = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[compat]
julia = "1"
AutoHashEquals = "0.2"
AxisKeys = "0.1"
Distributions = "0.24"
IterTools = "1.3"
StableRNGs = "1"
julia = "1.5"

[extras]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
test = ["LinearAlgebra", "StableRNGs", "Statistics", "Test"]
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://invenia.github.io/KeyedDistributions.jl/stable)
[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://invenia.github.io/KeyedDistributions.jl/dev)
[![Build Status](https://github.com/invenia/KeyedDistributions.jl/workflows/CI/badge.svg)](https://github.com/invenia/KeyedDistributions.jl/actions)
[![Coverage](https://codecov.io/gh/invenia/KeyedDistributions.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/invenia/KeyedDistributions.jl)
[![Coverage](https://codecov.io/gh/invenia/KeyedDistributions.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/invenia/KeyedDistributions.jl)
[![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/invenia/BlueStyle)
[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac)
2 changes: 2 additions & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,6 @@ makedocs(;

deploydocs(;
repo="github.com/invenia/KeyedDistributions.jl",
devbranch = "main",
push_preview = true,
)
176 changes: 175 additions & 1 deletion src/KeyedDistributions.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,179 @@
module KeyedDistributions

# Write your package code here.
using AutoHashEquals
using AxisKeys
using Distributions
using IterTools
using Random: AbstractRNG

export KeyedDistribution, KeyedSampleable
export axiskeys, distribution


# Constructors

for T in (:Distribution, :Sampleable)
KeyedT = Symbol(:Keyed, T)
@eval begin
"""
$($KeyedT)(d<:$($T), keys::Tuple{Vararg{AbstractVector}})
Stores `keys` for each variate alongside the `$($T)` `d`,
supporting all of the common functions of a `$($T)`.
Common functions that return an `AbstractArray`, such as `rand`,
will return a `KeyedArray` with keys derived from the `$($T)`.
The type of `keys` is restricted to be consistent with
[AxisKeys.jl](https://github.com/mcabbott/AxisKeys.jl).
The length of the `keys` tuple must be the number of dimensions, which is 1 for
univariate and multivariate distributions, and 2 for matrix-variate distributions.
The length of each key vector in must match the length along each dimension.
"""
@auto_hash_equals struct $KeyedT{F<:VariateForm, S<:ValueSupport, D<:$T{F, S}} <: $T{F, S}
d::D
keys::Tuple{Vararg{AbstractVector}}

function $KeyedT(d::$T{F, S}, keys::Tuple{Vararg{AbstractVector}}) where {F, S}
key_lengths = map(length, keys)
key_lengths == _size(d) || throw(ArgumentError(
"lengths of key vectors $key_lengths must match " *
"size of distribution $(_size(d))"))

return new{F, S, typeof(d)}(d, keys)
end
end

"""
$($KeyedT)(d<:$($T), keys::AbstractVector)
Constructor for [`$($KeyedT)`](@ref) with one dimension of variates.
The elements of `keys` correspond to the variates of the distribution.
"""
$KeyedT(d::$T{F, S}, keys::AbstractVector) where {F, S} = $KeyedT(d, (keys, ))
end
end

_size(d) = (length(d),)
_size(d::Sampleable{<:Matrixvariate}) = size(d)

"""
KeyedDistribution(d::Distribution)
Constructs a [`KeyedDistribution`](@ref) using the keys of the first field stored in `d`,
or if there are no keys, `1:n` for the length `n` of each dimension.
"""
function KeyedDistribution(d::Distribution)
first_field = getfield(d, 1)
return KeyedDistribution(d, _keys(first_field))
end

_keys(x::KeyedArray) = axiskeys(x)
_keys(x) = map(Base.OneTo, size(x))

const KeyedDistOrSampleable = Union{KeyedDistribution, KeyedSampleable}

# Access methods

"""
distribution(::KeyedDistribution) -> Distribution
distribution(::KeyedSampleable{F, S, D}) -> D
Return the wrapped distribution.
"""
distribution(d::KeyedDistOrSampleable) = d.d

# AxisKeys functionality

"""
axiskeys(d::Union{KeyedDistribution, KeyedSampleable})
Return the keys for the variates of the `KeyedDistribution` or `KeyedSampleable`.
"""
AxisKeys.axiskeys(d::KeyedDistOrSampleable) = d.keys

# Standard functions to overload for new Distribution and/or Sampleable
# https://juliastats.org/Distributions.jl/latest/extends/#Create-New-Samplers-and-Distributions

function Distributions._rand!(
rng::AbstractRNG,
d::KeyedDistOrSampleable,
x::AbstractVector{<:Real}
)
sample = Distributions._rand!(rng, distribution(d), x)
return KeyedArray(sample, axiskeys(d))
end

Base.length(d::KeyedDistOrSampleable) = length(distribution(d))

Distributions.size(d::KeyedDistribution{<:Matrixvariate}) = size(distribution(d))

Distributions.sampler(d::KeyedDistribution) = sampler(distribution(d))

Base.eltype(d::KeyedDistribution) = eltype(distribution(d))

function Distributions._logpdf(d::KeyedDistribution, x::AbstractArray)
# Workaround when KeyedArray is parameter of Distribution
# https://github.com/mcabbott/AxisKeys.jl/issues/54
dist = distribution(d)
T = typeof(dist)
args = map(_maybe_parent, fieldvalues(dist))
unkeyed_dist = T.name.wrapper(args...)

return Distributions._logpdf(unkeyed_dist, x)
end

_maybe_parent(x) = x
_maybe_parent(x::AbstractArray) = parent(x)

# Also need to overload `rand` methods to return KeyedArrays:

function Base.rand(rng::AbstractRNG, d::KeyedDistOrSampleable)
sample = rand(rng, distribution(d))
ndims(sample) == 0 && return sample # univariate returns a Number
return KeyedArray(sample, axiskeys(d))
end

function Base.rand(rng::AbstractRNG, d::KeyedDistOrSampleable, n::Int)
samples = rand(rng, distribution(d), n)
ndims(samples) == 1 && return KeyedArray(samples, Base.OneTo(n)) # univariate
return KeyedArray(samples, (first(axiskeys(d)), Base.OneTo(n)))
end

function Base.rand(rng::AbstractRNG, d::KeyedDistribution{<:Matrixvariate}, n::Int)
# Distributions.rand returns a vector of matrices
samples = [KeyedArray(x, axiskeys(d)) for x in rand(rng, distribution(d), n)]
return KeyedArray(samples, Base.OneTo(n))
end

# Statistics functions for Distribution

Distributions.mean(d::KeyedDistribution) = KeyedArray(mean(distribution(d)), axiskeys(d))

Distributions.var(d::KeyedDistribution) = KeyedArray(var(distribution(d)), axiskeys(d))

function Distributions.cov(d::KeyedDistribution)
keys = vcat(axiskeys(d)...)
return KeyedArray(cov(distribution(d)), (keys, keys))
end

Distributions.entropy(d::KeyedDistribution) = entropy(distribution(d))
Distributions.entropy(d::KeyedDistribution, b::Real) = entropy(distribution(d), b)

# Univariate Distributions only

for f in (:logpdf, :quantile, :mgf, :cf)
@eval Distributions.$f(d::KeyedDistribution{<:Univariate}, x) = $f(distribution(d), x)
end

for f in (:minimum, :maximum, :modes, :mode, :skewness, :kurtosis)
@eval Distributions.$f(d::KeyedDistribution{<:Univariate}) = $f(distribution(d))
end

# Needed to avoid method ambiguity
Distributions.cdf(d::KeyedDistribution{<:Univariate}, x::Real) = cdf(distribution(d), x)

function Distributions.insupport(d::KeyedDistribution{<:Univariate}, x::Real)
return insupport(distribution(d), x)
end

end
Loading

2 comments on commit 92bb274

@bencottier
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/32083

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.0 -m "<description of version>" 92bb274bf79af6919d86ab8e0ceab32ed6661d69
git push origin v0.1.0

Please sign in to comment.