Source code for hwm.compat.sklearn

# -*- coding: utf-8 -*-
#   License: BSD-3-Clause
#   Author: LKouadio <etanoyau@gmail.com>
#  <Take it from gofast package: https://github.com/earthai-tech/gofast>

"""
Provides compatibility utilities for different versions of
scikit-learn (sklearn). It includes functions and feature flags that
ensure smooth operation across various sklearn versions, handling
breaking changes and deprecated features. The module includes 
resampling utilities, scorer functions, and compatibility checks.

Key functionalities include:
- Resampling with sklearn's `resample`
- Validation with `check_is_fitted`
- Scorer retrieval with `get_scorer`
- Feature and compatibility flags for sklearn versions

The module ensures compatibility with sklearn versions less than 
0.22.0, 0.23.0, and 0.24.0.

Attributes
----------
SKLEARN_VERSION : packaging.version.Version
    The installed scikit-learn version.
SKLEARN_LT_0_22 : bool
    True if the installed scikit-learn version is less than 0.22.0.
SKLEARN_LT_0_23 : bool
    True if the installed scikit-learn version is less than 0.23.0.
SKLEARN_LT_0_24 : bool
    True if the installed scikit-learn version is less than 0.24.0.

Functions
---------
resample
    Resample arrays or sparse matrices in a consistent way.
get_scorer
    Get a scorer from string.
check_is_fitted
    Perform is_fitted validation for sklearn models.
"""

from packaging.version import Version, parse
import sklearn
import inspect
from sklearn.utils._param_validation import validate_params as sklearn_validate_params
from sklearn.utils._param_validation import Interval as sklearn_Interval 
from sklearn.utils._param_validation import StrOptions, HasMethods 
from sklearn.utils import resample
from sklearn.utils.validation import check_is_fitted as sklearn_check_is_fitted
from sklearn.metrics import get_scorer

# Determine the installed scikit-learn version
SKLEARN_VERSION = parse(sklearn.__version__)

# Feature and compatibility flags
SKLEARN_LT_0_22 = SKLEARN_VERSION < Version("0.22.0")
SKLEARN_LT_0_23 = SKLEARN_VERSION < Version("0.23.0")
SKLEARN_LT_0_24 = SKLEARN_VERSION < Version("0.24.0")
SKLEARN_LT_1_3 = SKLEARN_VERSION < parse("1.3.0")

__all__ = [
    "Interval", 
    "resample",
    "train_test_split",
    "get_scorer",
    "get_feature_names",
    "get_feature_names_out", 
    "get_transformers_from_column_transformer",
    "check_is_fitted",
    "adjusted_mutual_info_score", 
    "get_sgd_loss_param", 
    "validate_params", 
    "StrOptions", 
    "HasMethods", 
    "SKLEARN_LT_0_22", 
    "SKLEARN_LT_0_23", 
    "SKLEARN_LT_0_24"
]


