import numbers
from collections import defaultdict

import numpy as np

from astropy.utils import isiterable
from astropy.utils.decorators import lazyproperty

from .base import BaseWCSWrapper

__all__ = ["sanitize_slices", "SlicedLowLevelWCS"]


def sanitize_slices(slices, ndim):
    """
    Given a slice as input sanitise it to an easier to parse format.format.

    This function returns a list ``ndim`` long containing slice objects (or ints).
    """
    if not isinstance(slices, (tuple, list)):  # We just have a single int
        slices = (slices,)

    if len(slices) > ndim:
        raise ValueError(
            f"The dimensionality of the specified slice {slices} can not be greater "
            f"than the dimensionality ({ndim}) of the wcs."
        )

    if any(isiterable(s) for s in slices):
        raise IndexError(
            "This slice is invalid, only integer or range slices are supported."
        )

    slices = list(slices)

    if Ellipsis in slices:
        if slices.count(Ellipsis) > 1:
            raise IndexError("an index can only have a single ellipsis ('...')")

        # Replace the Ellipsis with the correct number of slice(None)s
        e_ind = slices.index(Ellipsis)
        slices.remove(Ellipsis)
        n_e = ndim - len(slices)
        for i in range(n_e):
            ind = e_ind + i
            slices.insert(ind, slice(None))

    for i in range(ndim):
        if i < len(slices):
            slc = slices[i]
            if isinstance(slc, slice):
                if slc.step and slc.step != 1:
                    raise IndexError("Slicing WCS with a step is not supported.")
            elif not isinstance(slc, numbers.Integral):
                raise IndexError("Only integer or range slices are accepted.")
        else:
            slices.append(slice(None))

    return slices


def combine_slices(slice1, slice2):
    """
    Given two slices that can be applied to a 1-d array, find the resulting
    slice that corresponds to the combination of both slices. We assume that
    slice2 can be an integer, but slice1 cannot.
    """
    if isinstance(slice1, slice) and slice1.step is not None:
        raise ValueError("Only slices with steps of 1 are supported")

    if isinstance(slice2, slice) and slice2.step is not None:
        raise ValueError("Only slices with steps of 1 are supported")

    if isinstance(slice2, numbers.Integral):
        if slice1.start is None:
            return slice2
        else:
            return slice2 + slice1.start

    if slice1.start is None:
        if slice1.stop is None:
            return slice2
        else:
            if slice2.stop is None:
                return slice(slice2.start, slice1.stop)
            else:
                return slice(slice2.start, min(slice1.stop, slice2.stop))
    else:
        if slice2.start is None:
            start = slice1.start
        else:
            start = slice1.start + slice2.start
        if slice2.stop is None:
            stop = slice1.stop
        else:
            if slice1.start is None:
                stop = slice2.stop
            else:
                stop = slice2.stop + slice1.start
            if slice1.stop is not None:
                stop = min(slice1.stop, stop)
    return slice(start, stop)


