
    &Vfh                       U d Z ddlmZ ddlZddlZddlZddlZddlmZm	Z	m
Z
mZ ddlmZ ddlmZ ddlmZmZmZmZmZmZmZ ddlmZ dd	lmZ dd
lmZmZmZm Z m!Z!m"Z"m#Z#m$Z$m%Z%m&Z& ddl'm(Z(m)Z)m*Z* erddl+Z+g dZ,ej-        dk    rddini Z. ej/        dddddde. G d d                      Z0[. G d d          Z1 e1            Z2de3d<    e            Z4de3d<   [1dd!Z5e	 ddd"dd%            Z6edd&            Z6	 ddd"dd)Z6dd*Z7dd.Z8dd2Z9dd6Z:dd:Z;dd>Z<dd@Z=ddDZ>ddEZ?ddGZ@ddKZAddMZBddNZCddQZDddTZEddXZFdd[ZGdd^ZHdd`ZIdddZJddfZK eLd           e0 eLd          e:e;          eM e0eMe<e=          eN e0eNe>e?          eO e0eOe@eA          e e0eeHeI          e e0eeBeC          e	 e0e	eDeE          e
 e0e
eFeG          e& e0e&eJeK          i	ZPdge3dh<   e2d"ddkZQeQe5_R        [Q G dl dm          ZS e6e2"           G dn doejT        ee                               ZU G dp dqe          ZV G dr dse          ZW G dt dueV          ZX G dv dweV          ZY G dx dyeV          ZZee!geeV         f         Z[i Z\dze3d{<   dd~Z] e]eMd             e]eNd             e]eOd             e]ed             e]e	d             e]e
d            e\jR        e]_R        dS )z#OpTree: Optimized PyTree Utilities.    )annotationsN)OrderedDictdefaultdictdeque
namedtuple)methodcaller)Lock)TYPE_CHECKINGAnyCallableIterable
NamedTupleSequenceoverload)Self)_C)
KTVTCustomTreeNodeFlattenFuncPyTreeTUnflattenFuncis_namedtuple_classis_structseq_class	structseq)safe_ziptotal_order_sortedunzip2)register_pytree_noderegister_pytree_node_classunregister_pytree_nodePartialregister_keypathsAttributeKeyPathEntryGetitemKeyPathEntry)   
   slotsT)initrepreqfrozenc                  <    e Zd ZU ded<   ded<   ded<   dZded	<   d
