Skip to content

Commit 9752e19

Browse files
authored
Further improve match statement narrowing against unions (#20744)
I realised there is a simpler and more complete approach than the one taken in #19600. This adds the new "Step 2" to the original code. Best way to review the diff is probably to check it out and review it squashed with #19600 , but really the net logic change is just "Step 2" Fixes #15190 Fixes #17549 Fixes #17600 Mostly fixes #18039 Helps with things in #19081
1 parent fbd4cb6 commit 9752e19

File tree

2 files changed

+100
-106
lines changed

2 files changed

+100
-106
lines changed

mypy/checkpattern.py

Lines changed: 76 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from __future__ import annotations
44

5-
import itertools
65
from collections import defaultdict
76
from typing import Final, NamedTuple
87

@@ -233,11 +232,12 @@ def visit_singleton_pattern(self, o: SingletonPattern) -> PatternType:
233232

234233
def visit_sequence_pattern(self, o: SequencePattern) -> PatternType:
235234
#
236-
# check for existence of a starred pattern
235+
# Step 1. Check for existence of a starred pattern
237236
#
238237
current_type = get_proper_type(self.type_context[-1])
239238
if not self.can_match_sequence(current_type):
240239
return self.early_non_match()
240+
241241
star_positions = [i for i, p in enumerate(o.patterns) if isinstance(p, StarredPattern)]
242242
star_position: int | None = None
243243
if len(star_positions) == 1:
@@ -248,98 +248,88 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType:
248248
if star_position is not None:
249249
required_patterns -= 1
250250

251-
# 1. Go through all possible types and filter to only those which are sequences that
252-
# could match that number of items
253-
# 2. If there is exactly one tuple left with an unpack, then use that type
254-
# and the unpack index
255-
# 3. Otherwise, take the product of the item types so that each index can have a
256-
# unique type. For tuples with unpack fallback to merging all of their types
257-
# for each index, since we can't handle multiple unpacked items at once yet.
258-
259-
# Whether we have encountered a type that we don't know how to handle in the union
260-
unknown_type = False
261-
# A list of types that could match any of the items in the sequence.
262-
sequence_types: list[Type] = []
263-
# A list of tuple types that could match the sequence, per index
264-
tuple_types: list[list[Type]] = []
265-
# A list of all the unpack tuple types that we encountered, each containing the
266-
# tuple type, unpack index, and union index
267-
unpack_tuple_types: list[tuple[TupleType, int, int]] = []
268-
for i, t in enumerate(
269-
current_type.items if isinstance(current_type, UnionType) else [current_type]
270-
):
271-
t = get_proper_type(t)
272-
if isinstance(t, TupleType):
273-
tuple_items = list(t.items)
274-
unpack_index = find_unpack_in_list(tuple_items)
275-
if unpack_index is None:
276-
size_diff = len(tuple_items) - required_patterns
277-
if size_diff < 0:
278-
continue
279-
if size_diff > 0 and star_position is None:
280-
continue
281-
if not size_diff and star_position is not None:
282-
# Above we subtract from required_patterns if star_position is not None
283-
tuple_items.append(UninhabitedType())
284-
tuple_types.append(tuple_items)
285-
else:
286-
normalized_inner_types = []
287-
for it in tuple_items:
288-
# Unfortunately, it is not possible to "split" the TypeVarTuple
289-
# into individual items, so we just use its upper bound for the whole
290-
# analysis instead.
291-
if isinstance(it, UnpackType) and isinstance(it.type, TypeVarTupleType):
292-
it = UnpackType(it.type.upper_bound)
293-
normalized_inner_types.append(it)
294-
if (
295-
len(normalized_inner_types) - 1 > required_patterns
296-
and star_position is None
297-
):
298-
continue
299-
t = t.copy_modified(items=normalized_inner_types)
300-
unpack_tuple_types.append((t, unpack_index, i))
301-
# In case we have multiple unpacks we want to combine them all, so add
302-
# the combined tuple type to the sequence types.
303-
sequence_types.append(self.chk.iterable_item_type(tuple_fallback(t), o))
304-
elif isinstance(t, AnyType):
305-
sequence_types.append(AnyType(TypeOfAny.from_another_any, t))
306-
elif self.chk.type_is_iterable(t) and isinstance(t, Instance):
307-
sequence_types.append(self.chk.iterable_item_type(t, o))
251+
#
252+
# Step 2. If we have a union, recurse and return the combined result
253+
#
254+
if isinstance(current_type, UnionType):
255+
match_types: list[Type] = []
256+
rest_types: list[Type] = []
257+
captures_list: dict[Expression, list[Type]] = {}
258+
259+
if star_position is not None:
260+
star_pattern = o.patterns[star_position]
261+
assert isinstance(star_pattern, StarredPattern)
262+
star_expr = star_pattern.capture
308263
else:
309-
unknown_type = True
310-
311-
inner_types: list[Type]
264+
star_expr = None
265+
266+
for t in current_type.items:
267+
match_type, rest_type, captures = self.accept(o, t)
268+
match_types.append(match_type)
269+
rest_types.append(rest_type)
270+
if not is_uninhabited(match_type):
271+
for expr, typ in captures.items():
272+
p_typ = get_proper_type(typ)
273+
if expr not in captures_list:
274+
captures_list[expr] = []
275+
# Avoid adding in a list[Never] for empty list captures
276+
if (
277+
expr == star_expr
278+
and isinstance(p_typ, Instance)
279+
and p_typ.type.fullname == "builtins.list"
280+
and is_uninhabited(p_typ.args[0])
281+
):
282+
continue
283+
captures_list[expr].append(typ)
284+
285+
return PatternType(
286+
make_simplified_union(match_types),
287+
make_simplified_union(rest_types),
288+
{expr: make_simplified_union(types) for expr, types in captures_list.items()},
289+
)
312290

313-
# If we only got one unpack tuple type, we can use that
291+
#
292+
# Step 3. Get inner types of original type
293+
#
314294
unpack_index = None
315-
if len(unpack_tuple_types) == 1 and len(sequence_types) == 1 and not tuple_types:
316-
update_tuple_type, unpack_index, union_index = unpack_tuple_types[0]
317-
inner_types = update_tuple_type.items
318-
if isinstance(current_type, UnionType):
319-
union_items = list(current_type.items)
320-
union_items[union_index] = update_tuple_type
321-
current_type = get_proper_type(UnionType.make_union(items=union_items))
295+
if isinstance(current_type, TupleType):
296+
inner_types = current_type.items
297+
unpack_index = find_unpack_in_list(inner_types)
298+
if unpack_index is None:
299+
size_diff = len(inner_types) - required_patterns
300+
if size_diff < 0:
301+
return self.early_non_match()
302+
elif size_diff > 0 and star_position is None:
303+
return self.early_non_match()
322304
else:
323-
current_type = update_tuple_type
324-
# If we only got tuples we can't match, then exit early
325-
elif not tuple_types and not sequence_types and not unknown_type:
326-
return self.early_non_match()
327-
elif tuple_types:
328-
inner_types = [
329-
make_simplified_union([*sequence_types, *[t for t in group if t is not None]])
330-
for group in itertools.zip_longest(*tuple_types)
331-
]
332-
elif sequence_types:
333-
inner_types = [make_simplified_union(sequence_types)] * len(o.patterns)
305+
normalized_inner_types = []
306+
for it in inner_types:
307+
# Unfortunately, it is not possible to "split" the TypeVarTuple
308+
# into individual items, so we just use its upper bound for the whole
309+
# analysis instead.
310+
if isinstance(it, UnpackType) and isinstance(it.type, TypeVarTupleType):
311+
it = UnpackType(it.type.upper_bound)
312+
normalized_inner_types.append(it)
313+
inner_types = normalized_inner_types
314+
current_type = current_type.copy_modified(items=normalized_inner_types)
315+
if len(inner_types) - 1 > required_patterns and star_position is None:
316+
return self.early_non_match()
317+
elif isinstance(current_type, AnyType):
318+
inner_type: Type = AnyType(TypeOfAny.from_another_any, current_type)
319+
inner_types = [inner_type] * len(o.patterns)
320+
elif isinstance(current_type, Instance) and self.chk.type_is_iterable(current_type):
321+
inner_type = self.chk.iterable_item_type(current_type, o)
322+
inner_types = [inner_type] * len(o.patterns)
334323
else:
335-
inner_types = [self.chk.named_type("builtins.object")] * len(o.patterns)
324+
inner_type = self.chk.named_type("builtins.object")
325+
inner_types = [inner_type] * len(o.patterns)
336326

337327
#
338-
# match inner patterns
328+
# Step 4. Match inner patterns
339329
#
340330
contracted_new_inner_types: list[Type] = []
341331
contracted_rest_inner_types: list[Type] = []
342-
captures: dict[Expression, Type] = {}
332+
captures = {} # dict[Expression, Type]
343333

344334
contracted_inner_types = self.contract_starred_pattern_types(
345335
inner_types, star_position, required_patterns
@@ -359,10 +349,10 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType:
359349
)
360350

361351
#
362-
# Calculate new type
352+
# Step 5. Calculate new type
363353
#
364354
new_type: Type
365-
rest_type: Type = current_type
355+
rest_type = current_type
366356
if isinstance(current_type, TupleType) and unpack_index is None:
367357
narrowed_inner_types = []
368358
inner_rest_types = []

test-data/unit/check-python310.test

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1773,42 +1773,38 @@ match m4:
17731773
reveal_type(a4) # N: Revealed type is "builtins.str"
17741774
reveal_type(b4) # N: Revealed type is "builtins.str"
17751775

1776-
# properly handles unpack when all other patterns are not sequences
17771776
m5: tuple[int, Unpack[tuple[float, ...]]] | None
17781777
match m5:
17791778
case (a5, b5):
17801779
reveal_type(a5) # N: Revealed type is "builtins.int"
17811780
reveal_type(b5) # N: Revealed type is "builtins.float"
17821781

1783-
# currently can't handle combing unpacking with other sequence patterns, if this happens revert to worst case
1784-
# of combing all types
17851782
m6: tuple[int, Unpack[tuple[float, ...]]] | list[str]
17861783
match m6:
17871784
case (a6, b6):
1788-
reveal_type(a6) # N: Revealed type is "builtins.int | builtins.float | builtins.str"
1789-
reveal_type(b6) # N: Revealed type is "builtins.int | builtins.float | builtins.str"
1785+
reveal_type(a6) # N: Revealed type is "builtins.int | builtins.str"
1786+
reveal_type(b6) # N: Revealed type is "builtins.float | builtins.str"
17901787

1791-
# but do still separate types from non unpacked types
17921788
m7: tuple[int, Unpack[tuple[float, ...]]] | tuple[str, str]
17931789
match m7:
17941790
case (a7, b7, *rest7):
1795-
reveal_type(a7) # N: Revealed type is "builtins.int | builtins.float | builtins.str"
1796-
reveal_type(b7) # N: Revealed type is "builtins.int | builtins.float | builtins.str"
1797-
reveal_type(rest7) # N: Revealed type is "builtins.list[builtins.int | builtins.float]"
1791+
reveal_type(a7) # N: Revealed type is "builtins.int | builtins.str"
1792+
reveal_type(b7) # N: Revealed type is "builtins.float | builtins.str"
1793+
reveal_type(rest7) # N: Revealed type is "builtins.list[builtins.float]"
17981794

17991795
# verify that if we are unpacking, it will get the type of the sequence if the tuple is too short
18001796
m8: tuple[int, str] | list[float]
18011797
match m8:
18021798
case (a8, b8, *rest8):
1803-
reveal_type(a8) # N: Revealed type is "builtins.float | builtins.int"
1804-
reveal_type(b8) # N: Revealed type is "builtins.float | builtins.str"
1799+
reveal_type(a8) # N: Revealed type is "builtins.int | builtins.float"
1800+
reveal_type(b8) # N: Revealed type is "builtins.str | builtins.float"
18051801
reveal_type(rest8) # N: Revealed type is "builtins.list[builtins.float]"
18061802

18071803
m9: tuple[str, str, int] | tuple[str, str]
18081804
match m9:
18091805
case (a9, *rest9):
18101806
reveal_type(a9) # N: Revealed type is "builtins.str"
1811-
reveal_type(rest9) # N: Revealed type is "builtins.list[builtins.str | builtins.int]"
1807+
reveal_type(rest9) # N: Revealed type is "builtins.list[builtins.str | builtins.int] | builtins.list[builtins.str]"
18121808

18131809
[builtins fixtures/tuple.pyi]
18141810

@@ -2261,15 +2257,23 @@ match foo:
22612257
reveal_type(x) # N: Revealed type is "builtins.int"
22622258
[builtins fixtures/tuple.pyi]
22632259

2264-
[case testMatchUnionTwoTuplesNoCrash]
2265-
var: tuple[int, int] | tuple[str, str]
2260+
[case testMatchUnionTwoTuples]
2261+
# flags: --strict-equality --warn-unreachable
2262+
2263+
def main(var: tuple[int, int] | tuple[str, str]):
2264+
match var:
2265+
case (42, a):
2266+
reveal_type(a) # N: Revealed type is "builtins.int"
2267+
case ("yes", b):
2268+
reveal_type(b) # N: Revealed type is "builtins.str"
22662269

2267-
# TODO: we can infer better here.
2268-
match var:
2269-
case (42, a):
2270-
reveal_type(a) # N: Revealed type is "builtins.int | builtins.str"
2271-
case ("yes", b):
2272-
reveal_type(b) # N: Revealed type is "builtins.int | builtins.str"
2270+
2271+
def main2(var: tuple[int, int] | tuple[str, str] | tuple[str, int]):
2272+
match var:
2273+
case (42, a):
2274+
reveal_type(a) # N: Revealed type is "builtins.int"
2275+
case ("yes", b):
2276+
reveal_type(b) # N: Revealed type is "builtins.str | builtins.int"
22732277
[builtins fixtures/tuple.pyi]
22742278

22752279
[case testMatchNamedAndKeywordsAreTheSame]

0 commit comments

Comments
 (0)