Skip to content

Commit

Permalink
Stop using pybind's implicit __hash__ (closes gh-102)
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Jan 1, 2023
1 parent c3eb182 commit aba6644
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 5 deletions.
16 changes: 15 additions & 1 deletion gen_wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -1370,6 +1370,10 @@ def write_wrapper(outf, meth):

# {{{ exposer generator

def is_hash(meth):
return meth.name == "get_hash" and len(meth.args) == 1


def write_exposer(outf, meth, arg_names, doc_str):
func_name = f"isl::{meth.cls}_{meth.name}"
py_name = meth.name
Expand All @@ -1385,7 +1389,7 @@ def write_exposer(outf, meth, arg_names, doc_str):
if meth.name == "size" and len(meth.args) == 1:
py_name = "__len__"

if meth.name == "get_hash" and len(meth.args) == 1:
if is_hash(meth):
py_name = "__hash__"

extra_py_names = []
Expand Down Expand Up @@ -1509,6 +1513,16 @@ def gen_wrapper(include_dirs, include_barvinok=False, isl_version=None):
for cls in classes
for meth in fdata.classes_to_methods.get(cls, [])])

for cls in classes:
has_isl_hash = any(
is_hash(meth) for meth in fdata.classes_to_methods.get(cls, []))

if not has_isl_hash:
# pybind11's C++ object base class has an object identity
# __hash__ that everyone inherits automatically. We don't
# want that.
expf.write(f'wrap_{cls}.attr("__hash__") = py::none();\n')

expf.close()
wrapf.close()

Expand Down
30 changes: 26 additions & 4 deletions islpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,9 +313,6 @@ def generic_repr(self):
cls.__str__ = generic_str
cls.__repr__ = generic_repr

if not hasattr(cls, "__hash__"):
raise AssertionError(f"not hashable: {cls}")

# }}}

# {{{ Python set-like behavior
Expand Down Expand Up @@ -349,6 +346,13 @@ def obj_sub(self, other):

# {{{ Space

def space_hash(self):
return hash((type(self),
self.dim(dim_type.param),
self.dim(dim_type.in_),
self.dim(dim_type.out),
self.dim(dim_type.div)))

def space_get_id_dict(self, dimtype=None):
"""Return a dictionary mapping variable :class:`Id` instances to tuples
of (:class:`dim_type`, index).
Expand Down Expand Up @@ -446,6 +450,7 @@ def space_create_from_names(ctx, set=None, in_=None, out=None, params=()):

return result

Space.__hash__ = space_hash
Space.create_from_names = staticmethod(space_create_from_names)
Space.get_var_dict = space_get_var_dict
Space.get_id_dict = space_get_id_dict
Expand Down Expand Up @@ -908,6 +913,12 @@ def val_to_python(self):
# note: automatic upcasts for method arguments are provided through
# 'implicitly_convertible' on the C++ side of the wrapper.

def make_upcasting_hash(special_method, upcast_method):
def wrapper(basic_instance):
return hash((type(basic_instance), upcast_method(basic_instance)))

return wrapper

def make_new_upcast_wrapper(method, upcast):
# This function provides a scope in which method and upcast
# are not changed from one iteration of the enclosing for
Expand All @@ -916,7 +927,6 @@ def make_new_upcast_wrapper(method, upcast):
def wrapper(basic_instance, *args, **kwargs):
special_instance = upcast(basic_instance)
return method(special_instance, *args, **kwargs)

return wrapper

def make_existing_upcast_wrapper(basic_method, special_method, upcast):
Expand All @@ -938,6 +948,18 @@ def wrapper(basic_instance, *args, **kwargs):
def add_upcasts(basic_class, special_class, upcast_method):
from functools import update_wrapper

# {{{ implicitly upcast __hash__

# We don't use hasattr() here because in the C++ part of the wrapper
# we overwrite pybind's unwanted default __hash__ implementation
# with None.
if (getattr(basic_class, "__hash__", None) is None
and getattr(special_class, "__hash__", None) is not None):
wrapper = make_upcasting_hash(special_class.__hash__, upcast_method)
basic_class.__hash__ = update_wrapper(wrapper, basic_class.__hash__)

# }}}

def my_ismethod(class_, method_name):
if method_name.startswith("_"):
return False
Expand Down
38 changes: 38 additions & 0 deletions test/test_isl.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,44 @@ def test_sched_constraints_set_validity():
assert str(validity) == str(validity2)


def test_space_hash():
# Direct from-string Space constructors are recent and broke barvinok CI
# which involves older isl versions.
s1 = isl.Set("[n] -> {[i]: 1=1}").space
s2 = isl.Set("[n] -> {[j]}: 1=1").space
s3 = isl.Set("[n] -> {[j, k]: 1=1}").space
s4 = isl.Set("[m, n] -> {[j, k]: 1=1}").space

i_id = isl.Id("i", context=isl.DEFAULT_CONTEXT)
j_id = isl.Id("j", context=isl.DEFAULT_CONTEXT)

s1_i = s1.set_dim_id(isl.dim_type.set, 0, i_id)
s2_i = s2.set_dim_id(isl.dim_type.set, 0, j_id)

def assert_equal(a, b):
assert a == b
assert hash(a) == hash(b)

def assert_not_equal(a, b):
assert a != b
# not guaranteed, but highly likely
assert hash(a) != hash(b)

assert_equal(s1, s2)
assert_equal(s1_i, s2_i)
assert_not_equal(s3, s1)
assert_not_equal(s4, s1)


def test_basicset_hash():
# https://github.com/inducer/islpy/issues/102
# isl does not currently (2022-12-30) offer hashing for BasicSet.

a1 = isl.BasicSet("{[i]: 0<=i<512}")
a2 = isl.BasicSet("{[i]: 0<=i<512}")
assert hash(a1) == hash(a2)


if __name__ == "__main__":
import sys
if len(sys.argv) > 1:
Expand Down

0 comments on commit aba6644

Please sign in to comment.