Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP View of Eye returns OneElement #238

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
Aqua = "0.5"
Aqua = "0.5, 0.6"
julia = "1.6"

[extras]
Expand Down
8 changes: 2 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,8 @@ as well as identity matrices. This package exports the following types:


The primary purpose of this package is to present a unified way of constructing
matrices. For example, to construct a 5-by-5 `CLArray` of all zeros, one would use
```julia
julia> CLArray(Zeros(5,5))
```
Because `Zeros` is lazy, this can be accomplished on the GPU with no memory transfer.
Similarly, to construct a 5-by-5 `BandedMatrix` of all zeros with bandwidths `(1,2)`, one would use
matrices.
For example, to construct a 5-by-5 `BandedMatrix` of all zeros with bandwidths `(1,2)`, one would use
```julia
julia> BandedMatrix(Zeros(5,5), (1, 2))
```
Expand Down
57 changes: 49 additions & 8 deletions src/FillArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import Base: size, getindex, setindex!, IndexStyle, checkbounds, convert,
+, -, *, /, \, diff, sum, cumsum, maximum, minimum, sort, sort!,
any, all, axes, isone, iterate, unique, allunique, permutedims, inv,
copy, vec, setindex!, count, ==, reshape, _throw_dmrs, map, zero,
show, view, in, mapreduce, one, reverse, promote_op
show, view, in, mapreduce, one, reverse, promote_op, promote_rule

import LinearAlgebra: rank, svdvals!, tril, triu, tril!, triu!, diag, transpose, adjoint, fill!,
dot, norm2, norm1, normInf, normMinusInf, normp, lmul!, rmul!, diagzero, AdjointAbsVec, TransposeAbsVec,
Expand All @@ -18,7 +18,7 @@ import Base.Broadcast: broadcasted, DefaultArrayStyle, broadcast_shape
import Statistics: mean, std, var, cov, cor


export Zeros, Ones, Fill, Eye, Trues, Falses
export Zeros, Ones, Fill, Eye, Trues, Falses, OneElement

import Base: oneto