class Interval:
    """
    Compatibility wrapper for scikit-learn's `Interval` class to handle 
    versions that do not include the `inclusive` argument.
    
    Parameters
    ----------
    *args : tuple
        Positional arguments passed to the `Interval` class, typically 
        the expected data types and the range boundaries for the validation 
        interval.
    
    inclusive : bool, optional
        Specifies whether the interval includes its bounds. Only supported 
        in scikit-learn versions that accept the `inclusive` parameter. If 
        `True`, the interval includes the bounds. Default is `None` for 
        older versions where this argument is not available.
    
    closed : str, optional
        Defines how the interval is closed. Can be "left", "right", "both", 
        or "neither". This argument is accepted by both older and newer 
        scikit-learn versions. Default is "left" (includes the left bound, 
        but excludes the right bound).
    
    kwargs : dict
        Additional keyword arguments passed to the `Interval` class for 
        compatibility, including any additional arguments required by the 
        current scikit-learn version.

    Returns
    -------
    Interval
        A compatible `Interval` object based on the scikit-learn version, 
        with or without the `inclusive` argument.
    
    Raises
    ------
    ValueError
        If an unsupported version of scikit-learn is used or the parameters 
        are not valid for the given version.
    
    Notes
    -----
    This class provides a compatibility layer for creating `Interval` 
    objects in different versions of scikit-learn. The `inclusive` argument 
    was introduced in newer versions, so this class removes it if not 
    supported in older versions. 
    
    If you are using scikit-learn versions that support the `inclusive` 
    argument (e.g., version 1.2 or later), it will be included in the call 
    to `Interval`. Otherwise, the argument will be excluded.
    
    Examples
    --------
    In newer scikit-learn versions (e.g., >=1.2), you can include the 
    `inclusive` parameter:
    
    >>> from numbers import Integral
    >>> from hwm.compat.sklearn import Interval
    >>> interval = Interval(Integral, 1, 10, closed="left", inclusive=True)
    >>> interval
    
    In older versions of scikit-learn that don't support `inclusive`, it 
    will automatically be removed:
    
    >>> interval = Interval(Integral, 1, 10, closed="left")
    >>> interval
    
    See Also
    --------
    sklearn.utils._param_validation.Interval : Original scikit-learn `Interval` 
        class used for parameter validation.
    
    References
    ----------
    .. [1] Pedregosa, F. et al. (2011). "Scikit-learn: Machine Learning in 
       Python." *Journal of Machine Learning Research*, 12, 2825-2830.
    
    .. [2] Buitinck, L., Louppe, G., Blondel, M., et al. (2013). "API design 
       for machine learning software: experiences from the scikit-learn 
       project." *arXiv preprint arXiv:1309.0238*.
    """
    
    def __new__(cls, *args, **kwargs):
        """
        Creates a compatible `Interval` object based on the scikit-learn 
        version.
        
        Parameters
        ----------
        *args : tuple
            Positional arguments for the `Interval` class.
        kwargs : dict
            Keyword arguments, including `inclusive` if supported by the 
            scikit-learn version.
        
        Returns
        -------
        sklearn.utils._param_validation.Interval
            A compatible `Interval` object.
        """
        # Check if 'inclusive' is a parameter in the __init__ method of 
        # sklearn_Interval
        signature = inspect.signature(sklearn_Interval.__init__)
        if 'inclusive' in signature.parameters:
            # 'inclusive' is supported, use kwargs as is
            return sklearn_Interval(*args, **kwargs)
        else:
            # 'inclusive' not supported, remove it from kwargs if present
            kwargs.pop('inclusive', None)
            return sklearn_Interval(*args, **kwargs)

def get_sgd_loss_param():
    """Get the correct argument of loss parameter for SGDClassifier 
    based on scikit-learn version.

    This function determines which loss parameter to use for 
    the SGDClassifier depending on the installed version of 
    scikit-learn. In versions 0.24 and newer, the loss parameter 
    should be set to 'log_loss'. In older versions, it should 
    be set to 'log'.

    Returns
    -------
    str
        The appropriate loss parameter for the SGDClassifier.

    Examples
    --------
    >>> loss_param = get_sgd_loss_param()
    >>> print(loss_param)
    'log_loss'  # If using scikit-learn 0.24 or newer

    >>> # Example usage with SGDClassifier
    >>> from sklearn.linear_model import SGDClassifier
    >>> clf = SGDClassifier(loss=get_sgd_loss_param(), max_iter=1000)
    
    Notes
    -----
    This function is useful for maintaining compatibility 
    with different versions of scikit-learn, ensuring that 
    the model behaves as expected regardless of the library 
    version being used.
    """
    
    # Use 'log' for older versions if SKLEARN_LT_1_3
    return 'log' if SKLEARN_LT_1_3 else 'log_loss'


