"""Base Config object for use with Molecule Graph Construction."""
# Graphein
# Author: Arian Jamasb <arian@jamasb.io>
# License: MIT
# Project Website: https://github.com/a-r-j/graphein
# Code Repository: https://github.com/a-r-j/graphein
from __future__ import annotations
from functools import partial
from pathlib import Path
from typing import Any, Callable, List, Optional, Union
from deepdiff import DeepDiff
from pydantic import BaseModel
from typing_extensions import Literal
from graphein.molecule.edges.atomic import add_atom_bonds
from graphein.molecule.edges.distance import (
    add_distance_threshold,
    add_fully_connected_edges,
    add_k_nn_edges,
)
from graphein.molecule.features.nodes.atom_type import atom_type_one_hot
from graphein.utils.config import PartialMatchOperator, PathMatchOperator
GraphAtoms = Literal[
    "C",
    "H",
    "O",
    "N",
    "F",
    "P",
    "S",
    "Cl",
    "Br",
    "I",
    "B",
]
"""Allowable atom types for nodes in the graph."""
[docs]class MoleculeGraphConfig(BaseModel):
    """
    Config Object for Molecule Structure Graph Construction.
    :param verbose: Specifies verbosity of graph creation process.
    :type verbose: bool
    :param add_hs: Specifies whether hydrogens should be added to the graph.
    :type add_hs: bool
    :param edge_construction_functions: List of functions that take an ``nx.Graph`` and return an ``nx.Graph`` with desired
        edges added. Prepared edge constructions can be found in :ref:`graphein.protein.edges`
    :type edge_construction_functions: List[Callable]
    :param node_metadata_functions: List of functions that take an ``nx.Graph``
    :type node_metadata_functions: List[Callable], optional
    :param edge_metadata_functions: List of functions that take an
    :type edge_metadata_functions: List[Callable], optional
    :param graph_metadata_functions: List of functions that take an ``nx.Graph`` and return an ``nx.Graph`` with added
        graph-level features and metadata.
    :type graph_metadata_functions: List[Callable], optional
    """
    verbose: bool = False
    add_hs: bool = False
    # Graph construction functions
    edge_construction_functions: List[Union[Callable, str]] = [
        add_fully_connected_edges,
        add_k_nn_edges,
        add_distance_threshold,
        add_atom_bonds,
    ]
    node_metadata_functions: Optional[List[Union[Callable, str]]] = [
        atom_type_one_hot
    ]
    edge_metadata_functions: Optional[List[Union[Callable, str]]] = None
    graph_metadata_functions: Optional[List[Callable]] = None
    def __eq__(self, other: Any) -> bool:
        """Overwrites the BaseModel __eq__ function in order to check more specific cases (like partial functions)."""
        if isinstance(other, MoleculeGraphConfig):
            return (
                DeepDiff(
                    self,
                    other,
                    custom_operators=[
                        PartialMatchOperator(types=[partial]),
                        PathMatchOperator(types=[Path]),
                    ],
                )
                == {}
            )
        return self.dict() == other