Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Strassen Matrix Multiplication Implementation for unit tests #42

Merged
merged 4 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/graph/mute.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,18 @@ function _insert_node!(graph::DAG, node::Node; track=true, invalidate_cache=true
return node
end

function _insert_edge!(
::DAG, ::DataTaskNode, ::DataTaskNode, ::Int=0; track=true, invalidate_cache=true
)
throw("trying to create an edge between two data nodes which is not allowed")
end

function _insert_edge!(
::DAG, ::ComputeTaskNode, ::ComputeTaskNode, ::Int=0; track=true, invalidate_cache=true
)
throw("trying to create an edge between two compute nodes which is not allowed")
end

"""
_insert_edge!(graph::DAG, node1::Node, node2::Node, index::Int=0; track = true, invalidate_cache = true)

Expand Down
9 changes: 6 additions & 3 deletions src/task/compute.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ function get_function_call(node::ComputeTaskNode)
end

function get_function_call(node::DataTaskNode)
@assert length(children(node)) == 1 "trying to call get_function_call on a data task node that has $(length(node.children)) children instead of 1"
@assert length(children(node)) == 1 "trying to call get_function_call on a data task node that has $(length(node.children)) children instead of 1\nchildren: $(node.children)"

# TODO: dispatch to device implementations generating the copy commands
return [
Expand Down Expand Up @@ -78,14 +78,17 @@ function result_type(fc::FunctionCall, known_res_types::Dict{Symbol,Type})

if length(types) > 1
throw(
"failure during type inference: function call $fc is type unstable, possible return types: $types",
"failure during type inference: function call $fc with argument types $(argument_types) is type unstable, possible return types: $types",
)
end
if isempty(types)
throw(
"failure during type inference: function call $fc has no return types, this is likely because no method matches the arguments",
"failure during type inference: function call $fc with argument types $(argument_types) has no return types, this is likely because no method matches the arguments",
)
end
if types[1] == Any
@warn "inferred return type 'Any' in task $fc with argument types $(argument_types)"
end

return types[1]
end
7 changes: 5 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
using SafeTestsets

@safetestset "Utility Unit Tests " begin
@safetestset "Utility Unit Tests " begin
include("unit_tests_utility.jl")
end
# TODO: Make a new simple test model and rewrite tests here

@safetestset "Strassen Matrix Multiplication Tests " begin
include("strassen_test.jl")
end
334 changes: 334 additions & 0 deletions test/strassen/impl.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,334 @@
module MatrixMultiplicationImpl

using ComputableDAGs

const STRASSEN_MIN_SIZE = 32 # minimum matrix size to use Strassen algorithm instead of naive algorithm
const DEFAULT_TYPE = Float64 # default type of matrix multiplication assumed

# problem model definition
struct MatrixMultiplication{T}
size::Int # size of multiplication
end

MatrixMultiplication(size::Int) = MatrixMultiplication{DEFAULT_TYPE}(size)

struct ComputeTask_Slice{X_SLICE,Y_SLICE} <: AbstractComputeTask end # perform a matrix slicing operation, subindexing the matrix with the ranges X_SLICE and Y_SLICE
struct ComputeTask_Add <: AbstractComputeTask end # perform matrix addition
struct ComputeTask_Sub <: AbstractComputeTask end # perform matrix subtraction
struct ComputeTask_MultBase <: AbstractComputeTask end # perform matrix multiplication on two matrices using naive algorithm
struct ComputeTask_MultStrassen <: AbstractComputeTask end # perform one strassen assembly step from C11, C12, C21, C22

ComputableDAGs.children(::ComputeTask_Slice) = 1 # A
ComputableDAGs.children(::ComputeTask_Add) = 2 # A, B
ComputableDAGs.children(::ComputeTask_Sub) = 2 # A, B
ComputableDAGs.children(::ComputeTask_MultBase) = 2 # A, B
ComputableDAGs.children(::ComputeTask_MultStrassen) = 7 # M1...M7

ComputableDAGs.compute_effort(::ComputeTask_Slice) = 0
ComputableDAGs.compute_effort(::ComputeTask_Add) = 0
ComputableDAGs.compute_effort(::ComputeTask_Sub) = 0
ComputableDAGs.compute_effort(::ComputeTask_MultBase) = 0
ComputableDAGs.compute_effort(::ComputeTask_MultStrassen) = 0

@inline ComputableDAGs.compute(
::ComputeTask_Slice{UR_X,UR_Y}, A::AbstractMatrix
) where {UR_X,UR_Y} = A[UR_X, UR_Y]
@inline ComputableDAGs.compute(::ComputeTask_Add, A::AbstractMatrix, B::AbstractMatrix) =
A + B
@inline ComputableDAGs.compute(::ComputeTask_Sub, A::AbstractMatrix, B::AbstractMatrix) =
A - B
@inline function ComputableDAGs.compute(
::ComputeTask_MultBase, A::MATRIX_T, B::MATRIX_T
)::MATRIX_T where {MATRIX_T<:AbstractMatrix}
return A * B
end

function ComputableDAGs.compute(
::ComputeTask_MultStrassen, C11::MATRIX_T, C12::MATRIX_T, C21::MATRIX_T, C22::MATRIX_T
) where {MATRIX_T<:AbstractMatrix}
return [
C11 C12
C21 C22
]
end

# return data node with the result
function _dag_build_helper!(
g::DAG,
mm::MatrixMultiplication{T},
A::DataTaskNode, # Data node that contains matrix A
B::DataTaskNode, # Data node that contains matrix B
) where {T}
mm_size = mm.size
@assert iseven(mm_size) "matrix size is not even: $mm_size"
mm_half_size = div(mm_size, 2)

data_size = mm.size^2 * sizeof(T)
data_size_half = div(data_size, 4)

if mm_size < STRASSEN_MIN_SIZE
mb_task = insert_node!(g, ComputeTask_MultBase())
mb_data = insert_node!(g, DataTask(data_size))

insert_edge!(g, A, mb_task, 1)
insert_edge!(g, B, mb_task, 2)
insert_edge!(g, mb_task, mb_data)
return mb_data
end

# STRASSEN step
h1 = 1:mm_half_size
h2 = (mm_half_size + 1):mm_size

# -- Subindexing of A and B to prepare A_11, A_12, and so on
A11_t = insert_node!(g, ComputeTask_Slice{h1,h1}())
insert_edge!(g, A, A11_t)
A12_t = insert_node!(g, ComputeTask_Slice{h1,h2}())
insert_edge!(g, A, A12_t)
A21_t = insert_node!(g, ComputeTask_Slice{h2,h1}())
insert_edge!(g, A, A21_t)
A22_t = insert_node!(g, ComputeTask_Slice{h2,h2}())
insert_edge!(g, A, A22_t)

A11_d = insert_node!(g, DataTask(data_size_half))
insert_edge!(g, A11_t, A11_d)
A12_d = insert_node!(g, DataTask(data_size_half))
insert_edge!(g, A12_t, A12_d)
A21_d = insert_node!(g, DataTask(data_size_half))
insert_edge!(g, A21_t, A21_d)
A22_d = insert_node!(g, DataTask(data_size_half))
insert_edge!(g, A22_t, A22_d)

B11_t = insert_node!(g, ComputeTask_Slice{h1,h1}())
insert_edge!(g, B, B11_t)
B12_t = insert_node!(g, ComputeTask_Slice{h1,h2}())
insert_edge!(g, B, B12_t)
B21_t = insert_node!(g, ComputeTask_Slice{h2,h1}())
insert_edge!(g, B, B21_t)
B22_t = insert_node!(g, ComputeTask_Slice{h2,h2}())
insert_edge!(g, B, B22_t)

B11_d = insert_node!(g, DataTask(data_size_half))
insert_edge!(g, B11_t, B11_d)
B12_d = insert_node!(g, DataTask(data_size_half))
insert_edge!(g, B12_t, B12_d)
B21_d = insert_node!(g, DataTask(data_size_half))
insert_edge!(g, B21_t, B21_d)
B22_d = insert_node!(g, DataTask(data_size_half))
insert_edge!(g, B22_t, B22_d)

# M1 = (A11 + A22) x (B11 + B22)
local M1_d::DataTaskNode
begin
A_sum_t = insert_node!(g, ComputeTask_Add()) # A11 + A22
B_sum_t = insert_node!(g, ComputeTask_Add()) # B11 + B22
insert_edge!(g, A11_d, A_sum_t, 1)
insert_edge!(g, A22_d, A_sum_t, 2)
insert_edge!(g, B11_d, B_sum_t, 1)
insert_edge!(g, B22_d, B_sum_t, 2)
A_sum_d = insert_node!(g, DataTask(data_size_half))
B_sum_d = insert_node!(g, DataTask(data_size_half))
insert_edge!(g, A_sum_t, A_sum_d)
insert_edge!(g, B_sum_t, B_sum_d)

M1_d = _dag_build_helper!(
g, MatrixMultiplication{T}(mm_half_size), A_sum_d, B_sum_d
)
end

# M2 = (A21 + A22) x B11
local M2_d::DataTaskNode
begin
A_sum_t = insert_node!(g, ComputeTask_Add()) # A21 + A22
insert_edge!(g, A21_d, A_sum_t, 1)
insert_edge!(g, A22_d, A_sum_t, 2)
A_sum_d = insert_node!(g, DataTask(data_size_half))
insert_edge!(g, A_sum_t, A_sum_d)

M2_d = _dag_build_helper!(g, MatrixMultiplication{T}(mm_half_size), A_sum_d, B11_d)
end

# M3 = A11 x (B12 - B22)
local M3_d::DataTaskNode
begin
B_dif_t = insert_node!(g, ComputeTask_Sub()) # B12 - B22
insert_edge!(g, B12_d, B_dif_t, 1)
insert_edge!(g, B22_d, B_dif_t, 2)
B_dif_d = insert_node!(g, DataTask(data_size_half))
insert_edge!(g, B_dif_t, B_dif_d)

M3_d = _dag_build_helper!(g, MatrixMultiplication{T}(mm_half_size), A11_d, B_dif_d)
end

# M4 = A22 x (B21 - B11)
local M4_d::DataTaskNode
begin
B_dif_t = insert_node!(g, ComputeTask_Sub()) # B21 - B11
insert_edge!(g, B21_d, B_dif_t, 1)
insert_edge!(g, B11_d, B_dif_t, 2)
B_dif_d = insert_node!(g, DataTask(data_size_half))
insert_edge!(g, B_dif_t, B_dif_d)

M4_d = _dag_build_helper!(g, MatrixMultiplication{T}(mm_half_size), A22_d, B_dif_d)
end

# M5 = (A11 + A12) x B22
local M5_d::DataTaskNode
begin
A_sum_t = insert_node!(g, ComputeTask_Add()) # A11 + A12
insert_edge!(g, A11_d, A_sum_t, 1)
insert_edge!(g, A12_d, A_sum_t, 2)
A_sum_d = insert_node!(g, DataTask(data_size_half))
insert_edge!(g, A_sum_t, A_sum_d)

M5_d = _dag_build_helper!(g, MatrixMultiplication{T}(mm_half_size), A_sum_d, B22_d)
end

# M6 = (A21 - A11) x (B11 + B12)
local M6_d::DataTaskNode
begin
A_dif_t = insert_node!(g, ComputeTask_Sub()) # A21 - A11
B_sum_t = insert_node!(g, ComputeTask_Add()) # B11 + B12
insert_edge!(g, A21_d, A_dif_t, 1)
insert_edge!(g, A11_d, A_dif_t, 2)
insert_edge!(g, B11_d, B_sum_t, 1)
insert_edge!(g, B12_d, B_sum_t, 2)
A_dif_d = insert_node!(g, DataTask(data_size_half))
B_sum_d = insert_node!(g, DataTask(data_size_half))
insert_edge!(g, A_dif_t, A_dif_d)
insert_edge!(g, B_sum_t, B_sum_d)

M6_d = _dag_build_helper!(
g, MatrixMultiplication{T}(mm_half_size), A_dif_d, B_sum_d
)
end

# M7 = (A12 - A22) x (B21 + B22)
local M7_d::DataTaskNode
begin
A_dif_t = insert_node!(g, ComputeTask_Sub()) # A12 - A22
B_sum_t = insert_node!(g, ComputeTask_Add()) # B21 + B22
insert_edge!(g, A12_d, A_dif_t, 1)
insert_edge!(g, A22_d, A_dif_t, 2)
insert_edge!(g, B21_d, B_sum_t, 1)
insert_edge!(g, B22_d, B_sum_t, 2)
A_dif_d = insert_node!(g, DataTask(data_size_half))
B_sum_d = insert_node!(g, DataTask(data_size_half))
insert_edge!(g, A_dif_t, A_dif_d)
insert_edge!(g, B_sum_t, B_sum_d)

M7_d = _dag_build_helper!(
g, MatrixMultiplication{T}(mm_half_size), A_dif_d, B_sum_d
)
end

# C11 = M1 + M4 - M5 + M7
local C11_d::DataTaskNode
begin
s1_t = insert_node!(g, ComputeTask_Add()) # M1 + M4
s2_t = insert_node!(g, ComputeTask_Sub()) # M7 - M5
C11_t = insert_node!(g, ComputeTask_Add()) # s1 + s2
insert_edge!(g, M1_d, s1_t, 1) # +M1
insert_edge!(g, M4_d, s1_t, 2) # +M4
insert_edge!(g, M7_d, s2_t, 1) # +M7
insert_edge!(g, M5_d, s2_t, 2) # -M5

s1_d = insert_node!(g, DataTask(data_size_half))
s2_d = insert_node!(g, DataTask(data_size_half))
insert_edge!(g, s1_t, s1_d)
insert_edge!(g, s2_t, s2_d)

insert_edge!(g, s1_d, C11_t, 1)
insert_edge!(g, s2_d, C11_t, 2)

C11_d = insert_node!(g, DataTask(data_size_half))
insert_edge!(g, C11_t, C11_d)
end

# C12 = M3 + M5
local C12_d::DataTaskNode
begin
C12_t = insert_node!(g, ComputeTask_Add())
C12_d = insert_node!(g, DataTask(data_size_half))
insert_edge!(g, M3_d, C12_t, 1)
insert_edge!(g, M5_d, C12_t, 2)
insert_edge!(g, C12_t, C12_d)
end

# C21 = M2 + M4
local C21_d::DataTaskNode
begin
C21_t = insert_node!(g, ComputeTask_Add())
C21_d = insert_node!(g, DataTask(data_size_half))
insert_edge!(g, M2_d, C21_t, 1)
insert_edge!(g, M4_d, C21_t, 2)
insert_edge!(g, C21_t, C21_d)
end

# C22 = M1 - M2 + M3 + M6
local C22_d::DataTaskNode
begin
s1_t = insert_node!(g, ComputeTask_Sub()) # M1 - M2
s2_t = insert_node!(g, ComputeTask_Add()) # M3 + M6
C22_t = insert_node!(g, ComputeTask_Add()) # s1 + s2
insert_edge!(g, M1_d, s1_t, 1) # +M1
insert_edge!(g, M2_d, s1_t, 2) # -M2
insert_edge!(g, M3_d, s2_t, 1) # +M3
insert_edge!(g, M6_d, s2_t, 2) # +M6

s1_d = insert_node!(g, DataTask(data_size_half))
s2_d = insert_node!(g, DataTask(data_size_half))
insert_edge!(g, s1_t, s1_d)
insert_edge!(g, s2_t, s2_d)

insert_edge!(g, s1_d, C22_t, 1)
insert_edge!(g, s2_d, C22_t, 2)

C22_d = insert_node!(g, DataTask(data_size_half))
insert_edge!(g, C22_t, C22_d)
end

# Assemble new Matrix C

assemble_t = insert_node!(g, ComputeTask_MultStrassen())
insert_edge!(g, C11_d, assemble_t, 1)
insert_edge!(g, C12_d, assemble_t, 2)
insert_edge!(g, C21_d, assemble_t, 3)
insert_edge!(g, C22_d, assemble_t, 4)

C_d = insert_node!(g, DataTask(data_size))
insert_edge!(g, assemble_t, C_d)

return C_d
end

function ComputableDAGs.graph(mm::MatrixMultiplication{T}) where {T}
g = DAG()

A_d = insert_node!(g, DataTask(mm.size^2 * sizeof(T)), "A")
B_d = insert_node!(g, DataTask(mm.size^2 * sizeof(T)), "B")

C_d = _dag_build_helper!(g, mm, A_d, B_d)

return g
end

function ComputableDAGs.input_expr(
::MatrixMultiplication, name::String, input_symbol::Symbol
)
if name == "A"
return Meta.parse("$input_symbol[1]")
elseif name == "B"
return Meta.parse("$input_symbol[2]")
else
throw("unknown data node name $name")
end
end

function ComputableDAGs.input_type(mm::MatrixMultiplication{T}) where {T}
return Tuple{Matrix{T},Matrix{T}}
end

export MatrixMultiplication

end
Binary file added test/strassen/strassen.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Loading