22
33from __future__ import annotations
44
5- import itertools
65from collections import defaultdict
76from typing import Final , NamedTuple
87
@@ -233,11 +232,12 @@ def visit_singleton_pattern(self, o: SingletonPattern) -> PatternType:
233232
234233 def visit_sequence_pattern (self , o : SequencePattern ) -> PatternType :
235234 #
236- # check for existence of a starred pattern
235+ # Step 1. Check for existence of a starred pattern
237236 #
238237 current_type = get_proper_type (self .type_context [- 1 ])
239238 if not self .can_match_sequence (current_type ):
240239 return self .early_non_match ()
240+
241241 star_positions = [i for i , p in enumerate (o .patterns ) if isinstance (p , StarredPattern )]
242242 star_position : int | None = None
243243 if len (star_positions ) == 1 :
@@ -248,98 +248,88 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType:
248248 if star_position is not None :
249249 required_patterns -= 1
250250
251- # 1. Go through all possible types and filter to only those which are sequences that
252- # could match that number of items
253- # 2. If there is exactly one tuple left with an unpack, then use that type
254- # and the unpack index
255- # 3. Otherwise, take the product of the item types so that each index can have a
256- # unique type. For tuples with unpack fallback to merging all of their types
257- # for each index, since we can't handle multiple unpacked items at once yet.
258-
259- # Whether we have encountered a type that we don't know how to handle in the union
260- unknown_type = False
261- # A list of types that could match any of the items in the sequence.
262- sequence_types : list [Type ] = []
263- # A list of tuple types that could match the sequence, per index
264- tuple_types : list [list [Type ]] = []
265- # A list of all the unpack tuple types that we encountered, each containing the
266- # tuple type, unpack index, and union index
267- unpack_tuple_types : list [tuple [TupleType , int , int ]] = []
268- for i , t in enumerate (
269- current_type .items if isinstance (current_type , UnionType ) else [current_type ]
270- ):
271- t = get_proper_type (t )
272- if isinstance (t , TupleType ):
273- tuple_items = list (t .items )
274- unpack_index = find_unpack_in_list (tuple_items )
275- if unpack_index is None :
276- size_diff = len (tuple_items ) - required_patterns
277- if size_diff < 0 :
278- continue
279- if size_diff > 0 and star_position is None :
280- continue
281- if not size_diff and star_position is not None :
282- # Above we subtract from required_patterns if star_position is not None
283- tuple_items .append (UninhabitedType ())
284- tuple_types .append (tuple_items )
285- else :
286- normalized_inner_types = []
287- for it in tuple_items :
288- # Unfortunately, it is not possible to "split" the TypeVarTuple
289- # into individual items, so we just use its upper bound for the whole
290- # analysis instead.
291- if isinstance (it , UnpackType ) and isinstance (it .type , TypeVarTupleType ):
292- it = UnpackType (it .type .upper_bound )
293- normalized_inner_types .append (it )
294- if (
295- len (normalized_inner_types ) - 1 > required_patterns
296- and star_position is None
297- ):
298- continue
299- t = t .copy_modified (items = normalized_inner_types )
300- unpack_tuple_types .append ((t , unpack_index , i ))
301- # In case we have multiple unpacks we want to combine them all, so add
302- # the combined tuple type to the sequence types.
303- sequence_types .append (self .chk .iterable_item_type (tuple_fallback (t ), o ))
304- elif isinstance (t , AnyType ):
305- sequence_types .append (AnyType (TypeOfAny .from_another_any , t ))
306- elif self .chk .type_is_iterable (t ) and isinstance (t , Instance ):
307- sequence_types .append (self .chk .iterable_item_type (t , o ))
251+ #
252+ # Step 2. If we have a union, recurse and return the combined result
253+ #
254+ if isinstance (current_type , UnionType ):
255+ match_types : list [Type ] = []
256+ rest_types : list [Type ] = []
257+ captures_list : dict [Expression , list [Type ]] = {}
258+
259+ if star_position is not None :
260+ star_pattern = o .patterns [star_position ]
261+ assert isinstance (star_pattern , StarredPattern )
262+ star_expr = star_pattern .capture
308263 else :
309- unknown_type = True
310-
311- inner_types : list [Type ]
264+ star_expr = None
265+
266+ for t in current_type .items :
267+ match_type , rest_type , captures = self .accept (o , t )
268+ match_types .append (match_type )
269+ rest_types .append (rest_type )
270+ if not is_uninhabited (match_type ):
271+ for expr , typ in captures .items ():
272+ p_typ = get_proper_type (typ )
273+ if expr not in captures_list :
274+ captures_list [expr ] = []
275+ # Avoid adding in a list[Never] for empty list captures
276+ if (
277+ expr == star_expr
278+ and isinstance (p_typ , Instance )
279+ and p_typ .type .fullname == "builtins.list"
280+ and is_uninhabited (p_typ .args [0 ])
281+ ):
282+ continue
283+ captures_list [expr ].append (typ )
284+
285+ return PatternType (
286+ make_simplified_union (match_types ),
287+ make_simplified_union (rest_types ),
288+ {expr : make_simplified_union (types ) for expr , types in captures_list .items ()},
289+ )
312290
313- # If we only got one unpack tuple type, we can use that
291+ #
292+ # Step 3. Get inner types of original type
293+ #
314294 unpack_index = None
315- if len (unpack_tuple_types ) == 1 and len (sequence_types ) == 1 and not tuple_types :
316- update_tuple_type , unpack_index , union_index = unpack_tuple_types [0 ]
317- inner_types = update_tuple_type .items
318- if isinstance (current_type , UnionType ):
319- union_items = list (current_type .items )
320- union_items [union_index ] = update_tuple_type
321- current_type = get_proper_type (UnionType .make_union (items = union_items ))
295+ if isinstance (current_type , TupleType ):
296+ inner_types = current_type .items
297+ unpack_index = find_unpack_in_list (inner_types )
298+ if unpack_index is None :
299+ size_diff = len (inner_types ) - required_patterns
300+ if size_diff < 0 :
301+ return self .early_non_match ()
302+ elif size_diff > 0 and star_position is None :
303+ return self .early_non_match ()
322304 else :
323- current_type = update_tuple_type
324- # If we only got tuples we can't match, then exit early
325- elif not tuple_types and not sequence_types and not unknown_type :
326- return self .early_non_match ()
327- elif tuple_types :
328- inner_types = [
329- make_simplified_union ([* sequence_types , * [t for t in group if t is not None ]])
330- for group in itertools .zip_longest (* tuple_types )
331- ]
332- elif sequence_types :
333- inner_types = [make_simplified_union (sequence_types )] * len (o .patterns )
305+ normalized_inner_types = []
306+ for it in inner_types :
307+ # Unfortunately, it is not possible to "split" the TypeVarTuple
308+ # into individual items, so we just use its upper bound for the whole
309+ # analysis instead.
310+ if isinstance (it , UnpackType ) and isinstance (it .type , TypeVarTupleType ):
311+ it = UnpackType (it .type .upper_bound )
312+ normalized_inner_types .append (it )
313+ inner_types = normalized_inner_types
314+ current_type = current_type .copy_modified (items = normalized_inner_types )
315+ if len (inner_types ) - 1 > required_patterns and star_position is None :
316+ return self .early_non_match ()
317+ elif isinstance (current_type , AnyType ):
318+ inner_type : Type = AnyType (TypeOfAny .from_another_any , current_type )
319+ inner_types = [inner_type ] * len (o .patterns )
320+ elif isinstance (current_type , Instance ) and self .chk .type_is_iterable (current_type ):
321+ inner_type = self .chk .iterable_item_type (current_type , o )
322+ inner_types = [inner_type ] * len (o .patterns )
334323 else :
335- inner_types = [self .chk .named_type ("builtins.object" )] * len (o .patterns )
324+ inner_type = self .chk .named_type ("builtins.object" )
325+ inner_types = [inner_type ] * len (o .patterns )
336326
337327 #
338- # match inner patterns
328+ # Step 4. Match inner patterns
339329 #
340330 contracted_new_inner_types : list [Type ] = []
341331 contracted_rest_inner_types : list [Type ] = []
342- captures : dict [Expression , Type ] = {}
332+ captures = {} # dict[Expression, Type]
343333
344334 contracted_inner_types = self .contract_starred_pattern_types (
345335 inner_types , star_position , required_patterns
@@ -359,10 +349,10 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType:
359349 )
360350
361351 #
362- # Calculate new type
352+ # Step 5. Calculate new type
363353 #
364354 new_type : Type
365- rest_type : Type = current_type
355+ rest_type = current_type
366356 if isinstance (current_type , TupleType ) and unpack_index is None :
367357 narrowed_inner_types = []
368358 inner_rest_types = []
0 commit comments