class SlicedLowLevelWCS(BaseWCSWrapper):
    """
    A Low Level WCS wrapper which applies an array slice to a WCS.

    This class does not modify the underlying WCS object and can therefore drop
    coupled dimensions as it stores which pixel and world dimensions have been
    sliced out (or modified) in the underlying WCS and returns the modified
    results on all the Low Level WCS methods.

    Parameters
    ----------
    wcs : `~astropy.wcs.wcsapi.BaseLowLevelWCS`
        The WCS to slice.
    slices : `slice` or `tuple` or `int`
        A valid array slice to apply to the WCS.

    """

    def __init__(self, wcs, slices):
        slices = sanitize_slices(slices, wcs.pixel_n_dim)

        if isinstance(wcs, SlicedLowLevelWCS):
            # Here we combine the current slices with the previous slices
            # to avoid ending up with many nested WCSes
            self._wcs = wcs._wcs
            slices_original = wcs._slices_array.copy()
            for ipixel in range(wcs.pixel_n_dim):
                ipixel_orig = wcs._wcs.pixel_n_dim - 1 - wcs._pixel_keep[ipixel]
                ipixel_new = wcs.pixel_n_dim - 1 - ipixel
                slices_original[ipixel_orig] = combine_slices(
                    slices_original[ipixel_orig], slices[ipixel_new]
                )
            self._slices_array = slices_original
        else:
            self._wcs = wcs
            self._slices_array = slices

        self._slices_pixel = self._slices_array[::-1]

        # figure out which pixel dimensions have been kept, then use axis correlation
        # matrix to figure out which world dims are kept
        self._pixel_keep = np.nonzero(
            [
                not isinstance(self._slices_pixel[ip], numbers.Integral)
                for ip in range(self._wcs.pixel_n_dim)
            ]
        )[0]

        # axis_correlation_matrix[world, pixel]
        self._world_keep = np.nonzero(
            self._wcs.axis_correlation_matrix[:, self._pixel_keep].any(axis=1)
        )[0]

        if len(self._pixel_keep) == 0 or len(self._world_keep) == 0:
            raise ValueError(
                "Cannot slice WCS: the resulting WCS should have "
                "at least one pixel and one world dimension."
            )

    @lazyproperty
    def dropped_world_dimensions(self):
        """
        Information describing the dropped world dimensions.
        """
        world_coords = self._pixel_to_world_values_all(*[0] * len(self._pixel_keep))
        dropped_info = defaultdict(list)

        for i in range(self._wcs.world_n_dim):
            if i in self._world_keep:
                continue

            if "world_axis_object_classes" not in dropped_info:
                dropped_info["world_axis_object_classes"] = dict()

            wao_classes = self._wcs.world_axis_object_classes
            wao_components = self._wcs.world_axis_object_components

            dropped_info["value"].append(world_coords[i])
            dropped_info["world_axis_names"].append(self._wcs.world_axis_names[i])
            dropped_info["world_axis_physical_types"].append(
                self._wcs.world_axis_physical_types[i]
            )
            dropped_info["world_axis_units"].append(self._wcs.world_axis_units[i])
            dropped_info["world_axis_object_components"].append(wao_components[i])
            dropped_info["world_axis_object_classes"].update(
                dict(
                    filter(lambda x: x[0] == wao_components[i][0], wao_classes.items())
                )
            )
            dropped_info["serialized_classes"] = self.serialized_classes
        return dict(dropped_info)

    @property
    def pixel_n_dim(self):
        return len(self._pixel_keep)

    @property
    def world_n_dim(self):
        return len(self._world_keep)

    @property
    def world_axis_physical_types(self):
        return [self._wcs.world_axis_physical_types[i] for i in self._world_keep]

    @property
    def world_axis_units(self):
        return [self._wcs.world_axis_units[i] for i in self._world_keep]

    @property
    def pixel_axis_names(self):
        return [self._wcs.pixel_axis_names[i] for i in self._pixel_keep]

    @property
    def world_axis_names(self):
        return [self._wcs.world_axis_names[i] for i in self._world_keep]

    def _pixel_to_world_values_all(self, *pixel_arrays):
        pixel_arrays = tuple(map(np.asanyarray, pixel_arrays))
        pixel_arrays_new = []
        ipix_curr = -1
        for ipix in range(self._wcs.pixel_n_dim):
            if isinstance(self._slices_pixel[ipix], numbers.Integral):
                pixel_arrays_new.append(self._slices_pixel[ipix])
            else:
                ipix_curr += 1
                if self._slices_pixel[ipix].start is not None:
                    pixel_arrays_new.append(
                        pixel_arrays[ipix_curr] + self._slices_pixel[ipix].start
                    )
                else:
                    pixel_arrays_new.append(pixel_arrays[ipix_curr])

        pixel_arrays_new = np.broadcast_arrays(*pixel_arrays_new)
        return self._wcs.pixel_to_world_values(*pixel_arrays_new)

    def pixel_to_world_values(self, *pixel_arrays):
        world_arrays = self._pixel_to_world_values_all(*pixel_arrays)

        # Detect the case of a length 0 array
        if isinstance(world_arrays, np.ndarray) and not world_arrays.shape:
            return world_arrays

        if self._wcs.world_n_dim > 1:
            # Select the dimensions of the original WCS we are keeping.
            world_arrays = [world_arrays[iw] for iw in self._world_keep]
            # If there is only one world dimension (after slicing) we shouldn't return a tuple.
            if self.world_n_dim == 1:
                world_arrays = world_arrays[0]

        return world_arrays

    def world_to_pixel_values(self, *world_arrays):
        sliced_out_world_coords = self._pixel_to_world_values_all(
            *[0] * len(self._pixel_keep)
        )

        world_arrays = tuple(map(np.asanyarray, world_arrays))
        world_arrays_new = []
        iworld_curr = -1
        for iworld in range(self._wcs.world_n_dim):
            if iworld in self._world_keep:
                iworld_curr += 1
                world_arrays_new.append(world_arrays[iworld_curr])
            else:
                world_arrays_new.append(sliced_out_world_coords[iworld])

        world_arrays_new = np.broadcast_arrays(*world_arrays_new)
        pixel_arrays = list(self._wcs.world_to_pixel_values(*world_arrays_new))

        for ipixel in range(self._wcs.pixel_n_dim):
            if (
                isinstance(self._slices_pixel[ipixel], slice)
                and self._slices_pixel[ipixel].start is not None
            ):
                pixel_arrays[ipixel] -= self._slices_pixel[ipixel].start

        # Detect the case of a length 0 array
        if isinstance(pixel_arrays, np.ndarray) and not pixel_arrays.shape:
            return pixel_arrays
        pixel = tuple(pixel_arrays[ip] for ip in self._pixel_keep)
        if self.pixel_n_dim == 1 and self._wcs.pixel_n_dim > 1:
            pixel = pixel[0]
        return pixel

    @property
    def world_axis_object_components(self):
        return [self._wcs.world_axis_object_components[idx] for idx in self._world_keep]

    @property
    def world_axis_object_classes(self):
        keys_keep = [item[0] for item in self.world_axis_object_components]
        return dict(
            [
                item
                for item in self._wcs.world_axis_object_classes.items()
                if item[0] in keys_keep
            ]
        )

    @property
    def array_shape(self):
        if self._wcs.array_shape:
            return np.broadcast_to(0, self._wcs.array_shape)[
                tuple(self._slices_array)
            ].shape

    @property
    def pixel_shape(self):
        if self.array_shape:
            return tuple(self.array_shape[::-1])

    @property
    def pixel_bounds(self):
        if self._wcs.pixel_bounds is None:
            return

        bounds = []
        for idx in self._pixel_keep:
            if self._slices_pixel[idx].start is None:
                bounds.append(self._wcs.pixel_bounds[idx])
            else:
                imin, imax = self._wcs.pixel_bounds[idx]
                start = self._slices_pixel[idx].start
                bounds.append((imin - start, imax - start))

        return tuple(bounds)

    @property
    def axis_correlation_matrix(self):
        return self._wcs.axis_correlation_matrix[self._world_keep][:, self._pixel_keep]
