Skip to content

Commit

Permalink
[3.12] gh-112281: Allow Union with unhashable Annotated metadata (G…
Browse files Browse the repository at this point in the history
…H-112283) (#116213)

Co-authored-by: Nikita Sobolev <[email protected]>
Co-authored-by: Alex Waygood <[email protected]>
  • Loading branch information
3 people authored Mar 1, 2024
1 parent 16be4a3 commit 90f75e1
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 18 deletions.
20 changes: 20 additions & 0 deletions Lib/test/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,26 @@ def test_hash(self):
self.assertEqual(hash(int | str), hash(str | int))
self.assertEqual(hash(int | str), hash(typing.Union[int, str]))

def test_union_of_unhashable(self):
class UnhashableMeta(type):
__hash__ = None

class A(metaclass=UnhashableMeta): ...
class B(metaclass=UnhashableMeta): ...

self.assertEqual((A | B).__args__, (A, B))
union1 = A | B
with self.assertRaises(TypeError):
hash(union1)

union2 = int | B
with self.assertRaises(TypeError):
hash(union2)

union3 = A | int
with self.assertRaises(TypeError):
hash(union3)

def test_instancecheck_and_subclasscheck(self):
for x in (int | str, typing.Union[int, str]):
with self.subTest(x=x):
Expand Down
107 changes: 103 additions & 4 deletions Lib/test/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
import collections
import collections.abc
from collections import defaultdict
from functools import lru_cache, wraps
from functools import lru_cache, wraps, reduce
import gc
import inspect
import itertools
import operator
import pickle
import re
import sys
Expand Down Expand Up @@ -1770,6 +1771,26 @@ def test_union_union(self):
v = Union[u, Employee]
self.assertEqual(v, Union[int, float, Employee])

def test_union_of_unhashable(self):
class UnhashableMeta(type):
__hash__ = None

class A(metaclass=UnhashableMeta): ...
class B(metaclass=UnhashableMeta): ...

self.assertEqual(Union[A, B].__args__, (A, B))
union1 = Union[A, B]
with self.assertRaises(TypeError):
hash(union1)

union2 = Union[int, B]
with self.assertRaises(TypeError):
hash(union2)

union3 = Union[A, int]
with self.assertRaises(TypeError):
hash(union3)

def test_repr(self):
self.assertEqual(repr(Union), 'typing.Union')
u = Union[Employee, int]
Expand Down Expand Up @@ -5295,10 +5316,8 @@ def some(self):
self.assertFalse(hasattr(WithOverride.some, "__override__"))

def test_multiple_decorators(self):
import functools

def with_wraps(f): # similar to `lru_cache` definition
@functools.wraps(f)
@wraps(f)
def wrapper(*args, **kwargs):
return f(*args, **kwargs)
return wrapper
Expand Down Expand Up @@ -8183,6 +8202,76 @@ def test_flatten(self):
self.assertEqual(A.__metadata__, (4, 5))
self.assertEqual(A.__origin__, int)

def test_deduplicate_from_union(self):
# Regular:
self.assertEqual(get_args(Annotated[int, 1] | int),
(Annotated[int, 1], int))
self.assertEqual(get_args(Union[Annotated[int, 1], int]),
(Annotated[int, 1], int))
self.assertEqual(get_args(Annotated[int, 1] | Annotated[int, 2] | int),
(Annotated[int, 1], Annotated[int, 2], int))
self.assertEqual(get_args(Union[Annotated[int, 1], Annotated[int, 2], int]),
(Annotated[int, 1], Annotated[int, 2], int))
self.assertEqual(get_args(Annotated[int, 1] | Annotated[str, 1] | int),
(Annotated[int, 1], Annotated[str, 1], int))
self.assertEqual(get_args(Union[Annotated[int, 1], Annotated[str, 1], int]),
(Annotated[int, 1], Annotated[str, 1], int))

# Duplicates:
self.assertEqual(Annotated[int, 1] | Annotated[int, 1] | int,
Annotated[int, 1] | int)
self.assertEqual(Union[Annotated[int, 1], Annotated[int, 1], int],
Union[Annotated[int, 1], int])

# Unhashable metadata:
self.assertEqual(get_args(str | Annotated[int, {}] | Annotated[int, set()] | int),
(str, Annotated[int, {}], Annotated[int, set()], int))
self.assertEqual(get_args(Union[str, Annotated[int, {}], Annotated[int, set()], int]),
(str, Annotated[int, {}], Annotated[int, set()], int))
self.assertEqual(get_args(str | Annotated[int, {}] | Annotated[str, {}] | int),
(str, Annotated[int, {}], Annotated[str, {}], int))
self.assertEqual(get_args(Union[str, Annotated[int, {}], Annotated[str, {}], int]),
(str, Annotated[int, {}], Annotated[str, {}], int))

self.assertEqual(get_args(Annotated[int, 1] | str | Annotated[str, {}] | int),
(Annotated[int, 1], str, Annotated[str, {}], int))
self.assertEqual(get_args(Union[Annotated[int, 1], str, Annotated[str, {}], int]),
(Annotated[int, 1], str, Annotated[str, {}], int))

import dataclasses
@dataclasses.dataclass
class ValueRange:
lo: int
hi: int
v = ValueRange(1, 2)
self.assertEqual(get_args(Annotated[int, v] | None),
(Annotated[int, v], types.NoneType))
self.assertEqual(get_args(Union[Annotated[int, v], None]),
(Annotated[int, v], types.NoneType))
self.assertEqual(get_args(Optional[Annotated[int, v]]),
(Annotated[int, v], types.NoneType))

# Unhashable metadata duplicated:
self.assertEqual(Annotated[int, {}] | Annotated[int, {}] | int,
Annotated[int, {}] | int)
self.assertEqual(Annotated[int, {}] | Annotated[int, {}] | int,
int | Annotated[int, {}])
self.assertEqual(Union[Annotated[int, {}], Annotated[int, {}], int],
Union[Annotated[int, {}], int])
self.assertEqual(Union[Annotated[int, {}], Annotated[int, {}], int],
Union[int, Annotated[int, {}]])

def test_order_in_union(self):
expr1 = Annotated[int, 1] | str | Annotated[str, {}] | int
for args in itertools.permutations(get_args(expr1)):
with self.subTest(args=args):
self.assertEqual(expr1, reduce(operator.or_, args))

expr2 = Union[Annotated[int, 1], str, Annotated[str, {}], int]
for args in itertools.permutations(get_args(expr2)):
with self.subTest(args=args):
self.assertEqual(expr2, Union[args])

def test_specialize(self):
L = Annotated[List[T], "my decoration"]
LI = Annotated[List[int], "my decoration"]
Expand All @@ -8203,6 +8292,16 @@ def test_hash_eq(self):
{Annotated[int, 4, 5], Annotated[int, 4, 5], Annotated[T, 4, 5]},
{Annotated[int, 4, 5], Annotated[T, 4, 5]}
)
# Unhashable `metadata` raises `TypeError`:
a1 = Annotated[int, []]
with self.assertRaises(TypeError):
hash(a1)