def validate_params(params, *args, prefer_skip_nested_validation=True, **kwargs):
    """
    Compatibility wrapper for scikit-learn's `validate_params` function
    to handle versions that require the `prefer_skip_nested_validation` argument,
    with a default value that can be overridden by the user.

    Parameters
    ----------
    params : dict
        A dictionary that defines the validation rules for the parameters.
        Each key in the dictionary should represent the name of a parameter
        that requires validation, and its associated value should be a list 
        of expected types (e.g., ``[int, str]``). 
        The function will validate that the parameters passed to the 
        decorated function match the specified types.
        
        For example, if `params` is:
        
        .. code-block:: python

            params = {
                'step_name': [str],
                'n_trials': [int]
            }

        Then, the `step_name` parameter must be of type `str`, and 
        `n_trials` must be of type `int`.
    
    prefer_skip_nested_validation : bool, optional
        If ``True`` (the default), the function will attempt to skip 
        nested validation of complex objects (e.g., dictionaries or 
        lists), focusing only on the top-level structure. This option 
        can be useful for improving performance when validating large, 
        complex objects where deep validation is unnecessary.
        
        Set to ``False`` to enable deep validation of nested objects.

    *args : list
        Additional positional arguments to pass to `validate_params`.

    **kwargs : dict
        Additional keyword arguments to pass to `validate_params`. 
        These can include options such as `prefer_skip_nested_validation` 
        and other custom behavior depending on the context of validation.
    
    Returns
    -------
    function
        Returns the `validate_params` function with appropriate argument 
        handling for scikit-learn's internal parameter validation. This 
        function can be used as a decorator to ensure type safety and 
        parameter consistency in various machine learning pipelines.

    Notes
    -----
    The `validate_params` function provides a robust way to enforce 
    type and structure validation on function arguments, especially 
    in the context of machine learning workflows. By ensuring that 
    parameters adhere to a predefined structure, the function helps 
    prevent runtime errors due to unexpected types or invalid argument 
    configurations.
    
    In the case where a user sets `prefer_skip_nested_validation` to 
    ``True``, the function optimizes the validation process by skipping 
    nested structures (e.g., dictionaries or lists), focusing only on 
    validating the top-level parameters. When set to ``False``, a deeper 
    validation process occurs, checking every element within nested 
    structures.

    The validation process can be represented mathematically as:

    .. math::

        V(p_i) = 
        \begin{cases}
        1, & \text{if} \, \text{type}(p_i) \in T(p_i) \\
        0, & \text{otherwise}
        \end{cases}

    where :math:`V(p_i)` is the validation function for parameter :math:`p_i`,
    and :math:`T(p_i)` represents the set of expected types for :math:`p_i`. 
    The function returns 1 if the parameter matches the expected type, 
    otherwise 0.

    Examples
    --------
    >>> from hwm.compat.sklearn import validate_params
    >>> @validate_params({
    ...     'step_name': [str],
    ...     'param_grid': [dict],
    ...     'n_trials': [int],
    ...     'eval_metric': [str]
    ... }, prefer_skip_nested_validation=False)
    ... def tune_hyperparameters(step_name, param_grid, n_trials, eval_metric):
    ...     print(f"Hyperparameters tuned for step: {step_name}")
    ... 
    >>> tune_hyperparameters(
    ...     step_name='TrainModel', 
    ...     param_grid={'learning_rate': [0.01, 0.1]}, 
    ...     n_trials=5, 
    ...     eval_metric='accuracy'
    ... )
    Hyperparameters tuned for step: TrainModel

    See Also
    --------
    sklearn.utils.validate_params : Original scikit-learn function for parameter 
        validation. Refer to scikit-learn documentation for more detailed information.

    References
    ----------
    .. [1] Pedregosa, F. et al. (2011). "Scikit-learn: Machine Learning in Python."
       *Journal of Machine Learning Research*, 12, 2825-2830.

    .. [2] Buitinck, L., Louppe, G., Blondel, M., et al. (2013). "API design for 
       machine learning software: experiences from the scikit-learn project."
       *arXiv preprint arXiv:1309.0238*.
    """
    # Check if `prefer_skip_nested_validation` is required by inspecting the signature
    sig = inspect.signature(sklearn_validate_params)
    if 'prefer_skip_nested_validation' in sig.parameters:
        # Pass the user's choice or default for `prefer_skip_nested_validation`
        kwargs['prefer_skip_nested_validation'] = prefer_skip_nested_validation
    
    # Call the actual validate_params with appropriate arguments
    return sklearn_validate_params(params, *args, **kwargs)


def get_column_transformer_feature_names(column_transformer, input_features=None):
    """
    Get feature names from a ColumnTransformer.
    
    Parameters:
    - column_transformer : ColumnTransformer
        The ColumnTransformer object.
    - input_features : list of str, optional
        List of input feature names.
    
    Returns:
    - feature_names : list of str
        List of feature names generated by the transformers in the ColumnTransformer.
    """
    output_features = []

    # Ensure input_features is a list; if not provided, assume numerical column indices
    if input_features is None:
        input_features = list(range(column_transformer._n_features))
    
    for transformer_name, transformer, column in column_transformer.transformers_:
        if transformer == 'drop' or (
                hasattr(transformer, 'remainder') and transformer.remainder == 'drop'):
            continue

        # Resolve actual column names/indices
        actual_columns = [input_features[c] for c in column] if isinstance(
            column, (list, slice)) else [input_features[column]]

        if hasattr(transformer, 'get_feature_names_out'):
            # For transformers that support get_feature_names_out
            if hasattr(transformer, 'feature_names_in_'):
                transformer.feature_names_in_ = actual_columns
            transformer_features = transformer.get_feature_names_out()
        elif hasattr(transformer, 'get_feature_names'):
            # For transformers that support get_feature_names
            transformer_features = transformer.get_feature_names()
        else:
            # Default behavior for transformers without get_feature_names methods
            transformer_features = [f"{transformer_name}__{i}" for i in range(
                transformer.transform(column).shape[1])]
        
        output_features.extend(transformer_features)

    return output_features

