import math
from chromadb.test.property.strategies import NormalizedRecordSet, RecordSet
from typing import Callable, Optional, Tuple, Union, List, TypeVar, cast, Dict
from typing_extensions import Literal
import numpy as np
import numpy.typing as npt
from chromadb.api import types
from chromadb.api.models.Collection import Collection
from hypothesis import note
from hypothesis.errors import InvalidArgument

T = TypeVar("T")


def wrap(value: Union[T, List[T]]) -> List[T]:
    """Wrap a value in a list if it is not a list"""
    if value is None:
        raise InvalidArgument("value cannot be None")
    elif isinstance(value, List):
        return value
    else:
        return [value]


def wrap_all(record_set: RecordSet) -> NormalizedRecordSet:
    """Ensure that an embedding set has lists for all its values"""

    embedding_list: Optional[types.Embeddings]
    if record_set["embeddings"] is None:
        embedding_list = None
    elif isinstance(record_set["embeddings"], list):
        assert record_set["embeddings"] is not None
        if len(record_set["embeddings"]) > 0 and not all(
            isinstance(embedding, list) for embedding in record_set["embeddings"]
        ):
            if all(isinstance(e, (int, float)) for e in record_set["embeddings"]):
                embedding_list = cast(types.Embeddings, [record_set["embeddings"]])
            else:
                raise InvalidArgument("an embedding must be a list of floats or ints")
        else:
            embedding_list = cast(types.Embeddings, record_set["embeddings"])
    else:
        raise InvalidArgument(
            "embeddings must be a list of lists, a list of numbers, or None"
        )

    return {
        "ids": wrap(record_set["ids"]),
        "documents": wrap(record_set["documents"])
        if record_set["documents"] is not None
        else None,
        "metadatas": wrap(record_set["metadatas"])
        if record_set["metadatas"] is not None
        else None,
        "embeddings": embedding_list,
    }


def count(collection: Collection, record_set: RecordSet) -> None:
    """The given collection count is equal to the number of embeddings"""
    count = collection.count()
    normalized_record_set = wrap_all(record_set)
    assert count == len(normalized_record_set["ids"])


def _field_matches(
    collection: Collection,
    normalized_record_set: NormalizedRecordSet,
    field_name: Union[Literal["documents"], Literal["metadatas"]],
) -> None:
    """
    The actual embedding field is equal to the expected field
    field_name: one of [documents, metadatas]
    """
    result = collection.get(ids=normalized_record_set["ids"], include=[field_name])
    # The test_out_of_order_ids test fails because of this in test_add.py
    # Here we sort by the ids to match the input order
    embedding_id_to_index = {id: i for i, id in enumerate(normalized_record_set["ids"])}
    actual_field = result[field_name]
    # This assert should never happen, if we include metadatas/documents it will be
    # [None, None..] if there is no metadata. It will not be just None.
    assert actual_field is not None
    sorted_field = sorted(
        enumerate(actual_field),
        key=lambda index_and_field_value: embedding_id_to_index[
            result["ids"][index_and_field_value[0]]
        ],
    )
    field_values = [field_value for _, field_value in sorted_field]

    expected_field = normalized_record_set[field_name]
    if expected_field is None:
        # Since an RecordSet is the user input, we need to convert the documents to
        # a List since thats what the API returns -> none per entry
        expected_field = [None] * len(normalized_record_set["ids"])  # type: ignore
    assert field_values == expected_field


def ids_match(collection: Collection, record_set: RecordSet) -> None:
    """The actual embedding ids is equal to the expected ids"""
    normalized_record_set = wrap_all(record_set)
    actual_ids = collection.get(ids=normalized_record_set["ids"], include=[])["ids"]
    # The test_out_of_order_ids test fails because of this in test_add.py
    # Here we sort the ids to match the input order
    embedding_id_to_index = {id: i for i, id in enumerate(normalized_record_set["ids"])}
    actual_ids = sorted(actual_ids, key=lambda id: embedding_id_to_index[id])
    assert actual_ids == normalized_record_set["ids"]


def metadatas_match(collection: Collection, record_set: RecordSet) -> None:
    """The actual embedding metadata is equal to the expected metadata"""
    normalized_record_set = wrap_all(record_set)
    _field_matches(collection, normalized_record_set, "metadatas")


def documents_match(collection: Collection, record_set: RecordSet) -> None:
    """The actual embedding documents is equal to the expected documents"""
    normalized_record_set = wrap_all(record_set)
    _field_matches(collection, normalized_record_set, "documents")


def no_duplicates(collection: Collection) -> None:
    ids = collection.get()["ids"]
    assert len(ids) == len(set(ids))


# These match what the spec of hnswlib is
# This epsilon is used to prevent division by zero and the value is the same
# https://github.com/nmslib/hnswlib/blob/359b2ba87358224963986f709e593d799064ace6/python_bindings/bindings.cpp#L238
NORM_EPS = 1e-30
distance_functions: Dict[str, Callable[[npt.ArrayLike, npt.ArrayLike], float]] = {
    "l2": lambda x, y: np.linalg.norm(x - y) ** 2,  # type: ignore
    "cosine": lambda x, y: 1 - np.dot(x, y) / ((np.linalg.norm(x) + NORM_EPS) * (np.linalg.norm(y) + NORM_EPS)),  # type: ignore
    "ip": lambda x, y: 1 - np.dot(x, y),  # type: ignore
}