class A:
__hash__ = None
a2 = Annotated[int, A()]
with self.assertRaises(TypeError):
hash(a2)

def test_instantiate(self):
class C:
Expand Down
45 changes: 31 additions & 14 deletions Lib/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,19 +314,33 @@ def _unpack_args(args):
newargs.append(arg)
return newargs

def _deduplicate(params):
def _deduplicate(params, *, unhashable_fallback=False):
# Weed out strict duplicates, preserving the first of each occurrence.
all_params = set(params)
if len(all_params) < len(params):
new_params = []
for t in params:
if t in all_params:
new_params.append(t)
all_params.remove(t)
params = new_params
assert not all_params, all_params
return params

try:
return dict.fromkeys(params)
except TypeError:
if not unhashable_fallback:
raise
# Happens for cases like `Annotated[dict, {'x': IntValidator()}]`
return _deduplicate_unhashable(params)

def _deduplicate_unhashable(unhashable_params):
new_unhashable = []
for t in unhashable_params:
if t not in new_unhashable:
new_unhashable.append(t)
return new_unhashable

def _compare_args_orderless(first_args, second_args):
first_unhashable = _deduplicate_unhashable(first_args)
second_unhashable = _deduplicate_unhashable(second_args)
t = list(second_unhashable)
try:
for elem in first_unhashable:
t.remove(elem)
except ValueError:
return False
return not t

def _remove_dups_flatten(parameters):
"""Internal helper for Union creation and substitution.
Expand All @@ -341,7 +355,7 @@ def _remove_dups_flatten(parameters):
else:
params.append(p)

return tuple(_deduplicate(params))
return tuple(_deduplicate(params, unhashable_fallback=True))


def _flatten_literal_params(parameters):
Expand Down Expand Up @@ -1548,7 +1562,10 @@ def copy_with(self, params):
def __eq__(self, other):
if not isinstance(other, (_UnionGenericAlias, types.UnionType)):
return NotImplemented
return set(self.__args__) == set(other.__args__)
try: # fast path
return set(self.__args__) == set(other.__args__)
except TypeError: # not hashable, slow path
return _compare_args_orderless(self.__args__, other.__args__)

def __hash__(self):
return hash(frozenset(self.__args__))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Allow creating :ref:`union of types<types-union>` for
:class:`typing.Annotated` with unhashable metadata.

0 comments on commit 90f75e1

Please sign in to comment.