def get_column_transformer_feature_names2(column_transformer, input_features=None):
    """
    Get feature names from a ColumnTransformer.
    
    Parameters:
    - column_transformer : ColumnTransformer
        The ColumnTransformer object.
    - input_features : list of str, optional
        List of input feature names.
    
    Returns:
    - feature_names : list of str
        List of feature names generated by the transformers in the ColumnTransformer.
    """
    output_features = []

    for transformer_name, transformer, column in column_transformer.transformers_:
        if transformer == 'drop' or (
                hasattr(transformer, 'remainder') and transformer.remainder == 'drop'):
            continue

        if hasattr(transformer, 'get_feature_names_out'):
            # For transformers that support get_feature_names_out
            if input_features is not None and hasattr(transformer, 'feature_names_in_'):
                # Adjust for the case where column is a list of column names or indices
                transformer_feature_names_in = [input_features[col] if isinstance(
                    column, list) else input_features[column] for col in column] if isinstance(
                        column, list) else [input_features[column]]
                transformer.feature_names_in_ = transformer_feature_names_in
            transformer_features = transformer.get_feature_names_out()
        elif hasattr(transformer, 'get_feature_names'):
            # For transformers that support get_feature_names
            transformer_features = transformer.get_feature_names()
        else:
            # Default behavior for transformers without get_feature_names methods
            transformer_features = [f"{transformer_name}__{i}" for i in range(
                transformer.transform(column).shape[1])]
        
        output_features.extend(transformer_features)

    return output_features

def get_feature_names(estimator, *args, **kwargs):
    """
    Compatibility function for fetching feature names from an estimator.
    
    Parameters:
    - estimator : estimator object
        Scikit-learn estimator from which to get feature names.
    - *args : Additional positional arguments for the get_feature_names method.
    - **kwargs : Additional keyword arguments for the get_feature_names method.

    Returns:
    - feature_names : list
        List of feature names.
    """
    if hasattr(estimator, 'get_feature_names_out'):
        # For versions of scikit-learn that support get_feature_names_out
        return estimator.get_feature_names_out(*args, **kwargs)
    elif hasattr(estimator, 'get_feature_names'):
        # For older versions of scikit-learn
        return estimator.get_feature_names(*args, **kwargs)
    else:
        raise AttributeError(
            "The estimator does not have a method to get feature names.")

def get_feature_names_out(estimator, *args, **kwargs):
    """
    Compatibility function for fetching feature names from an estimator, using
    get_feature_names_out for scikit-learn versions that support it.
    
    Parameters:
    - estimator : estimator object
        Scikit-learn estimator from which to get feature names.
    - *args : Additional positional arguments for the get_feature_names_out method.
    - **kwargs : Additional keyword arguments for the get_feature_names_out method.

    Returns:
    - feature_names_out : list
        List of feature names.
    """
    return get_feature_names(estimator, *args, **kwargs)

def get_transformers_from_column_transformer(ct):
    """
    Compatibility function to get transformers from a ColumnTransformer object.
    
    Parameters:
    - ct : ColumnTransformer
        A fitted ColumnTransformer instance.

    Returns:
    - transformers : list of tuples
        List of (name, transformer, column(s)) tuples.
    """
    if hasattr(ct, 'transformers_'):
        return ct.transformers_
    else:
        raise AttributeError(
            "The ColumnTransformer instance does not have a 'transformers_' attribute.")

