Skip to content

Commit

Permalink
Merge pull request #30 from nschloe/tensordot
Browse files Browse the repository at this point in the history
tensordot
  • Loading branch information
nschloe authored Feb 2, 2022
2 parents 3dc43cb + 9242c5e commit 0a76896
Show file tree
Hide file tree
Showing 10 changed files with 56 additions and 65 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ repos:
- id: isort

- repo: https://github.com/psf/black
rev: 21.10b0
rev: 22.1.0
hooks:
- id: black
language_version: python3
Expand Down
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Copyright 2021 Nico Schlömer
Copyright 2021-2022 Nico Schlömer

Redistribution and use in source and binary forms, with or without modification, are
permitted provided that the following conditions are met:
Expand Down
17 changes: 4 additions & 13 deletions justfile
Original file line number Diff line number Diff line change
@@ -1,21 +1,12 @@
version := `python3 -c "from configparser import ConfigParser; p = ConfigParser(); p.read('setup.cfg'); print(p['metadata']['version'])"`
name := `python3 -c "from configparser import ConfigParser; p = ConfigParser(); p.read('setup.cfg'); print(p['metadata']['name'])"`

version := `python3 -c "from src.npx.__about__ import __version__; print(__version__)"`

default:
@echo "\"just publish\"?"

tag:
@if [ "$(git rev-parse --abbrev-ref HEAD)" != "main" ]; then exit 1; fi
curl -H "Authorization: token `cat ~/.github-access-token`" -d '{"tag_name": "v{{version}}"}' https://api.github.com/repos/nschloe/{{name}}/releases

upload: clean
publish:
@if [ "$(git rev-parse --abbrev-ref HEAD)" != "main" ]; then exit 1; fi
# https://stackoverflow.com/a/58756491/353337
python3 -m build --sdist --wheel .
twine upload dist/*

publish: tag upload
gh release create "v{{version}}"
flit publish

clean:
@find . | grep -E "(__pycache__|\.pyc|\.pyo$)" | xargs rm -rf
Expand Down
36 changes: 34 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,35 @@
[build-system]
requires = ["setuptools>=42", "wheel"]
build-backend = "setuptools.build_meta"
requires = ["flit_core >=3.2,<4"]
build-backend = "flit_core.buildapi"

[tool.isort]
profile = "black"

[project]
name = "npx"
authors = [{name = "Nico Schlömer", email = "nico.schloemer@gmail.com"}]
description = "Some useful extensions for NumPy"
readme = "README.md"
license = {file = "LICENSE"}
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: BSD License",
"Operating System :: OS Independent",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Topic :: Scientific/Engineering",
"Topic :: Utilities",
]
dynamic = ["version"]
requires-python = ">=3.7"
dependencies = ["numpy >= 1.20.0"]

[project.urls]
Code = "https://github.com/nschloe/npx"
Issues = "https://github.com/nschloe/npx/issues"
Funding = "https://github.com/sponsors/nschloe"
39 changes: 0 additions & 39 deletions setup.cfg

This file was deleted.

1 change: 1 addition & 0 deletions src/npx/__about__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__version__ = "0.1.0"
2 changes: 2 additions & 0 deletions src/npx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from .__about__ import __version__
from ._isin import isin_rows
from ._main import add_at, dot, outer, solve, subtract_at, sum_at
from ._mean import mean
from ._unique import unique, unique_rows

__all__ = [
"__version__",
"dot",
"outer",
"solve",
Expand Down
4 changes: 1 addition & 3 deletions src/npx/_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@ def dot(a: ArrayLike, b: ArrayLike) -> np.ndarray:
"""Take arrays `a` and `b` and form the dot product between the last axis of `a` and
the first of `b`.
"""
a = np.asarray(a)
b = np.asarray(b)
return np.dot(a, b.reshape(b.shape[0], -1)).reshape(a.shape[:-1] + b.shape[1:])
return np.tensordot(a, b, 1)


def outer(a: ArrayLike, b: ArrayLike) -> np.ndarray:
Expand Down
16 changes: 11 additions & 5 deletions src/npx/_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ def _logsumexp(x: ArrayLike):


def mean(x: ArrayLike, p: float = 1) -> np.ndarray:
"""Generalized mean.
See <https://github.com/numpy/numpy/issues/19341> for the numpy issue.
"""
x = np.asarray(x)

n = len(x)
Expand All @@ -24,15 +28,17 @@ def mean(x: ArrayLike, p: float = 1) -> np.ndarray:
if np.any(x < 0.0):
raise ValueError("p=0 only works with nonnegative x.")
return np.prod(np.power(x, 1 / n))
# alternative:
# return np.exp(np.mean(np.log(x)))
elif p == np.inf:
return np.max(np.abs(x))

if not isinstance(p, int) and np.any(x < 0.0):
raise ValueError("Non-integer p only work with nonnegative x.")

if np.all(x > 0.0):
# logsumexp trick to avoid overflow for large p
# only works for positive x though
return np.exp((_logsumexp(p * np.log(x)) - np.log(n)) / p)
else:
return (np.sum(x ** p) / n) ** (1.0 / p)

if not isinstance(p, (int, np.integer)):
raise ValueError(f"Non-integer p (={p}) only work with nonnegative x.")

return (np.sum(x**p) / n) ** (1.0 / p)
2 changes: 1 addition & 1 deletion tests/speedtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,6 @@ def npx_sum_at(data):
b = perfplot.bench(
setup=setup,
kernels=[np_add_at, npx_add_at, npx_sum_at],
n_range=[2 ** k for k in range(23)],
n_range=[2**k for k in range(23)],
)
b.save("perf-add-at.svg")

0 comments on commit 0a76896

Please sign in to comment.