import itertools

import numpy as np
import pytest
from numpy.testing import assert_array_almost_equal
from sklearn.neighbors._ball_tree import BallTree
from sklearn.utils import check_random_state
from sklearn.utils.validation import check_array
from sklearn.utils._testing import _convert_container

rng = np.random.RandomState(10)
V_mahalanobis = rng.rand(3, 3)
V_mahalanobis = np.dot(V_mahalanobis, V_mahalanobis.T)

DIMENSION = 3

METRICS = {
    "euclidean": {},
    "manhattan": {},
    "minkowski": dict(p=3),
    "chebyshev": {},
    "seuclidean": dict(V=rng.random_sample(DIMENSION)),
    "wminkowski": dict(p=3, w=rng.random_sample(DIMENSION)),
    "mahalanobis": dict(V=V_mahalanobis),
}

DISCRETE_METRICS = ["hamming", "canberra", "braycurtis"]

BOOLEAN_METRICS = [
    "matching",
    "jaccard",
    "dice",
    "kulsinski",
    "rogerstanimoto",
    "russellrao",
    "sokalmichener",
    "sokalsneath",
]


def brute_force_neighbors(X, Y, k, metric, **kwargs):
    from sklearn.metrics import DistanceMetric

    X, Y = check_array(X), check_array(Y)
    D = DistanceMetric.get_metric(metric, **kwargs).pairwise(Y, X)
    ind = np.argsort(D, axis=1)[:, :k]
    dist = D[np.arange(Y.shape[0])[:, None], ind]
    return dist, ind


@pytest.mark.parametrize("metric", itertools.chain(BOOLEAN_METRICS, DISCRETE_METRICS))
@pytest.mark.parametrize("array_type", ["list", "array"])
def test_ball_tree_query_metrics(metric, array_type):
    rng = check_random_state(0)
    if metric in BOOLEAN_METRICS:
        X = rng.random_sample((40, 10)).round(0)
        Y = rng.random_sample((10, 10)).round(0)
    elif metric in DISCRETE_METRICS:
        X = (4 * rng.random_sample((40, 10))).round(0)
        Y = (4 * rng.random_sample((10, 10))).round(0)
    X = _convert_container(X, array_type)
    Y = _convert_container(Y, array_type)

    k = 5

    bt = BallTree(X, leaf_size=1, metric=metric)
    dist1, ind1 = bt.query(Y, k)
    dist2, ind2 = brute_force_neighbors(X, Y, k, metric)
    assert_array_almost_equal(dist1, dist2)


def test_query_haversine():
    rng = check_random_state(0)
    X = 2 * np.pi * rng.random_sample((40, 2))
    bt = BallTree(X, leaf_size=1, metric="haversine")
    dist1, ind1 = bt.query(X, k=5)
    dist2, ind2 = brute_force_neighbors(X, X, k=5, metric="haversine")

    assert_array_almost_equal(dist1, dist2)
    assert_array_almost_equal(ind1, ind2)


def test_array_object_type():
    """Check that we do not accept object dtype array."""
    X = np.array([(1, 2, 3), (2, 5), (5, 5, 1, 2)], dtype=object)
    with pytest.raises(ValueError, match="setting an array element with a sequence"):
        BallTree(X)


def test_bad_pyfunc_metric():
    def wrong_returned_value(x, y):
        return "1"

    def one_arg_func(x):
        return 1.0  # pragma: no cover

    X = np.ones((5, 2))
    msg = "Custom distance function must accept two vectors and return a float."
    with pytest.raises(TypeError, match=msg):
        BallTree(X, metric=wrong_returned_value)

    msg = "takes 1 positional argument but 2 were given"
    with pytest.raises(TypeError, match=msg):
        BallTree(X, metric=one_arg_func)