[docs] def check_is_fitted(estimator, attributes=None, msg=None, all_or_any=all): """ Compatibility wrapper for scikit-learn's check_is_fitted function. Parameters: - estimator : estimator instance The estimator to check. - attributes : str or list of str, optional The attributes to check for. - msg : str, optional The message to display on failure. - all_or_any : callable, optional all or any; whether all or any of the given attributes must be present. Returns: - None """ return sklearn_check_is_fitted(estimator, attributes, msg, all_or_any)
def adjusted_mutual_info_score( labels_true, labels_pred, average_method='arithmetic'): """ Compatibility function for adjusted_mutual_info_score with the average_method parameter. Parameters: - labels_true : array-like of shape (n_samples,) Ground truth class labels. - labels_pred : array-like of shape (n_samples,) Cluster labels to evaluate. - average_method : str, default='arithmetic' The method to average the mutual information scores. Versions of scikit-learn before 0.22.0 do not have this parameter and use 'arithmetic' by default. Returns: - ami : float Adjusted Mutual Information Score. """ from sklearn.metrics import adjusted_mutual_info_score as ami_score if SKLEARN_LT_0_22: return ami_score(labels_true, labels_pred) else: return ami_score(labels_true, labels_pred, average_method=average_method) def fetch_openml(*args, **kwargs): """ Compatibility function for fetch_openml to ensure consistent return type. Parameters: - args, kwargs: Arguments and keyword arguments for sklearn.datasets.fetch_openml. Returns: - data : Bunch Dictionary-like object with all the data and metadata. """ from sklearn.datasets import fetch_openml if 'as_frame' not in kwargs and not SKLEARN_LT_0_24: kwargs['as_frame'] = True return fetch_openml(*args, **kwargs) def plot_confusion_matrix(estimator, X, y_true, *args, **kwargs): """ Compatibility function for plot_confusion_matrix across scikit-learn versions. Parameters: - estimator : estimator instance Fitted classifier. - X : array-like of shape (n_samples, n_features) Input values. - y_true : array-like of shape (n_samples,) True labels for X. Returns: - display : ConfusionMatrixDisplay Object that stores the confusion matrix display. """ try: from sklearn.metrics import plot_confusion_matrix except ImportError: # Assume older version without plot_confusion_matrix # Implement fallback or raise informative error raise NotImplementedError( "plot_confusion_matrix not available in your sklearn version.") return plot_confusion_matrix(estimator, X, y_true, *args, **kwargs) def train_test_split(*args, **kwargs): """ Compatibility wrapper for train_test_split to ensure consistent behavior. Parameters: - args, kwargs: Arguments and keyword arguments for sklearn.model_selection.train_test_split. """ from sklearn.model_selection import train_test_split if 'shuffle' not in kwargs: kwargs['shuffle'] = True return train_test_split(*args, **kwargs) def get_transformer_feature_names(transformer, input_features=None): """ Compatibility function to get feature names from transformers like OneHotEncoder in scikit-learn, taking into account changes in method names across versions. Parameters: - transformer : sklearn transformer instance The transformer instance from which to get feature names. - input_features : list of str, optional List of input feature names to the transformer. Required for transformers that support `get_feature_names` method which requires input feature names. Returns: - feature_names : list of str List of feature names generated by the transformer. """ if hasattr(transformer, 'get_feature_names_out'): # Use get_feature_names_out if available (preferable in newer versions) return transformer.get_feature_names_out(input_features) elif hasattr(transformer, 'get_feature_names'): # Fallback to get_feature_names for compatibility with older versions if input_features is not None: return transformer.get_feature_names(input_features) else: return transformer.get_feature_names() else: # Raise error if neither method is available raise AttributeError( f"{transformer.__class__.__name__} does not support feature name extraction.") def get_pipeline_feature_names(pipeline, input_features=None): """ Compatibility function to safely extract feature names from a pipeline, especially when it contains transformers like SimpleImputer that do not support get_feature_names_out directly. Parameters: - pipeline : sklearn Pipeline instance The pipeline instance from which to extract feature names. - input_features : list of str, optional List of input feature names to the pipeline. Required for transformers that support `get_feature_names` or `get_feature_names_out` methods which require input feature names. Returns: - feature_names : list of str List of feature names generated by the pipeline. """ import numpy as np if input_features is None: input_features = [] # Initialize with input features current_features = np.array(input_features) # Iterate through transformers in the pipeline for name, transformer in pipeline.steps: if hasattr(transformer, 'get_feature_names_out'): # Transformer supports get_feature_names_out current_features = transformer.get_feature_names_out(current_features) elif hasattr(transformer, 'get_feature_names'): # Transformer supports get_feature_names and requires current feature names current_features = transformer.get_feature_names(current_features) elif hasattr(transformer, 'categories_'): # Handle OneHotEncoder separately current_features = np.concatenate(transformer.categories_) else: # For transformers that do not modify feature names # or do not provide a method to get feature names continue # Ensure output is a list of strings feature_names = list(map(str, current_features)) return feature_names __all__.extend([ "fetch_openml", "plot_confusion_matrix", "get_transformer_feature_names", "get_pipeline_feature_names" ])