S )PyTreeNodeRegistryEntryzbuiltins.typetyper   flatten_funcr   unflatten_func str	namespaceN)__name__
__module____qualname____annotations__r5        L/var/www/html/software/conda/lib/python3.11/site-packages/optree/registry.pyr/   r/   ?   sG         !!!!Ir;   r/   c                      e Zd ZddZdS )GlobalNamespacereturnr4   c                    dS )Nz<GLOBAL NAMESPACE>r:   selfs    r<   __repr__zGlobalNamespace.__repr__L   s    ##r;   Nr?   r4   )r6   r7   r8   rC   r:   r;   r<   r>   r>   K   s(        $ $ $ $ $ $r;   r>   r4   __GLOBAL_NAMESPACEr	   __REGISTRY_LOCKclstype[CustomTreeNode[T]]r1   r   r2   r   r5   r?   c                   t          j        |           st          d|  d          |t          ur(t	          |t
                    st          d| d          |dk    rt          d          |t          u r| }d}n|| f}t          5  t          j	        | |||           t          | |||          t          |<   ddd           n# 1 swxY w Y   | S )a   Extend the set of types that are considered internal nodes in pytrees.

    See also :func:`register_pytree_node_class` and :func:`unregister_pytree_node`.

    The ``namespace`` argument is used to avoid collisions that occur when different libraries
    register the same Python type with different behaviors. It is recommended to add a unique prefix
    to the namespace to avoid conflicts with other libraries. Namespaces can also be used to specify
    the same class in different namespaces for different use cases.

    .. warning::
        For safety reasons, a ``namespace`` must be specified while registering a custom type. It is
        used to isolate the behavior of flattening and unflattening a pytree node type. This is to
        prevent accidental collisions between different libraries that may register the same type.

    Args:
        cls (type): A Python type to treat as an internal pytree node.
        flatten_func (callable): A function to be used during flattening, taking an instance of ``cls``
            and returning a triple or optionally a pair, with (1) an iterable for the children to be
            flattened recursively, and (2) some hashable auxiliary data to be stored in the treespec
            and to be passed to the ``unflatten_func``, and (3) (optional) an iterable for the tree
            path entries to the corresponding children. If the entries are not provided or given by
            :data:`None`, then `range(len(children))` will be used.
        unflatten_func (callable): A function taking two arguments: the auxiliary data that was
            returned by ``flatten_func`` and stored in the treespec, and the unflattened children.
            The function should return an instance of ``cls``.
        namespace (str): A non-empty string that uniquely identifies the namespace of the type registry.
            This is used to isolate the registry from other modules that might register a different
            custom behavior for the same type.

    Returns:
        The same type as the input ``cls``.

    Raises:
        TypeError: If the input type is not a class.
        TypeError: If the namespace is not a string.
        ValueError: If the namespace is an empty string.
        ValueError: If the type is already registered in the registry.

    Examples:
        >>> # Registry a Python type with lambda functions
        >>> register_pytree_node(
        ...     set,
        ...     lambda s: (sorted(s), None, None),
        ...     lambda _, children: set(children),
        ...     namespace='set',
        ... )
        <class 'set'>

        >>> # Register a Python type into a namespace
        >>> import torch
        >>> register_pytree_node(
        ...     torch.Tensor,
        ...     flatten_func=lambda tensor: (
        ...         (tensor.cpu().detach().numpy(),),
        ...         {'dtype': tensor.dtype, 'device': tensor.device, 'requires_grad': tensor.requires_grad},
        ...     ),
        ...     unflatten_func=lambda metadata, children: torch.tensor(children[0], **metadata),
        ...     namespace='torch2numpy',
        ... )
        <class 'torch.Tensor'>

        >>> # doctest: +SKIP
        >>> tree = {'weight': torch.ones(size=(1, 2)).cuda(), 'bias': torch.zeros(size=(2,))}
        >>> tree
        {'weight': tensor([[1., 1.]], device='cuda:0'), 'bias': tensor([0., 0.])}

        >>> # Flatten without specifying the namespace
        >>> tree_flatten(tree)  # `torch.Tensor`s are leaf nodes
        ([tensor([0., 0.]), tensor([[1., 1.]], device='cuda:0')], PyTreeSpec({'bias': *, 'weight': *}))

        >>> # Flatten with the namespace
        >>> tree_flatten(tree, namespace='torch2numpy')
        (
            [array([0., 0.], dtype=float32), array([[1., 1.]], dtype=float32)],
            PyTreeSpec(
                {
                    'bias': CustomTreeNode(Tensor[{'dtype': torch.float32, 'device': device(type='cpu'), 'requires_grad': False}], [*]),
                    'weight': CustomTreeNode(Tensor[{'dtype': torch.float32, 'device': device(type='cuda', index=0), 'requires_grad': False}], [*])
                },
                namespace='torch2numpy'
            )
        )

        >>> # Register the same type with a different namespace for different behaviors
        >>> def tensor2flatparam(tensor):
        ...     return [torch.nn.Parameter(tensor.reshape(-1))], tensor.shape, None
        ...
        ... def flatparam2tensor(metadata, children):
        ...     return children[0].reshape(metadata)
        ...
        ... register_pytree_node(
        ...     torch.Tensor,
        ...     flatten_func=tensor2flatparam,
        ...     unflatten_func=flatparam2tensor,
        ...     namespace='tensor2flatparam',
        ... )
        <class 'torch.Tensor'>

        >>> # Flatten with the new namespace
        >>> tree_flatten(tree, namespace='tensor2flatparam')
        (
            [
                Parameter containing: tensor([0., 0.], requires_grad=True),
                Parameter containing: tensor([1., 1.], device='cuda:0', requires_grad=True)
            ],
            PyTreeSpec(
                {
                    'bias': CustomTreeNode(Tensor[torch.Size([2])], [*]),
                    'weight': CustomTreeNode(Tensor[torch.Size([1, 2])], [*])
                },
                namespace='tensor2flatparam'
            )
        )
    Expected a class, got .$The namespace must be a string, got r3   (The namespace cannot be an empty string.r5   N)inspectisclass	TypeErrorrE   
