# Copyright 2022-2024 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Typing utilities for OpTree."""

from __future__ import annotations

import types
from collections.abc import Hashable
from typing import (
    Any,
    Callable,
    DefaultDict,
    Deque,
    Dict,
    ForwardRef,
    Generic,
    Iterable,
    List,
    NoReturn,
    Optional,
    Sequence,
    Tuple,
    TypeVar,
    Union,
)
from typing_extensions import NamedTuple  # Generic NamedTuple: Python 3.11+
from typing_extensions import OrderedDict  # Generic OrderedDict: Python 3.7.2+
from typing_extensions import Self  # Python 3.11+
from typing_extensions import TypeAlias  # Python 3.10+
from typing_extensions import Final, Protocol, runtime_checkable  # Python 3.8+

from optree import _C
from optree._C import PyTreeKind, PyTreeSpec


__all__ = [
    'PyTreeSpec',
    'PyTreeDef',
    'PyTreeKind',
    'PyTree',
    'PyTreeTypeVar',
    'CustomTreeNode',
    'Children',
    'MetaData',
    'FlattenFunc',
    'UnflattenFunc',
    'is_namedtuple',
    'is_namedtuple_instance',
    'is_namedtuple_class',
    'namedtuple_fields',
    'is_structseq',
    'is_structseq_instance',
    'is_structseq_class',
    'structseq_fields',
    'T',
    'S',
    'U',
    'KT',
    'VT',
    'Iterable',
    'Sequence',
    'List',
    'Tuple',
    'NamedTuple',
    'Dict',
    'OrderedDict',
    'DefaultDict',
    'Deque',
]


PyTreeDef = PyTreeSpec  # alias

T = TypeVar('T')
S = TypeVar('S')
U = TypeVar('U')
KT = TypeVar('KT')
VT = TypeVar('VT')


Children = Iterable[T]
_MetaData = TypeVar('_MetaData', bound=Hashable)
MetaData = Optional[_MetaData]


@runtime_checkable
class CustomTreeNode(Protocol[T]):
    """The abstract base class for custom pytree nodes."""

    def tree_flatten(
        self,
    ) -> (
        # Use `range(num_children)` as path entries
        tuple[Children[T], MetaData]
        |
        # With optionally implemented path entries
        tuple[Children[T], MetaData, Iterable[Any] | None]
    ):
        """Flatten the custom pytree node into children and auxiliary data."""

    @classmethod
    def tree_unflatten(cls, metadata: MetaData, children: Children[T]) -> CustomTreeNode[T]:
        """Unflatten the children and auxiliary data into the custom pytree node."""


_GenericAlias = type(Union[int, str])


def _tp_cache(func: Callable) -> Callable:
    import functools  # pylint: disable=import-outside-toplevel

    cached = functools.lru_cache()(func)

    @functools.wraps(func)
    def inner(*args: Any, **kwds: Any) -> Any:
        try:
            return cached(*args, **kwds)
        except TypeError:
            # All real errors (not unhashable args) are raised below.
            return func(*args, **kwds)

    return inner