Expand All @@ -34,6 +34,7 @@ const AbstractFillVecOrMat{T} = Union{AbstractFillVector{T},AbstractFillMatrix{T

==(a::AbstractFill, b::AbstractFill) = axes(a) == axes(b) && getindex_value(a) == getindex_value(b)


@inline function _fill_getindex(F::AbstractFill, kj::Integer...)
@boundscheck checkbounds(F, kj...)
getindex_value(F)
Expand Down Expand Up @@ -147,6 +148,27 @@ Fill{T,0}(x::T, ::Tuple{}) where T = Fill{T,0,Tuple{}}(x, ()) # ambiguity fix
@inline axes(F::Fill) = F.axes
@inline size(F::Fill) = map(length, F.axes)

"""
getindex_value(F::AbstractFill)

Return the value that `F` is filled with.

# Examples

```jldoctest
julia> f = Ones(3);

julia> FillArrays.getindex_value(f)
1.0

julia> g = Fill(2, 10);

julia> FillArrays.getindex_value(g)
2
```
"""
getindex_value

@inline getindex_value(F::Fill) = F.value

AbstractArray{T}(F::Fill{V,N}) where {T,V,N} = Fill{T}(convert(T, F.value)::T, F.axes)
Expand All @@ -155,7 +177,12 @@ AbstractFill{T}(F::AbstractFill) where T = AbstractArray{T}(F)

copy(F::Fill) = Fill(F.value, F.axes)

""" Throws an error if `arr` does not contain one and only one unique value. """
"""
unique_value(arr::AbstractArray)

Return `only(unique(arr))` without intermediate allocations.
Throws an error if `arr` does not contain one and only one unique value.
"""
function unique_value(arr::AbstractArray)
if isempty(arr) error("Cannot convert empty array to Fill") end
val = first(arr)
Expand Down Expand Up @@ -262,6 +289,7 @@ for (Typ, funcs, func) in ((:Zeros, :zeros, :zero), (:Ones, :ones, :one))
@inline $Typ{T,N}(A::AbstractArray{V,N}) where{T,V,N} = $Typ{T,N}(size(A))
@inline $Typ{T}(A::AbstractArray) where{T} = $Typ{T}(size(A))
@inline $Typ(A::AbstractArray) = $Typ{eltype(A)}(A)
@inline $Typ(::Type{T}, m...) where T = $Typ{T}(m...)

@inline axes(Z::$Typ) = Z.axes
@inline size(Z::$Typ) = length.(Z.axes)
Expand All @@ -273,6 +301,14 @@ for (Typ, funcs, func) in ((:Zeros, :zeros, :zero), (:Ones, :ones, :one))
copy(F::$Typ) = F

getindex(F::$Typ{T,0}) where T = getindex_value(F)

promote_rule(::Type{$Typ{T, N, Axes}}, ::Type{$Typ{V, N, Axes}}) where {T,V,N,Axes} = $Typ{promote_type(T,V),N,Axes}
function convert(::Type{$Typ{T,N,Axes}}, A::$Typ{V,N,Axes}) where {T,V,N,Axes}
convert(T, getindex_value(A)) # checks that the types are convertible
$Typ{T,N,Axes}(axes(A))
end
convert(::Type{$Typ{T,N}}, A::$Typ{V,N,Axes}) where {T,V,N,Axes} = convert($Typ{T,N,Axes}, A)
convert(::Type{$Typ{T}}, A::$Typ{V,N,Axes}) where {T,V,N,Axes} = convert($Typ{T,N,Axes}, A)
end
end

Expand All @@ -284,6 +320,8 @@ for TYPE in (:Fill, :AbstractFill, :Ones, :Zeros), STYPE in (:AbstractArray, :Ab
end
end

promote_rule(::Type{<:AbstractFill{T, N, Axes}}, ::Type{<:AbstractFill{V, N, Axes}}) where {T,V,N,Axes} = Fill{promote_type(T,V),N,Axes}

"""
fillsimilar(a::AbstractFill, axes)

Expand Down Expand Up @@ -426,7 +464,8 @@ end


## Array
Base.Array{T,N}(F::AbstractFill{V,N}) where {T,V,N} = fill(convert(T, getindex_value(F)), size(F))
Base.Array{T,N}(F::AbstractFill{V,N}) where {T,V,N} =
convert(Array{T,N}, fill(convert(T, getindex_value(F)), size(F)))

# These are in case `zeros` or `ones` are ever faster than `fill`
for (Typ, funcs, func) in ((:Zeros, :zeros, :zero), (:Ones, :ones, :one))
Expand All @@ -437,7 +476,7 @@ end

# temporary patch. should be a PR(#48895) to LinearAlgebra
Diagonal{T}(A::AbstractFillMatrix) where T = Diagonal{T}(diag(A))
function convert(::Type{T}, A::AbstractFillMatrix) where T<:Diagonal
function convert(::Type{T}, A::AbstractFillMatrix) where T<:Diagonal
checksquare(A)
isdiag(A) ? T(A) : throw(InexactError(:convert, T, A))
end
Expand Down Expand Up @@ -496,14 +535,14 @@ sum(x::AbstractFill) = getindex_value(x)*length(x)
sum(f, x::AbstractFill) = length(x) * f(getindex_value(x))
sum(x::Zeros) = getindex_value(x)

cumsum(x::AbstractFill{<:Any,1}) = range(getindex_value(x); step=getindex_value(x),
length=length(x))
# needed to support infinite case
steprangelen(st...) = StepRangeLen(st...)
cumsum(x::AbstractFill{<:Any,1}) = steprangelen(getindex_value(x), getindex_value(x), length(x))

cumsum(x::ZerosVector) = x
cumsum(x::ZerosVector{Bool}) = x
cumsum(x::OnesVector{II}) where II<:Integer = convert(AbstractVector{II}, oneto(length(x)))
cumsum(x::OnesVector{Bool}) = oneto(length(x))
cumsum(x::AbstractFillVector{Bool}) = cumsum(AbstractFill{Int}(x))


#########
Expand Down Expand Up @@ -717,4 +756,6 @@ Base.@propagate_inbounds function view(A::AbstractFill{<:Any,N}, I::Vararg{Real,
fillsimilar(A)
end

include("oneelement.jl")

end # module
1 change: 0 additions & 1 deletion src/fillalgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ end
*(a::ZerosMatrix, b::AbstractMatrix) = mult_zeros(a, b)
*(a::AbstractMatrix, b::ZerosVector) = mult_zeros(a, b)
*(a::AbstractMatrix, b::ZerosMatrix) = mult_zeros(a, b)
*(a::ZerosVector, b::AbstractVector) = mult_zeros(a, b)
*(a::ZerosMatrix, b::AbstractVector) = mult_zeros(a, b)
*(a::AbstractVector, b::ZerosMatrix) = mult_zeros(a, b)

Expand Down
9 changes: 5 additions & 4 deletions src/fillbroadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,6 @@ _broadcasted_zeros(f, a, b) = Zeros{Base.Broadcast.combine_eltypes(f, (a, b))}(b
_broadcasted_ones(f, a, b) = Ones{Base.Broadcast.combine_eltypes(f, (a, b))}(broadcast_shape(axes(a), axes(b)))
_broadcasted_nan(f, a, b) = Fill(convert(Base.Broadcast.combine_eltypes(f, (a, b)), NaN), broadcast_shape(axes(a), axes(b)))

# TODO: remove at next breaking version
_broadcasted_zeros(a, b) = _broadcasted_zeros(+, a, b)
_broadcasted_ones(a, b) = _broadcasted_ones(+, a, b)

broadcasted(::DefaultArrayStyle, ::typeof(+), a::Zeros, b::Zeros) = _broadcasted_zeros(+, a, b)
broadcasted(::DefaultArrayStyle, ::typeof(+), a::Ones, b::Zeros) = _broadcasted_ones(+, a, b)
broadcasted(::DefaultArrayStyle, ::typeof(+), a::Zeros, b::Ones) = _broadcasted_ones(+, a, b)
Expand Down Expand Up @@ -247,3 +243,8 @@ broadcasted(::DefaultArrayStyle{N}, ::typeof(Base.literal_pow), ::Base.RefValue{
broadcasted(::DefaultArrayStyle{N}, ::typeof(Base.literal_pow), ::Base.RefValue{typeof(^)}, r::Ones{T,N}, ::Base.RefValue{Val{k}}) where {T,N,k} = Ones{T}(axes(r))
broadcasted(::DefaultArrayStyle{N}, ::typeof(Base.literal_pow), ::Base.RefValue{typeof(^)}, r::Zeros{T,N}, ::Base.RefValue{Val{0}}) where {T,N} = Ones{T}(axes(r))
broadcasted(::DefaultArrayStyle{N}, ::typeof(Base.literal_pow), ::Base.RefValue{typeof(^)}, r::Zeros{T,N}, ::Base.RefValue{Val{k}}) where {T,N,k} = Zeros{T}(axes(r))

# supports structured broadcast
if isdefined(LinearAlgebra, :fzero)
LinearAlgebra.fzero(x::Zeros) = zero(eltype(x))
end
58 changes: 58 additions & 0 deletions src/oneelement.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""
OneElement(val, ind, axesorsize) <: AbstractArray

Represents an array with the specified axes (if its a tuple of `AbstractUnitRange`s)
or size (if its a tuple of `Integer`s), with a single entry set to `val` and all others equal to zero,
specified by `ind``.
"""
struct OneElement{T,N,I,A} <: AbstractArray{T,N}
val::T
ind::I
axes::A
OneElement(val::T, ind::I, axes::A) where {T<:Number, I<:NTuple{N,Int}, A<:NTuple{N,AbstractUnitRange}} where {N} = new{T,N,I,A}(val, ind, axes)
end

OneElement(val, inds::NTuple{N,Int}, sz::NTuple{N,Integer}) where N = OneElement(val, inds, oneto.(sz))
"""
OneElement(val, ind::Int, n::Int)

Creates a length `n` vector where the `ind` entry is equal to `val`, and all other entries are zero.
"""
OneElement(val, ind::Int, len::Int) = OneElement(val, (ind,), (len,))
"""
OneElement(ind::Int, n::Int)

Creates a length `n` vector where the `ind` entry is equal to `1`, and all other entries are zero.
"""
OneElement(inds::Int, sz::Int) = OneElement(1, inds, sz)
OneElement{T}(val, inds::NTuple{N,Int}, sz::NTuple{N,Integer}) where {T,N} = OneElement(convert(T,val), inds, oneto.(sz))
OneElement{T}(val, inds::Int, sz::Int) where T = OneElement{T}(val, (inds,), (sz,))

"""
OneElement{T}(val, ind::Int, n::Int)

Creates a length `n` vector where the `ind` entry is equal to `one(T)`, and all other entries are zero.
"""
OneElement{T}(inds::Int, sz::Int) where T = OneElement(one(T), inds, sz)

Base.size(A::OneElement) = map(length, A.axes)
Base.axes(A::OneElement) = A.axes
function Base.getindex(A::OneElement{T,N}, kj::Vararg{Int,N}) where {T,N}
@boundscheck checkbounds(A, kj...)
ifelse(kj == A.ind, A.val, zero(T))
end

Base.replace_in_print_matrix(o::OneElement{<:Any,2}, k::Integer, j::Integer, s::AbstractString) =
o.ind == (k,j) ? s : Base.replace_with_centered_mark(s)

function Base.setindex(A::Zeros{T,N}, v, kj::Vararg{Int,N}) where {T,N}
@boundscheck checkbounds(A, kj...)
OneElement(convert(T, v), kj, axes(A))
end


Base.@propagate_inbounds function view(A::RectOrDiagonal{<:Any,<:AbstractFill}, kr::AbstractRange, j::Integer)
@boundscheck checkbounds(A, kr, j)
k = findfirst(isequal(j), kr)
OneElement(getindex_value(A.diag), isnothing(k) ? 0 : something(k), length(kr))
end
Loading