isinstancer4   
ValueErrorrF   r   register_noder/   _NODETYPE_REGISTRY)rG   r1   r2   r5   registration_keys        r<   r    r    U   sD   p ?3 97777888***:i3M3M*KyKKKLLLBCDDD &&&		%s+	 
 

lNIFFF/F	0
 0
 0
+,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 Js   3CC
C
rN   
str | None<Callable[[type[CustomTreeNode[T]]], type[CustomTreeNode[T]]]c                   d S Nr:   rG   r5   s     r<   r!   r!      	     Cr;   c                   d S rZ   r:   r[   s     r<   r!   r!      r\   r;   $type[CustomTreeNode[T]] | str | NoneVtype[CustomTreeNode[T]] | Callable[[type[CustomTreeNode[T]]], type[CustomTreeNode[T]]]c               @   | t           u st          | t                    rA|t          d          | dk    rt          d          t	          j        t          |           S |t          d          |t           ur't          |t                    st          d|           |dk    rt          d          | t	          j        t          |          S t          j	        |           st          d|  d	          t          | t          d
          | j        |           | S )a	  Extend the set of types that are considered internal nodes in pytrees.

    See also :func:`register_pytree_node` and :func:`unregister_pytree_node`.

    The ``namespace`` argument is used to avoid collisions that occur when different libraries
    register the same Python type with different behaviors. It is recommended to add a unique prefix
    to the namespace to avoid conflicts with other libraries. Namespaces can also be used to specify
    the same class in different namespaces for different use cases.

    .. warning::
        For safety reasons, a ``namespace`` must be specified while registering a custom type. It is
        used to isolate the behavior of flattening and unflattening a pytree node type. This is to
        prevent accidental collisions between different libraries that may register the same type.

    Args:
        cls (type, optional): A Python type to treat as an internal pytree node.
        namespace (str, optional): A non-empty string that uniquely identifies the namespace of the
            type registry. This is used to isolate the registry from other modules that might
            register a different custom behavior for the same type.

    Returns:
        The same type as the input ``cls`` if the argument presents. Otherwise, return a decorator
        function that registers the class as a pytree node.

    Raises:
        TypeError: If the namespace is not a string.
        ValueError: If the namespace is an empty string.
        ValueError: If the type is already registered in the registry.

    This function is a thin wrapper around :func:`register_pytree_node`, and provides a
    class-oriented interface::

        @register_pytree_node_class(namespace='foo')
        class Special:
            def __init__(self, x, y):
                self.x = x
                self.y = y

            def tree_flatten(self):
                return ((self.x, self.y), None)

            @classmethod
            def tree_unflatten(cls, metadata, children):
                return cls(*children)

        @register_pytree_node_class('mylist')
        class MyList(UserList):
            def tree_flatten(self):
                return self.data, None, None

            @classmethod
            def tree_unflatten(cls, metadata, children):
                return cls(*children)
    Nz?Cannot specify `namespace` when the first argument is a string.r3   rM   rN   z<Must specify `namespace` when the first argument is a class.rL   rJ   rK   tree_flatten)rE   rR   r4   rS   	functoolspartialr!   rQ   rO   rP   r    r   tree_unflattenr[   s     r<   r!   r!      s0   v    JsC$8$8  ^___"99GHHH !;sKKKKWXXX***:i3M3M*JyJJKKKBCDDD
{ !;yQQQQ?3 97777888l>::C<NPYZZZJr;   c                  t          j        |           st          d|  d          |t          ur(t	          |t
                    st          d| d          |dk    rt          d          |t          u r| }d}n|| f}t          5  t          j	        | |           t                              |          cddd           S # 1 swxY w Y   dS )a  Remove a type from the pytree node registry.

    See also :func:`register_pytree_node` and :func:`register_pytree_node_class`.

    This function is the inverse operation of function :func:`register_pytree_node`.

    Args:
        cls (type): A Python type to remove from the pytree node registry.
        namespace (str): The namespace of the pytree node registry to remove the type from.

    Returns:
        The removed registry entry.

    Raises:
        TypeError: If the input type is not a class.
        TypeError: If the namespace is not a string.
        ValueError: If the namespace is an empty string.
        ValueError: If the type is a built-in type that cannot be unregistered.
        ValueError: If the type is not found in the registry.

    Examples:
        >>> # Register a Python type with lambda functions
        >>> register_pytree_node(
        ...     set,
        ...     lambda s: (sorted(s), None, None),
        ...     lambda _, children: set(children),
        ...     namespace='temp',
        ... )
        <class 'set'>

        >>> # Unregister the Python type
        >>> unregister_pytree_node(set, namespace='temp')
    rJ   rK   rL   r3   rM   N)rO   rP   rQ   rE   rR   r4   rS   rF   r   unregister_noderU   pop)rG   r5   rV   s      r<   r"   r"   I  s3   L ?3 97777888***:i3M3M*KyKKKLLLBCDDD &&&		%s+	 8 8