def _exact_distances(
    query: types.Embeddings,
    targets: types.Embeddings,
    distance_fn: Callable[[npt.ArrayLike, npt.ArrayLike], float] = distance_functions[
        "l2"
    ],
) -> Tuple[List[List[int]], List[List[float]]]:
    """Return the ordered indices and distances from each query to each target"""
    np_query = np.array(query)
    np_targets = np.array(targets)

    # Compute the distance between each query and each target, using the distance function
    distances = np.apply_along_axis(
        lambda query: np.apply_along_axis(distance_fn, 1, np_targets, query),
        1,
        np_query,
    )
    # Sort the distances and return the indices
    return np.argsort(distances).tolist(), distances.tolist()


def is_metadata_valid(normalized_record_set: NormalizedRecordSet) -> bool:
    if normalized_record_set["metadatas"] is None:
        return True
    return not any([len(m) == 0 for m in normalized_record_set["metadatas"]])


def ann_accuracy(
    collection: Collection,
    record_set: RecordSet,
    n_results: int = 1,
    min_recall: float = 0.99,
    embedding_function: Optional[types.EmbeddingFunction] = None,
) -> None:
    """Validate that the API performs nearest_neighbor searches correctly"""
    normalized_record_set = wrap_all(record_set)

    if len(normalized_record_set["ids"]) == 0:
        return  # nothing to test here

    embeddings: Optional[types.Embeddings] = normalized_record_set["embeddings"]
    have_embeddings = embeddings is not None and len(embeddings) > 0
    if not have_embeddings:
        assert embedding_function is not None
        assert normalized_record_set["documents"] is not None
        assert isinstance(normalized_record_set["documents"], list)
        # Compute the embeddings for the documents
        embeddings = embedding_function(normalized_record_set["documents"])

    # l2 is the default distance function
    distance_function = distance_functions["l2"]
    accuracy_threshold = 1e-6
    assert collection.metadata is not None
    assert embeddings is not None
    if "hnsw:space" in collection.metadata:
        space = collection.metadata["hnsw:space"]
        # TODO: ip and cosine are numerically unstable in HNSW.
        # The higher the dimensionality, the more noise is introduced, since each float element
        # of the vector has noise added, which is then subsequently included in all normalization calculations.
        # This means that higher dimensions will have more noise, and thus more error.
        assert all(isinstance(e, list) for e in embeddings)
        dim = len(embeddings[0])
        accuracy_threshold = accuracy_threshold * math.pow(10, int(math.log10(dim)))

        if space == "cosine":
            distance_function = distance_functions["cosine"]

        if space == "ip":
            distance_function = distance_functions["ip"]

    # Perform exact distance computation
    indices, distances = _exact_distances(
        embeddings, embeddings, distance_fn=distance_function
    )

    query_results = collection.query(
        query_embeddings=normalized_record_set["embeddings"],
        query_texts=normalized_record_set["documents"] if not have_embeddings else None,
        n_results=n_results,
        include=["embeddings", "documents", "metadatas", "distances"],
    )

    assert query_results["distances"] is not None
    assert query_results["documents"] is not None
    assert query_results["metadatas"] is not None
    assert query_results["embeddings"] is not None

    # Dict of ids to indices
    id_to_index = {id: i for i, id in enumerate(normalized_record_set["ids"])}
    missing = 0
    for i, (indices_i, distances_i) in enumerate(zip(indices, distances)):
        expected_ids = np.array(normalized_record_set["ids"])[indices_i[:n_results]]
        missing += len(set(expected_ids) - set(query_results["ids"][i]))

        # For each id in the query results, find the index in the embeddings set
        # and assert that the embeddings are the same
        for j, id in enumerate(query_results["ids"][i]):
            # This may be because the true nth nearest neighbor didn't get returned by the ANN query
            unexpected_id = id not in expected_ids
            index = id_to_index[id]

            correct_distance = np.allclose(
                distances_i[index],
                query_results["distances"][i][j],
                atol=accuracy_threshold,
            )
            if unexpected_id:
                # If the ID is unexpcted, but the distance is correct, then we
                # have a duplicate in the data. In this case, we should not reduce recall.
                if correct_distance:
                    missing -= 1
                else:
                    continue
            else:
                assert correct_distance

            assert np.allclose(embeddings[index], query_results["embeddings"][i][j])
            if normalized_record_set["documents"] is not None:
                assert (
                    normalized_record_set["documents"][index]
                    == query_results["documents"][i][j]
                )
            if normalized_record_set["metadatas"] is not None:
                assert (
                    normalized_record_set["metadatas"][index]
                    == query_results["metadatas"][i][j]
                )

    size = len(normalized_record_set["ids"])
    recall = (size - missing) / size

    try:
        note(
            f"recall: {recall}, missing {missing} out of {size}, accuracy threshold {accuracy_threshold}"
        )
    except InvalidArgument:
        pass  # it's ok if we're running outside hypothesis

    assert recall >= min_recall

    # Ensure that the query results are sorted by distance
    for distance_result in query_results["distances"]:
        assert np.allclose(np.sort(distance_result), distance_result)