class PyTree(Generic[T]):  # pylint: disable=too-few-public-methods
    """Generic PyTree type.

    >>> import torch
    >>> from optree.typing import PyTree
    >>> TensorTree = PyTree[torch.Tensor]
    >>> TensorTree  # doctest: +IGNORE_WHITESPACE
    typing.Union[torch.Tensor,
                 typing.Tuple[ForwardRef('PyTree[torch.Tensor]'), ...],
                 typing.List[ForwardRef('PyTree[torch.Tensor]')],
                 typing.Dict[typing.Any, ForwardRef('PyTree[torch.Tensor]')],
                 typing.Deque[ForwardRef('PyTree[torch.Tensor]')],
                 optree.typing.CustomTreeNode[ForwardRef('PyTree[torch.Tensor]')]]
    """

    @_tp_cache
    def __class_getitem__(
        cls,
        item: T | tuple[T] | tuple[T, str | None],
    ) -> TypeAlias:
        """Instantiate a PyTree type with the given type."""
        if not isinstance(item, tuple):
            item = (item, None)
        if len(item) != 2:
            raise TypeError(
                f'{cls.__name__}[...] only supports a tuple of 2 items, '
                f'a parameter and a string of type name, got {item!r}.',
            )
        param, name = item
        if name is not None and not isinstance(name, str):
            raise TypeError(
                f'{cls.__name__}[...] only supports a tuple of 2 items, '
                f'a parameter and a string of type name, got {item!r}.',
            )

        if (
            isinstance(param, _GenericAlias)
            and param.__origin__ is Union  # type: ignore[attr-defined]
            and hasattr(param, '__pytree_args__')
        ):
            return param  # PyTree[PyTree[T]] -> PyTree[T]

        if name is not None:
            recurse_ref = ForwardRef(name)
        elif isinstance(param, TypeVar):
            recurse_ref = ForwardRef(f'{cls.__name__}[{param.__name__}]')
        elif isinstance(param, type):
            if param.__module__ == 'builtins':
                typename = param.__qualname__
            else:
                try:
                    typename = f'{param.__module__}.{param.__qualname__}'
                except AttributeError:
                    typename = f'{param.__module__}.{param.__name__}'
            recurse_ref = ForwardRef(f'{cls.__name__}[{typename}]')
        else:
            recurse_ref = ForwardRef(f'{cls.__name__}[{param!r}]')

        pytree_alias = Union[
            param,  # type: ignore[valid-type]
            Tuple[recurse_ref, ...],  # type: ignore[valid-type] # Tuple, NamedTuple, PyStructSequence
            List[recurse_ref],  # type: ignore[valid-type]
            Dict[Any, recurse_ref],  # type: ignore[valid-type] # Dict, OrderedDict, DefaultDict
            Deque[recurse_ref],  # type: ignore[valid-type]
            CustomTreeNode[recurse_ref],  # type: ignore[valid-type]
        ]
        pytree_alias.__pytree_args__ = item  # type: ignore[attr-defined]
        return pytree_alias

    def __new__(cls) -> NoReturn:  # pylint: disable=arguments-differ
        """Prohibit instantiation."""
        raise TypeError('Cannot instantiate special typing classes.')

    def __init_subclass__(cls, *args: Any, **kwargs: Any) -> NoReturn:
        """Prohibit subclassing."""
        raise TypeError('Cannot subclass special typing classes.')

    def __copy__(self) -> PyTree:
        """Immutable copy."""
        return self

    def __deepcopy__(self, memo: dict[int, Any]) -> PyTree:
        """Immutable copy."""
        return self


class PyTreeTypeVar:
    """Type variable for PyTree.

    >>> import torch
    >>> from optree.typing import PyTreeTypeVar
    >>> TensorTree = PyTreeTypeVar('TensorTree', torch.Tensor)
    >>> TensorTree  # doctest: +IGNORE_WHITESPACE
    typing.Union[torch.Tensor,
                 typing.Tuple[ForwardRef('TensorTree'), ...],
                 typing.List[ForwardRef('TensorTree')],
                 typing.Dict[typing.Any, ForwardRef('TensorTree')],
                 typing.Deque[ForwardRef('TensorTree')],
                 optree.typing.CustomTreeNode[ForwardRef('TensorTree')]]
    """

    @_tp_cache
    def __new__(cls, name: str, param: type) -> TypeAlias:
        """Instantiate a PyTree type variable with the given name and parameter."""
        if not isinstance(name, str):
            raise TypeError(f'{cls.__name__} only supports a string of type name, got {name!r}.')
        return PyTree[param, name]  # type: ignore[misc,valid-type]

    def __init_subclass__(cls, *args: Any, **kwargs: Any) -> NoReturn:
        """Prohibit subclassing."""
        raise TypeError('Cannot subclass special typing classes.')

    def __copy__(self) -> TypeAlias:
        """Immutable copy."""
        return self

    def __deepcopy__(self, memo: dict[int, Any]) -> TypeAlias:
        """Immutable copy."""
        return self


FlattenFunc = Callable[
    [CustomTreeNode[T]],
    Union[
        Tuple[Children[T], MetaData],
        Tuple[Children[T], MetaData, Optional[Iterable[Any]]],
    ],
]
UnflattenFunc = Callable[[MetaData, Children[T]], CustomTreeNode[T]]


def is_namedtuple(obj: object | type) -> bool:
    """Return whether the object is an instance of namedtuple or a subclass of namedtuple."""
    cls = obj if isinstance(obj, type) else type(obj)
    return is_namedtuple_class(cls)


def is_namedtuple_instance(obj: object) -> bool:
    """Return whether the object is an instance of namedtuple."""
    return is_namedtuple_class(type(obj))


def is_namedtuple_class(cls: type) -> bool:
    """Return whether the class is a subclass of namedtuple."""
    return (
        isinstance(cls, type)
        and issubclass(cls, tuple)
        and isinstance(getattr(cls, '_fields', None), tuple)
        and all(
            type(field) is str  # noqa: E721 # pylint: disable=unidiomatic-typecheck
            for field in cls._fields  # type: ignore[attr-defined]
        )
        and callable(getattr(cls, '_make', None))
        and callable(getattr(cls, '_asdict', None))
    )


