# Licensed under a 3-clause BSD style license - see PYFITS.rst

import platform

import numpy as np
import pytest
from numpy.testing import assert_array_equal

from astropy.io import fits

from .conftest import FitsTestCase


class TestUintFunctions(FitsTestCase):
    @classmethod
    def setup_class(cls):
        cls.utypes = ("u2", "u4", "u8")
        cls.utype_map = {"u2": np.uint16, "u4": np.uint32, "u8": np.uint64}
        cls.itype_map = {"u2": np.int16, "u4": np.int32, "u8": np.int64}
        cls.format_map = {"u2": "I", "u4": "J", "u8": "K"}

    # Test of 64-bit compressed image is disabled.  cfitsio library doesn't
    # like it
    @pytest.mark.parametrize(
        ("utype", "compressed"),
        [("u2", False), ("u4", False), ("u8", False), ("u2", True), ("u4", True)],
    )  # ,('u8',True)])
    def test_uint(self, utype, compressed):
        bits = 8 * int(utype[1])
        if platform.architecture()[0] == "64bit" or bits != 64:
            if compressed:
                hdu = fits.CompImageHDU(
                    np.array([-3, -2, -1, 0, 1, 2, 3], dtype=np.int64)
                )
                hdu_number = 1
            else:
                hdu = fits.PrimaryHDU(
                    np.array([-3, -2, -1, 0, 1, 2, 3], dtype=np.int64)
                )
                hdu_number = 0

            hdu.scale(f"int{bits:d}", "", bzero=2 ** (bits - 1))

            hdu.writeto(self.temp("tempfile.fits"), overwrite=True)

            with fits.open(self.temp("tempfile.fits"), uint=True) as hdul:
                assert hdul[hdu_number].data.dtype == self.utype_map[utype]
                assert (
                    hdul[hdu_number].data
                    == np.array(
                        [(2**bits) - 3, (2**bits) - 2, (2**bits) - 1, 0, 1, 2, 3],
                        dtype=self.utype_map[utype],
                    )
                ).all()
                hdul.writeto(self.temp("tempfile1.fits"))
                with fits.open(self.temp("tempfile1.fits"), uint16=True) as hdul1:
                    d1 = hdul[hdu_number].data
                    d2 = hdul1[hdu_number].data
                    assert (d1 == d2).all()
                    if not compressed:
                        # TODO: Enable these lines if CompImageHDUs ever grow
                        # .section support
                        sec = hdul[hdu_number].section[:1]
                        assert sec.dtype.name == f"uint{bits}"
                        assert (sec == d1[:1]).all()

    @pytest.mark.parametrize("utype", ("u2", "u4", "u8"))
    def test_uint_columns(self, utype):
        """Test basic functionality of tables with columns containing
        pseudo-unsigned integers.  See
        https://github.com/astropy/astropy/pull/906
        """

        bits = 8 * int(utype[1])
        if platform.architecture()[0] == "64bit" or bits != 64:
            bzero = self.utype_map[utype](2 ** (bits - 1))
            one = self.utype_map[utype](1)
            u0 = np.arange(bits + 1, dtype=self.utype_map[utype])
            u = 2**u0 - one
            if bits == 64:
                u[63] = bzero - one
                u[64] = u[63] + u[63] + one
            uu = (u - bzero).view(self.itype_map[utype])

            # Construct a table from explicit column
            col = fits.Column(
                name=utype, array=u, format=self.format_map[utype], bzero=bzero
            )

            table = fits.BinTableHDU.from_columns([col])
            assert (table.data[utype] == u).all()
            # This used to be table.data.base, but now after adding a table to
            # a BinTableHDU it gets stored as a view of the original table,
            # even if the original was already a FITS_rec.  So now we need
            # table.data.base.base
            assert (table.data.base.base[utype] == uu).all()
            hdu0 = fits.PrimaryHDU()
            hdulist = fits.HDUList([hdu0, table])

            hdulist.writeto(self.temp("tempfile.fits"), overwrite=True)

            # Test write of unsigned int
            del hdulist
            with fits.open(self.temp("tempfile.fits"), uint=True) as hdulist2:
                hdudata = hdulist2[1].data
                assert (hdudata[utype] == u).all()
                assert hdudata[utype].dtype == self.utype_map[utype]
                assert (hdudata.base[utype] == uu).all()

            # Construct recarray then write out that.
            v = u.view(dtype=[(utype, self.utype_map[utype])])

            fits.writeto(self.temp("tempfile2.fits"), v, overwrite=True)

            with fits.open(self.temp("tempfile2.fits"), uint=True) as hdulist3:
                hdudata3 = hdulist3[1].data
                assert (hdudata3.base[utype] == table.data.base.base[utype]).all()
                assert (hdudata3[utype] == table.data[utype]).all()
                assert (hdudata3[utype] == u).all()

    def test_uint_slice(self):
        """
        Fix for https://github.com/astropy/astropy/issues/5490
        if data is sliced first, make sure the data is still converted as uint
        """
        # create_data:
        dataref = np.arange(2**16, dtype=np.uint16)
        tbhdu = fits.BinTableHDU.from_columns(
            [
                fits.Column(
                    name="a", format="I", array=np.arange(2**16, dtype=np.int16)
                ),
                fits.Column(
                    name="b", format="I", bscale=1, bzero=2**15, array=dataref
                ),
            ]
        )
        tbhdu.writeto(self.temp("test_scaled_slicing.fits"))

        with fits.open(self.temp("test_scaled_slicing.fits")) as hdulist:
            data = hdulist[1].data
        assert_array_equal(data["b"], dataref)
        sel = data["a"] >= 0
        assert_array_equal(data[sel]["b"], dataref[sel])
        assert data[sel]["b"].dtype == dataref[sel].dtype

        with fits.open(self.temp("test_scaled_slicing.fits")) as hdulist:
            data = hdulist[1].data
        assert_array_equal(data[sel]["b"], dataref[sel])
        assert data[sel]["b"].dtype == dataref[sel].dtype