3	***!%%&6778 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8s   /CC
CitemsIterable[tuple[KT, VT]]list[tuple[KT, VT]]c                &    t          | d           S )Nc                    | d         S )Nr   r:   )kvs    r<   <lambda>z_sorted_items.<locals>.<lambda>  s
    BqE r;   keyr   )rh   s    r<   _sorted_itemsrr     s    e)9)9::::r;   dctdict[KT, VT]list[KT]c                     t          |           S rZ   rq   rs   s    r<   _sorted_keysrx     s    c"""r;   noneNonetuple[tuple[()], None]c                    dS )N)r:   Nr:   )ry   s    r<   _none_flattenr}     s    8r;   _childrenIterable[Any]c                ~    t                      }t          t          |          |          |urt          d          d S )NzExpected no children.)objectnextiterrS   )r~   r   sentinels      r<   _none_unflattenr     s:    xxHDNNH%%X5501114r;   tuptuple[T, ...]tuple[tuple[T, ...], None]c                
    | d fS rZ   r:   r   s    r<   _tuple_flattenr         9r;   Iterable[T]c                     t          |          S rZ   )tupler~   r   s     r<   _tuple_unflattenr     s    ??r;   lstlist[T]tuple[list[T], None]c                
    | d fS rZ   r:   r   s    r<   _list_flattenr     r   r;   c                     t          |          S rZ   )listr   s     r<   _list_unflattenr     s    >>r;   /tuple[tuple[VT, ...], list[KT], tuple[KT, ...]]c                    t          t          |                                                     \  }}|t          |          |fS rZ   )r   rr   rh   r   rs   keysvaluess      r<   _dict_flattenr     s6    -		4455LD&4::t##r;   r   r   Iterable[VT]c                <    t          t          | |                    S rZ   )dictr   r   r   s     r<   _dict_unflattenr     s    v&&'''r;   OrderedDict[KT, VT]c                n    t          |                                           \  }}|t          |          |fS rZ   )r   rh   r   r   s      r<   _ordereddict_flattenr     s0     #))++&&LD&4::t##r;   c                <    t          t          | |                    S rZ   )r   r   r   s     r<   _ordereddict_unflattenr     s    xf--...r;   defaultdict[KT, VT]Otuple[tuple[VT, ...], tuple[Callable[[], VT] | None, list[KT]], tuple[KT, ...]]c                Z    t          |           \  }}}|| j        t          |          f|fS rZ   )r   default_factoryr   )rs   r   r   entriess       r<   _defaultdict_flattenr     s3     *#..FD'C'd4g==r;   metadata!tuple[Callable[[], VT], list[KT]]c                H    | \  }}t          |t          ||                    S rZ   )r   r   )r   r   r   r   s       r<   _defaultdict_unflattenr     s(     %OTv(>(>???r;   deqdeque[T]tuple[deque[T], int | None]c                    | | j         fS rZ   maxlen)r   s    r<   _deque_flattenr     s    
?r;   r   
int | Nonec                $    t          ||           S )Nr   )r   )r   r   s     r<   _deque_unflattenr     s    &))))r;   NamedTuple[T])tuple[tuple[T, ...], type[NamedTuple[T]]]c                $    | t          |           fS rZ   r0   r   s    r<   _namedtuple_flattenr         S		>r;   type[NamedTuple[T]]c                     | | S rZ   r:   rG   r   s     r<   _namedtuple_unflattenr     s    3>r;   seqstructseq[T](tuple[tuple[T, ...], type[structseq[T]]]c                $    | t          |           fS rZ   r   )r   s    r<   _structseq_flattenr     r   r;   type[structseq[T]]c                     | |          S rZ   r:   r   s     r<   _structseq_unflattenr     s    3x==r;   z6dict[type | tuple[str, type], PyTreeNodeRegistryEntry]rU   r0   PyTreeNodeRegistryEntry | Nonec               :   t                               |           }||S t                               || f          }||S t          |           rt                               t                    S t	          |           rt                               t
                    S d S rZ   )rU   getr   r   r   r   )rG   r5   handlers      r<   _pytree_node_registry_getr     s    
 /A.D.DS.I.IG $$i%566G# 1!%%i0003 2!%%j1114r;   c                  R    e Zd ZU dZded<   ded<   ded<   ddZddZddZddZdS )_HashablePartialShimz_Object that delegates :meth:`__call__`, :meth:`__hash__`, and :meth:`__eq__` to another object.Callable[..., Any]functuple[Any, ...]argsdict[str, Any]keywordspartial_funcfunctools.partialr?   rz   c                    || _         d S rZ   r   )rB   r   s     r<   __init__z_HashablePartialShim.__init__  s    /;r;   r   kwargsc                     | j         |i |S rZ   r   )rB   r   r   s      r<   __call__z_HashablePartialShim.__call__  s     t $1&111r;   intc                *    t          | j                  S rZ   )hashr   rA   s    r<   __hash__z_HashablePartialShim.__hash__  s    D%&&&r;   otherr   boolc                b    t          |t                    r| j        |j        k    S | j        |k    S rZ   )rR   r   r   rB   r   s     r<   __eq__z_HashablePartialShim.__eq__  s4    e122 	;$(::: E))r;   N)r   r   r?   rz   )r   r   r   r   r?   r   )r?   r   r   r   r?   r   )	r6   r7   r8   __doc__r9   r   r   r   r   r:   r;   r<   r   r     s         ii< < < <2 2 2 2' ' ' '* * * * * *r;   r   c                  d     e Zd ZU dZded<   ded<   ded<   d fdZddZedd            Z xZ	S )r#   a  A version of :func:`functools.partial` that works in pytrees.

    Use it for partial function evaluation in a way that is compatible with transformations,
    e.g., ``Partial(func, *args, **kwargs)``.

    (You need to explicitly opt-in to this behavior because we did not want to give
    :func:`functools.partial` different semantics than normal function closures.)

    For example, here is a basic usage of :class:`Partial` in a manner similar to
    :func:`functools.partial`:

    >>> import operator
    >>> import torch
    >>> add_one = Partial(operator.add, torch.ones(()))
    >>> add_one(torch.tensor([[1, 2], [3, 4]]))
    tensor([[2., 3.],
            [4., 5.]])

    Pytree compatibility means that the resulting partial function can be passed as an argument
    within tree-map functions, which is not possible with a standard :func:`functools.partial`
    function:

    >>> def call_func_on_cuda(f, *args, **kwargs):
    ...     f, args, kwargs = tree_map(lambda t: t.cuda(), (f, args, kwargs))
    ...     return f(*args, **kwargs)
    ...
    >>> # doctest: +SKIP
    >>> tree_map(lambda t: t.cuda(), add_one)
    Partial(<built-in function add>, tensor(1., device='cuda:0'))
    >>> call_func_on_cuda(add_one, torch.tensor([[1, 2], [3, 4]]))
    tensor([[2., 3.],
            [4., 5.]], device='cuda:0')

    Passing zero arguments to :class:`Partial` effectively wraps the original function, making it a
    valid argument in tree-map functions:

    >>> # doctest: +SKIP
    >>> call_func_on_cuda(Partial(torch.add), torch.tensor(1), torch.tensor(2))
    tensor(3, device='cuda:0')

    Had we passed :func:`operator.add` to ``call_func_on_cuda`` directly, it would have resulted in
    a :class:`TypeError` or :class:`AttributeError`.
    r   r   r   r   r   r   r   r?   r   c                R   t          |t          j                  ro|}t          |          }t	          |d          r
J d             t                      j        | |g|R i |}|j        |_        |j        |_        |j	        |_	        |S  t                      j        | |g|R i |S )z'Create a new :class:`Partial` instance.r   z3shimmed function should not have a `func` attribute)
rR   rb   rc   r   hasattrsuper__new__r   r   r   )rG   r   r   r   original_funcout	__class__s         r<   r   zPartial.__new__M  s     dI-.. 	 M'66DtV,,cc.cccc!%''/#t?d???h??C%*DI%*DI)2DMJuwwsD<4<<<8<<<r;   Atuple[tuple[tuple[Any, ...], dict[str, Any]], Callable[..., Any]]c                ,    | j         | j        f| j        fS )zEFlatten the :class:`Partial` instance to children and auxiliary data.)r   r   r   rA   s    r<   ra   zPartial.tree_flatten_  s    	4=)4944r;   r   r   &tuple[tuple[Any, ...], dict[str, Any]]c                $    |\  }} | |g|R i |S )zKUnflatten the children and auxiliary data into a :class:`Partial` instance.r:   )rG   r   r   r   r   s        r<   rd   zPartial.tree_unflattenc  s-     "hs8/d///h///r;   )r   r   r   r   r   r   r?   r   )r?   r   )r   r   r   r   r?   r   )
r6   r7   r8   r   r9   r   ra   classmethodrd   __classcell__)r   s   @r<   r#   r#     s         * *X = = = = = =$5 5 5 5 0 0 0 [0 0 0 0 0r;   r#   c                  2    e Zd ZU ded<   ddZdd	ZddZdS )KeyPathEntryr   rp   r   r   r?   KeyPathc                    t          |t                    rt          | |f          S t          |t                    rt          | g|j        R           S t          S rZ   rR   r  r  r   NotImplementedr   s     r<   __add__zKeyPathEntry.__add__q  sY    e\** 	*D%=)))eW%% 	0D.5:..///r;   r   c                L    t          || j                  o| j        |j        k    S rZ   )rR   r   rp   r   s     r<   r   zKeyPathEntry.__eq__x  s"    %00JTX5JJr;   r4   c                    t           )"Pretty name of the key path entry.)NotImplementedErrorrA   s    r<   pprintzKeyPathEntry.pprint{  s    !!r;   Nr   r   r?   r  r   rD   )r6   r7   r8   r9   r  r   r  r:   r;   r<   r  r  n  s_         HHH   K K K K" " " " " "r;   r  c                  6    e Zd ZU dZded<   ddZdd	ZddZdS )r  r:   ztuple[KeyPathEntry, ...]r   r   r   r?   c                    t          |t                    rt          g | j        |R           S t          |t                    rt          | j        |j        z             S t          S rZ   r  r   s     r<   r  zKeyPath.__add__  sa    e\** 	0.TY...///eW%% 	349uz1222r;   r   c                L    t          |t                    o| j        |j        k    S rZ   )rR   r  r   r   s     r<   r   zKeyPath.__eq__  s     %))Edi5:.EEr;   r4   c                \    | j         sdS d                    d | j         D                       S )zPretty name of the key path.z
 tree rootr3   c              3  >   K   | ]}|                                 V  d S rZ   )r  ).0ks     r<   	<genexpr>z!KeyPath.pprint.<locals>.<genexpr>  s*      55aqxxzz555555r;   )r   joinrA   s    r<   r  zKeyPath.pprint  s4    y 	 <ww5549555555r;   Nr  r   rD   )r6   r7   r8   r   r9   r  r   r  r:   r;   r<   r  r    sg         %'D''''   F F F F6 6 6 6 6 6r;   r  c                      e Zd ZdZddZdS )r&   z8The key path entry class for sequences and dictionaries.r?   r4   c                    d| j         dS )r	  []ro   rA   s    r<   r  zGetitemKeyPathEntry.pprint  s     48    r;   NrD   r6   r7   r8   r   r  r:   r;   r<   r&   r&     s.        BB! ! ! ! ! !r;   r&   c                      e Zd ZdZddZdS )r%   z)The key path entry class for namedtuples.r?   r4   c                    d| j          S )r	  rK   ro   rA   s    r<   r  zAttributeKeyPathEntry.pprint  s    48~~r;   NrD   r  r:   r;   r<   r%   r%     s.        33     r;   r%   c                      e Zd ZdZddZdS )FlattenedKeyPathEntryz"The fallback key path entry class.r?   r4   c                    d| j          dS )r	  z[<flat index z>]ro   rA   s    r<   r  zFlattenedKeyPathEntry.pprint  s    +tx++++r;   NrD   r  r:   r;   r<   r  r    s.        ,,, , , , , ,r;   r  z*dict[type[CustomTreeNode], KeyPathHandler]_KEYPATH_REGISTRYr   KeyPathHandlerc                    t          j        |           st          d|  d          | t          v rt	          d|  d          |t          | <   |S )z:Register a key path handler for a custom pytree node type.rJ   rK   zKey path handler for z has already been registered.)rO   rP   rQ   r   rS   )rG   r   s     r<   r$   r$     sf    
 ?3 97777888
SSSSTTT$cNr;   c           
     z    t          t          t          t          t	          |                                         S rZ   r   mapr&   rangelenr   s    r<   rn   rn     s%    T#.A5S??*S*S%T%T r;   c           
     z    t          t          t          t          t	          |                                         S rZ   r$  r   s    r<   rn   rn     s%    D-@%C//)R)R$S$S r;   c                `    t          t          t          t          |                               S rZ   r   r%  r&   rx   rw   s    r<   rn   rn     s!    D-@,sBSBS)T)T$U$U r;   c                F    t          t          t          |                     S rZ   )r   r%  r&   )odcts    r<   rn   rn     s    D5H$1O1O,P,P r;   c                `    t          t          t          t          |                               S rZ   r*  )ddcts    r<   rn   rn     s"    D5H,W[J\J\1]1],^,^ r;   c           
     z    t          t          t          t          t	          |                                         S rZ   r$  )dqs    r<   rn   rn     s%    D-@%B..)Q)Q$R$R r;   r:   )
rG   rH   r1   r   r2   r   r5   r4   r?   rH   rZ   )rG   rW   r5   rW   r?   rX   )rG   rH   r5   r4   r?   rH   )rG   r^   r5   rW   r?   r_   )rG   rH   r5   r4   r?   r/   )rh   ri   r?   rj   )rs   rt   r?   ru   )ry   rz   r?   r{   )r~   rz   r   r   r?   rz   )r   r   r?   r   )r~   rz   r   r   r?   r   )r   r   r?   r   )r~   rz   r   r   r?   r   )rs   rt   r?   r   )r   ru   r   r   r?   rt   )rs   r   r?   r   )r   ru   r   r   r?   r   )rs   r   r?   r   )r   r   r   r   r?   r   )r   r   r?   r   )r   r   r   r   r?   r   )r   r   r?   r   )rG   r   r   r   r?   r   )r   r   r?   r   )rG   r   r   r   r?   r   )rG   r0   r5   r4   r?   r   )rG   rH   r   r!  r?   r!  )^r   
__future__r   dataclassesrb   rO   syscollectionsr   r   r   r   operatorr   	threadingr	   typingr
   r   r   r   r   r   r   typing_extensionsr   optreer   optree.typingr   r   r   r   r   r   r   r   r   r   optree.utilsr   r   r   builtins__all__version_infoSLOTS	dataclassr/   r>   rE   r9   rF   r    r!   r"   rr   rx   r}   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r0   r   r   r   rU   r   r   r   rc   r#   r  r  r&   r%   r  r!  r   r$   r:   r;   r<   <module>rA     s!   * ) ) " " " " " "          



 C C C C C C C C C C C C ! ! ! ! ! !       Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y " " " " " "                              > = = = = = = = = =  OOO   +w66$B KDtTKKUKK       LK 
$ $ $ $ $ $ $ $
 */++  + + + +    N N N Nb 
 !     
 
   
 15N !N N N N N Nb68 68 68 68r; ; ; ;# # # #                  $ $ $ $
( ( ( ($ $ $ $/ / / /> > > >@ @ @ @   * * * *             	DJJ''T

M?SS	""5.:JKK
!
!$
G
G
!
!$
G
G''
4GI^__((6JLbcc((6JLbcc	""5.:JKK&&y2DFZ[[
N  
 
 
 
" (     $ 5  * * * * * * * *, &8999O0 O0 O0 O0 O0i!4 O0 O0 :9O0d" " " " ": " " "$6 6 6 6 6j 6 6 6(! ! ! ! !, ! ! !    L   , , , , ,L , , , 6(H\$::;@B  B B B B     %TT U U U  $SS T T T  $UU V V V  +PP Q Q Q  +^^ _ _ _  %RR S S S)-    r;   