Compatibility Utilities

This section provides documentation for compatibility utilities in the hwm.compat module. These utilities ensure smooth operation across various versions of scikit-learn (sklearn) by handling breaking changes and deprecated features. The module includes resampling utilities, scorer functions, and compatibility checks to maintain functionality across different sklearn versions.

Interval

class Interval(*args, inclusive=None, closed='left', **kwargs)

Compatibility wrapper for scikit-learn’s Interval class to handle versions that do not include the inclusive argument.

Parameters

*argstuple

Positional arguments passed to the Interval class, typically the expected data types and the range boundaries for the validation interval.

inclusivebool, optional

Specifies whether the interval includes its bounds. Only supported in scikit-learn versions that accept the inclusive parameter. If True, the interval includes the bounds. Default is None for older versions where this argument is not available.

closedstr, optional

Defines how the interval is closed. Can be “left”, “right”, “both”, or “neither”. This argument is accepted by both older and newer scikit-learn versions. Default is “left” (includes the left bound, but excludes the right bound).

**kwargsdict

Additional keyword arguments passed to the Interval class for compatibility, including any additional arguments required by the current scikit-learn version.

Returns

Interval

A compatible Interval object based on the scikit-learn version, with or without the inclusive argument.

Raises

ValueError

If an unsupported version of scikit-learn is used or the parameters are not valid for the given version.

Notes

This class provides a compatibility layer for creating Interval objects in different versions of scikit-learn. The inclusive argument was introduced in newer versions, so this class removes it if not supported in older versions.

If you are using scikit-learn versions that support the inclusive argument (e.g., version 1.2 or later), it will be included in the call to Interval. Otherwise, the argument will be excluded.

Examples

In newer scikit-learn versions (e.g., >=1.2), you can include the inclusive parameter:

1from numbers import Integral
2from hwm.compat import Interval
3
4# Create Interval with inclusive=True
5interval = Interval(Integral, 1, 10, closed="left", inclusive=True)
6print(interval)
7# Output: Interval(Integral, 1, 10, closed='left')

In older versions of scikit-learn that don’t support inclusive, it will automatically be removed:

1from numbers import Integral
2from hwm.compat import Interval
3
4# Create Interval without inclusive
5interval = Interval(Integral, 1, 10, closed="left")
6print(interval)
7# Output: Interval(Integral, 1, 10, closed='left')

See Also

sklearn.utils._param_validation.IntervalOriginal scikit-learn

Interval class used for parameter validation.

References

get_sgd_loss_param

get_sgd_loss_param()

Get the correct argument of loss parameter for SGDClassifier based on scikit-learn version.

This function determines which loss parameter to use for the SGDClassifier depending on the installed version of scikit-learn. In versions 0.24 and newer, the loss parameter should be set to ‘log_loss’. In older versions, it should be set to ‘log’.

Returns

str

The appropriate loss parameter for the SGDClassifier.

Examples

The following examples demonstrate how to use the get_sgd_loss_param function to obtain the correct loss parameter for SGDClassifier.

Basic Example:

 1from hwm.compat import get_sgd_loss_param
 2from sklearn.linear_model import SGDClassifier
 3
 4# Get the appropriate loss parameter
 5loss_param = get_sgd_loss_param()
 6print(loss_param)
 7# Output: 'log_loss'  # If using scikit-learn 0.24 or newer
 8
 9# Example usage with SGDClassifier
10clf = SGDClassifier(loss=get_sgd_loss_param(), max_iter=1000)
11clf.fit(X_train, y_train)

Notes

This function is useful for maintaining compatibility with different versions of scikit-learn, ensuring that the model behaves as expected regardless of the library version being used.

See Also

sklearn.linear_model.SGDClassifierLinear classifier with

SGD training.

References

validate_params

validate_params

validate_params(params, *args, prefer_skip_nested_validation=True, **kwargs)

Compatibility wrapper for scikit-learn’s validate_params function to handle versions that require the prefer_skip_nested_validation argument, with a default value that can be overridden by the user.

Parameters

params

A dictionary that defines the validation rules for the parameters. Each key in the dictionary should represent the name of a parameter that requires validation, and its associated value should be a list of expected types (e.g., [int, str]). The function will validate that the parameters passed to the decorated function match the specified types.

For example, if params is:

params = {
    'step_name': [str],
    'n_trials': [int]
}

Then, the step_name parameter must be of type str, and n_trials must be of type int.

prefer_skip_nested_validation

If True (the default), the function will attempt to skip nested validation of complex objects (e.g., dictionaries or lists), focusing only on the top-level structure. This option can be useful for improving performance when validating large, complex objects where deep validation is unnecessary.

Set to False to enable deep validation of nested objects.

args

Additional positional arguments to pass to validate_params.

kwargs

Additional keyword arguments to pass to validate_params. These can include options such as prefer_skip_nested_validation and other custom behavior depending on the context of validation.

Returns

function

Returns the validate_params function with appropriate argument handling for scikit-learn’s internal parameter validation. This function can be used as a decorator to ensure type safety and parameter consistency in various machine learning pipelines.

Notes

The validate_params function provides a robust way to enforce type and structure validation on function arguments, especially in the context of machine learning workflows. By ensuring that parameters adhere to a predefined structure, the function helps prevent runtime errors due to unexpected types or invalid argument configurations.

In the case where a user sets prefer_skip_nested_validation to True, the function optimizes the validation process by skipping nested structures (e.g., dictionaries or lists), focusing only on validating the top-level parameters. When set to False, a deeper validation process occurs, checking every element within nested structures.

The validation process can be represented mathematically as:

\[\begin{split}V(p_i) = \begin{cases} 1, & \text{if} \, \text{type}(p_i) \in T(p_i) \\ 0, & \text{otherwise} \end{cases}\end{split}\]

where \(V(p_i)\) is the validation function for parameter \(p_i\), and \(T(p_i)\) represents the set of expected types for \(p_i\). The function returns 1 if the parameter matches the expected type, otherwise 0.

Examples

The following examples demonstrate how to use the validate_params function to enforce parameter validation in machine learning pipelines.

Basic Example:

Ensuring that parameters match expected types using the validate_params decorator.

 1from hwm.compat import validate_params
 2
 3@validate_params({
 4    'step_name': [str],
 5    'param_grid': [dict],
 6    'n_trials': [int],
 7    'eval_metric': [str]
 8}, prefer_skip_nested_validation=False)
 9def tune_hyperparameters(step_name, param_grid, n_trials, eval_metric):
10    print(f"Hyperparameters tuned for step: {step_name}")
11
12# Correct usage
13tune_hyperparameters(
14    step_name='TrainModel',
15    param_grid={'learning_rate': [0.01, 0.1]},
16    n_trials=5,
17    eval_metric='accuracy'
18)
19# Output: Hyperparameters tuned for step: TrainModel

Incorrect Usage:

Attempting to pass parameters with incorrect types will raise a validation error.

 1from hwm.compat import validate_params
 2
 3@validate_params({
 4    'step_name': [str],
 5    'n_trials': [int]
 6})
 7def initialize_step(step_name, n_trials):
 8    pass
 9
10# Incorrect usage: n_trials should be int
11initialize_step(step_name='Init', n_trials='five')
12# Raises: ValueError: Parameter 'n_trials' must be of type int.

See Also

sklearn.utils.validate_params()Original scikit-learn function for

parameter validation. Refer to scikit-learn documentation for more detailed information.

References