def namedtuple_fields(obj: tuple | type[tuple]) -> tuple[str, ...]:
    """Return the field names of a namedtuple."""
    if isinstance(obj, type):
        cls = obj
        if not is_namedtuple_class(cls):
            raise TypeError(f'Expected a collections.namedtuple type, got {cls!r}.')
    else:
        cls = type(obj)
        if not is_namedtuple_class(cls):
            raise TypeError(f'Expected an instance of collections.namedtuple type, got {obj!r}.')
    return cls._fields  # type: ignore[attr-defined]


_T_co = TypeVar('_T_co', covariant=True)


class _StructSequenceMeta(type):
    def __subclasscheck__(cls, subclass: type) -> bool:
        """Return whether the class is a PyStructSequence type.

        >>> import time
        >>> issubclass(time.struct_time, structseq)
        True
        >>> class MyTuple(tuple):
        ...     n_fields = 2
        ...     n_sequence_fields = 2
        ...     n_unnamed_fields = 0
        >>> issubclass(MyTuple, structseq)
        False
        """
        return is_structseq_class(subclass)

    def __instancecheck__(cls, instance: Any) -> bool:
        """Return whether the object is a PyStructSequence instance.

        >>> import sys
        >>> isinstance(sys.float_info, structseq)
        True
        >>> isinstance((1, 2), structseq)
        False
        """
        return is_structseq_instance(instance)


# Reference: https://github.com/python/typeshed/blob/main/stdlib/_typeshed/__init__.pyi
# This is an internal CPython type that is like, but subtly different from a NamedTuple.
# `structseq` classes are unsubclassable, so are all decorated with `@final`.
# pylint: disable-next=invalid-name,missing-class-docstring
class structseq(tuple, Generic[_T_co], metaclass=_StructSequenceMeta):  # type: ignore[misc] # noqa: N801
    """A generic type stub for CPython's ``PyStructSequence`` type."""

    n_fields: Final[int]  # type: ignore[misc] # pylint: disable=invalid-name
    n_sequence_fields: Final[int]  # type: ignore[misc] # pylint: disable=invalid-name
    n_unnamed_fields: Final[int]  # type: ignore[misc] # pylint: disable=invalid-name

    def __init_subclass__(cls) -> NoReturn:
        """Prohibit subclassing."""
        raise TypeError("type 'structseq' is not an acceptable base type")

    # pylint: disable-next=unused-argument,redefined-builtin
    def __new__(cls, sequence: Iterable[_T_co], dict: dict[str, Any] = ...) -> Self:
        raise NotImplementedError


del _StructSequenceMeta


def is_structseq(obj: object | type) -> bool:
    """Return whether the object is an instance of PyStructSequence or a class of PyStructSequence."""
    cls = obj if isinstance(obj, type) else type(obj)
    return is_structseq_class(cls)


def is_structseq_instance(obj: object) -> bool:
    """Return whether the object is an instance of PyStructSequence."""
    return is_structseq_class(type(obj))


# Set if the type allows subclassing (see CPython's Include/object.h)
Py_TPFLAGS_BASETYPE = _C.Py_TPFLAGS_BASETYPE  # (1UL << 10)


def is_structseq_class(cls: type) -> bool:
    """Return whether the class is a class of PyStructSequence."""
    return (
        isinstance(cls, type)
        # Check direct inheritance from `tuple` rather than `issubclass(cls, tuple)`
        and cls.__bases__ == (tuple,)
        # Check PyStructSequence members
        and isinstance(getattr(cls, 'n_fields', None), int)
        and isinstance(getattr(cls, 'n_sequence_fields', None), int)
        and isinstance(getattr(cls, 'n_unnamed_fields', None), int)
        # Check the type does not allow subclassing
        and not (cls.__flags__ & Py_TPFLAGS_BASETYPE)
    )


def structseq_fields(obj: tuple | type[tuple]) -> tuple[str, ...]:
    """Return the field names of a PyStructSequence."""
    if isinstance(obj, type):
        cls = obj
        if not is_structseq_class(cls):
            raise TypeError(f'Expected a PyStructSequence type, got {cls!r}.')
    else:
        cls = type(obj)
        if not is_structseq_class(cls):
            raise TypeError(f'Expected an instance of PyStructSequence type, got {obj!r}.')

    n_sequence_fields: int = cls.n_sequence_fields  # type: ignore[attr-defined]
    fields: list[str] = []
    for name, member in vars(cls).items():
        if len(fields) >= n_sequence_fields:
            break
        if isinstance(member, types.MemberDescriptorType):
            fields.append(name)
    return tuple(fields)


# Ensure that the behavior is consistent with C++ implementation
# pylint: disable-next=wrong-import-position,ungrouped-imports
from optree._C import (
    is_namedtuple,
    is_namedtuple_class,
    is_namedtuple_instance,
    is_structseq,
    is_structseq_class,
    is_structseq_instance,
    namedtuple_fields,
    structseq_fields,
)
