Skip to content

Commit

Permalink
Fix strassen test
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonReinhard committed Nov 7, 2024
1 parent 4631d87 commit 4929fce
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 33 deletions.
7 changes: 5 additions & 2 deletions src/task/compute.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,17 @@ function result_type(fc::FunctionCall, known_res_types::Dict{Symbol,Type})

if length(types) > 1
throw(
"failure during type inference: function call $fc is type unstable, possible return types: $types",
"failure during type inference: function call $fc with argument types $(argument_types) is type unstable, possible return types: $types",
)
end
if isempty(types)
throw(
"failure during type inference: function call $fc has no return types, this is likely because no method matches the arguments",
"failure during type inference: function call $fc with argument types $(argument_types) has no return types, this is likely because no method matches the arguments",
)
end
if types[1] == Any
@warn "inferred return type 'Any' in task $fc with argument types $(argument_types)"
end

return types[1]
end
28 changes: 10 additions & 18 deletions test/strassen/impl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module MatrixMultiplicationImpl

using ComputableDAGs

const STRASSEN_MIN_SIZE = 64 # minimum matrix size to use Strassen algorithm instead of naive algorithm
const STRASSEN_MIN_SIZE = 32 # minimum matrix size to use Strassen algorithm instead of naive algorithm
const DEFAULT_TYPE = Float64 # default type of matrix multiplication assumed

# problem model definition
Expand Down Expand Up @@ -37,23 +37,15 @@ ComputableDAGs.compute_effort(::ComputeTask_MultStrassen) = 0
A + B
@inline ComputableDAGs.compute(::ComputeTask_Sub, A::AbstractMatrix, B::AbstractMatrix) =
A - B
@inline ComputableDAGs.compute(
::ComputeTask_MultBase, A::AbstractMatrix, B::AbstractMatrix
) = A * B
@inline function ComputableDAGs.compute(
::ComputeTask_MultBase, A::MATRIX_T, B::MATRIX_T
)::MATRIX_T where {MATRIX_T<:AbstractMatrix}
return A * B
end

function ComputableDAGs.compute(
::ComputeTask_MultStrassen,
C11::AbstractMatrix,
C12::AbstractMatrix,
C21::AbstractMatrix,
C22::AbstractMatrix,
)
# One Strassen step from precomputed matrices M1...M7
# result will be [C11 C12; C21 C22], where
# C11 = M1 + M4 - M5 + M7
# C12 = M3 + M5
# C21 = M2 + M4
# C22 = M1 - M2 + M3 + M6
::ComputeTask_MultStrassen, C11::MATRIX_T, C12::MATRIX_T, C21::MATRIX_T, C22::MATRIX_T
) where {MATRIX_T<:AbstractMatrix}
return [
C11 C12
C21 C22
Expand Down Expand Up @@ -334,9 +326,9 @@ function ComputableDAGs.input_expr(
end

function ComputableDAGs.input_type(mm::MatrixMultiplication{T}) where {T}
return Tuple{<:AbstractMatrix{T},<:AbstractMatrix{T}}
return Tuple{Matrix{T},Matrix{T}}
end

export MatrixMultiplication

end
end
22 changes: 9 additions & 13 deletions test/strassen_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ include("strassen/impl.jl")
using .MatrixMultiplicationImpl

TEST_TYPES = (Int32, Int64, Float32, Float64)
TEST_SIZES = (32, 64, 128, 256, 512)
NODE_NUMBERS = (4, 70, 532, 3766, 26404)
EDGE_NUMBERS = (3, 96, 747, 5304, 37203)
TEST_SIZES = (16, 32, 64, 128) #, 256
NODE_NUMBERS = (4, 70, 532, 3766) #, 26404
EDGE_NUMBERS = (3, 96, 747, 5304) #, 37203

@testset "Strassen Matrix Type $M_T Size $(TEST_SIZES[M_SIZE_I])" for (M_T, M_SIZE_I) in
Iterators.product(
Expand All @@ -25,7 +25,7 @@ EDGE_NUMBERS = (3, 96, 747, 5304, 37203)

@testset "Construction" begin
@test mm.size == M_SIZE
@test input_type(mm) == Tuple{AbstractMatrix{M_T},AbstractMatrix{M_T}}
@test input_type(mm) == Tuple{Matrix{M_T},Matrix{M_T}}
@test input isa input_type(mm)
@test_throws "unknown data node name C" input_expr(mm, "C", :input)
end
Expand All @@ -49,18 +49,14 @@ EDGE_NUMBERS = (3, 96, 747, 5304, 37203)
end

@testset "Execution" begin
@test Base.infer_return_type(f, (typeof(input),)) == typeof(input[1])
@test Base.return_types(f, (typeof(input),))[1] == typeof(input[1])
@test isapprox(f(input), input[1] * input[2])
end

if (M_SIZE > 64)
# skip SMatrix tests for larger matrix sizes
continue
end
@testset "Execution with closures" begin
f_closures = get_compute_function(g, mm, cpu_st(), @__MODULE__; closures_size=100)

s_input = (rand(SMatrix{M_SIZE,M_SIZE,M_T}), rand(SMatrix{M_SIZE,M_SIZE,M_T}))
@testset "Execution on SMatrix" begin
@test Base.infer_return_type(f, (typeof(s_input),)) == typeof(s_input[1])
@test isapprox(f(s_input), s_input[1] * s_input[2])
@test Base.return_types(f_closures, (typeof(input),))[1] == typeof(input[1])
@test isapprox(f_closures(input), input[1] * input[2])
end
end

0 comments on commit 4929fce

Please sign in to comment.