Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Float types as parametrics #14

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,6 @@ bench/cuda/matrix_vector_mult
tests/all
tests/cublas
*.ndb
nimsuggest.log
nimsuggest.log
bin/
.DS_Store
10 changes: 2 additions & 8 deletions linalg/private/ops.nim
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,7 @@ template maxIndexPrivate(N, v: untyped): auto =
m = val
(j, m)

proc maxIndex*[N: static[int]](v: Vector32[N]): tuple[i: int, val: float32] =
maxIndexPrivate(N, v)

proc maxIndex*[N: static[int]](v: Vector64[N]): tuple[i: int, val: float64] =
proc maxIndex*[N: static[int], T](v: Vector[N, T]): tuple[i: int, val: T] =
maxIndexPrivate(N, v)

proc maxIndex*(v: DVector32): tuple[i: int, val: float32] =
Expand All @@ -190,10 +187,7 @@ template minIndexPrivate(N, v: untyped): auto =
m = val
return (j, m)

proc minIndex*[N: static[int]](v: Vector32[N]): tuple[i: int, val: float32] =
minIndexPrivate(N, v)

proc minIndex*[N: static[int]](v: Vector64[N]): tuple[i: int, val: float64] =
proc minIndex*[N: static[int], T](v: Vector[N, T]): tuple[i: int, val: T] =
minIndexPrivate(N, v)

proc minIndex*(v: DVector32): tuple[i: int, val: float32] =
Expand Down
42 changes: 19 additions & 23 deletions linalg/private/types.nim
Original file line number Diff line number Diff line change
Expand Up @@ -13,35 +13,31 @@
# limitations under the License.

type
Vector32*[N: static[int]] = ref array[N, float32]
Vector64*[N: static[int]] = ref array[N, float64]
Matrix32*[M, N: static[int]] = object
Vector*[N: static[int], T: SomeReal] = ref array[N, T]
Matrix*[M, N: static[int], T: SomeReal] = object
order: OrderType
data: ref array[N * M, float32]
Matrix64*[M, N: static[int]] = object
order: OrderType
data: ref array[M * N, float64]
DVector32* = seq[float32]
DVector64* = seq[float64]
DMatrix32* = ref object
order: OrderType
M, N: int
data: seq[float32]
DMatrix64* = ref object
data: ref array[N * M, T]
DVector*[T: SomeReal] = seq[T]
DMatrix*[T: SomeReal] = ref object
order: OrderType
M, N: int
data: seq[float64]
AnyVector = Vector32 or Vector64 or DVector32 or DVector64
AnyMatrix = Matrix32 or Matrix64 or DMatrix32 or DMatrix64
data: seq[T]

Vector32*[N: static[int]] = Vector[N, float32]
Vector64*[N: static[int]] = Vector[N, float64]
Matrix32*[M, N: static[int]] = Matrix[M, N, float32]
Matrix64*[M, N: static[int]] = Matrix[M, N, float64]
DVector32* = DVector[float32]
DVector64* = DVector[float64]
DMatrix32* = DMatrix[float32]
DMatrix64* = DMatrix[float64]
AnyVector = Vector32 or Vector64 or DVector32 or DVector64 or Vector
AnyMatrix = Matrix32 or Matrix64 or DMatrix32 or DMatrix64 or Matrix

# Float pointers
template fp(v: Vector32): ptr float32 = cast[ptr float32](addr(v[]))

template fp(v: Vector64): ptr float64 = cast[ptr float64](addr(v[]))

template fp(m: Matrix32): ptr float32 = cast[ptr float32](addr(m.data[]))
template fp[N,T](v: Vector[N,T]): ptr T = cast[ptr T](addr(v[]))

template fp(m: Matrix64): ptr float64 = cast[ptr float64](addr(m.data[]))
template fp[M,N,T](m: Matrix[M,N,T]): ptr T = cast[ptr T](addr(m.data[]))

template fp(v: DVector32): ptr float32 = cast[ptr float32](unsafeAddr(v[0]))

Expand Down