Videre
This commit is contained in:
@@ -0,0 +1,313 @@
|
||||
"""Base class for ensemble-based estimators."""
|
||||
|
||||
# Authors: The scikit-learn developers
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
import numpy as np
|
||||
from joblib import effective_n_jobs
|
||||
|
||||
from sklearn.base import (
|
||||
BaseEstimator,
|
||||
MetaEstimatorMixin,
|
||||
clone,
|
||||
is_classifier,
|
||||
is_regressor,
|
||||
)
|
||||
from sklearn.utils import Bunch, check_random_state
|
||||
from sklearn.utils._tags import get_tags
|
||||
from sklearn.utils._user_interface import _print_elapsed_time
|
||||
from sklearn.utils.metadata_routing import _routing_enabled
|
||||
from sklearn.utils.metaestimators import _BaseComposition
|
||||
|
||||
|
||||
def _fit_single_estimator(
|
||||
estimator, X, y, fit_params, message_clsname=None, message=None
|
||||
):
|
||||
"""Private function used to fit an estimator within a job."""
|
||||
# TODO(SLEP6): remove if-condition for unrouted sample_weight when metadata
|
||||
# routing can't be disabled.
|
||||
if not _routing_enabled() and "sample_weight" in fit_params:
|
||||
try:
|
||||
with _print_elapsed_time(message_clsname, message):
|
||||
estimator.fit(X, y, sample_weight=fit_params["sample_weight"])
|
||||
except TypeError as exc:
|
||||
if "unexpected keyword argument 'sample_weight'" in str(exc):
|
||||
raise TypeError(
|
||||
"Underlying estimator {} does not support sample weights.".format(
|
||||
estimator.__class__.__name__
|
||||
)
|
||||
) from exc
|
||||
raise
|
||||
else:
|
||||
with _print_elapsed_time(message_clsname, message):
|
||||
estimator.fit(X, y, **fit_params)
|
||||
return estimator
|
||||
|
||||
|
||||
def _set_random_states(estimator, random_state=None):
|
||||
"""Set fixed random_state parameters for an estimator.
|
||||
|
||||
Finds all parameters ending ``random_state`` and sets them to integers
|
||||
derived from ``random_state``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
estimator : estimator supporting get/set_params
|
||||
Estimator with potential randomness managed by random_state
|
||||
parameters.
|
||||
|
||||
random_state : int, RandomState instance or None, default=None
|
||||
Pseudo-random number generator to control the generation of the random
|
||||
integers. Pass an int for reproducible output across multiple function
|
||||
calls.
|
||||
See :term:`Glossary <random_state>`.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This does not necessarily set *all* ``random_state`` attributes that
|
||||
control an estimator's randomness, only those accessible through
|
||||
``estimator.get_params()``. ``random_state``s not controlled include
|
||||
those belonging to:
|
||||
|
||||
* cross-validation splitters
|
||||
* ``scipy.stats`` rvs
|
||||
"""
|
||||
random_state = check_random_state(random_state)
|
||||
to_set = {}
|
||||
for key in sorted(estimator.get_params(deep=True)):
|
||||
if key == "random_state" or key.endswith("__random_state"):
|
||||
to_set[key] = random_state.randint(np.iinfo(np.int32).max)
|
||||
|
||||
if to_set:
|
||||
estimator.set_params(**to_set)
|
||||
|
||||
|
||||
class BaseEnsemble(MetaEstimatorMixin, BaseEstimator, metaclass=ABCMeta):
|
||||
"""Base class for all ensemble classes.
|
||||
|
||||
Warning: This class should not be used directly. Use derived classes
|
||||
instead.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
estimator : object
|
||||
The base estimator from which the ensemble is built.
|
||||
|
||||
n_estimators : int, default=10
|
||||
The number of estimators in the ensemble.
|
||||
|
||||
estimator_params : list of str, default=tuple()
|
||||
The list of attributes to use as parameters when instantiating a
|
||||
new base estimator. If none are given, default parameters are used.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
estimator_ : estimator
|
||||
The base estimator from which the ensemble is grown.
|
||||
|
||||
estimators_ : list of estimators
|
||||
The collection of fitted base estimators.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
estimator=None,
|
||||
*,
|
||||
n_estimators=10,
|
||||
estimator_params=tuple(),
|
||||
):
|
||||
# Set parameters
|
||||
self.estimator = estimator
|
||||
self.n_estimators = n_estimators
|
||||
self.estimator_params = estimator_params
|
||||
|
||||
# Don't instantiate estimators now! Parameters of estimator might
|
||||
# still change. Eg., when grid-searching with the nested object syntax.
|
||||
# self.estimators_ needs to be filled by the derived classes in fit.
|
||||
|
||||
def _validate_estimator(self, default=None):
|
||||
"""Check the base estimator.
|
||||
|
||||
Sets the `estimator_` attributes.
|
||||
"""
|
||||
if self.estimator is not None:
|
||||
self.estimator_ = self.estimator
|
||||
else:
|
||||
self.estimator_ = default
|
||||
|
||||
def _make_estimator(self, append=True, random_state=None):
|
||||
"""Make and configure a copy of the `estimator_` attribute.
|
||||
|
||||
Warning: This method should be used to properly instantiate new
|
||||
sub-estimators.
|
||||
"""
|
||||
estimator = clone(self.estimator_)
|
||||
estimator.set_params(**{p: getattr(self, p) for p in self.estimator_params})
|
||||
|
||||
if random_state is not None:
|
||||
_set_random_states(estimator, random_state)
|
||||
|
||||
if append:
|
||||
self.estimators_.append(estimator)
|
||||
|
||||
return estimator
|
||||
|
||||
def __len__(self):
|
||||
"""Return the number of estimators in the ensemble."""
|
||||
return len(self.estimators_)
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""Return the index'th estimator in the ensemble."""
|
||||
return self.estimators_[index]
|
||||
|
||||
def __iter__(self):
|
||||
"""Return iterator over estimators in the ensemble."""
|
||||
return iter(self.estimators_)
|
||||
|
||||
|
||||
def _partition_estimators(n_estimators, n_jobs):
|
||||
"""Private function used to partition estimators between jobs."""
|
||||
# Compute the number of jobs
|
||||
n_jobs = min(effective_n_jobs(n_jobs), n_estimators)
|
||||
|
||||
# Partition estimators between jobs
|
||||
n_estimators_per_job = np.full(n_jobs, n_estimators // n_jobs, dtype=int)
|
||||
n_estimators_per_job[: n_estimators % n_jobs] += 1
|
||||
starts = np.cumsum(n_estimators_per_job)
|
||||
|
||||
return n_jobs, n_estimators_per_job.tolist(), [0] + starts.tolist()
|
||||
|
||||
|
||||
class _BaseHeterogeneousEnsemble(
|
||||
MetaEstimatorMixin, _BaseComposition, metaclass=ABCMeta
|
||||
):
|
||||
"""Base class for heterogeneous ensemble of learners.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
estimators : list of (str, estimator) tuples
|
||||
The ensemble of estimators to use in the ensemble. Each element of the
|
||||
list is defined as a tuple of string (i.e. name of the estimator) and
|
||||
an estimator instance. An estimator can be set to `'drop'` using
|
||||
`set_params`.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
estimators_ : list of estimators
|
||||
The elements of the estimators parameter, having been fitted on the
|
||||
training data. If an estimator has been set to `'drop'`, it will not
|
||||
appear in `estimators_`.
|
||||
"""
|
||||
|
||||
@property
|
||||
def named_estimators(self):
|
||||
"""Dictionary to access any fitted sub-estimators by name.
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class:`~sklearn.utils.Bunch`
|
||||
"""
|
||||
return Bunch(**dict(self.estimators))
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, estimators):
|
||||
self.estimators = estimators
|
||||
|
||||
def _validate_estimators(self):
|
||||
if len(self.estimators) == 0 or not all(
|
||||
isinstance(item, (tuple, list)) and isinstance(item[0], str)
|
||||
for item in self.estimators
|
||||
):
|
||||
raise ValueError(
|
||||
"Invalid 'estimators' attribute, 'estimators' should be a "
|
||||
"non-empty list of (string, estimator) tuples."
|
||||
)
|
||||
names, estimators = zip(*self.estimators)
|
||||
# defined by MetaEstimatorMixin
|
||||
self._validate_names(names)
|
||||
|
||||
has_estimator = any(est != "drop" for est in estimators)
|
||||
if not has_estimator:
|
||||
raise ValueError(
|
||||
"All estimators are dropped. At least one is required "
|
||||
"to be an estimator."
|
||||
)
|
||||
|
||||
is_estimator_type = is_classifier if is_classifier(self) else is_regressor
|
||||
|
||||
for est in estimators:
|
||||
if est != "drop" and not is_estimator_type(est):
|
||||
raise ValueError(
|
||||
"The estimator {} should be a {}.".format(
|
||||
est.__class__.__name__, is_estimator_type.__name__[3:]
|
||||
)
|
||||
)
|
||||
|
||||
return names, estimators
|
||||
|
||||
def set_params(self, **params):
|
||||
"""
|
||||
Set the parameters of an estimator from the ensemble.
|
||||
|
||||
Valid parameter keys can be listed with `get_params()`. Note that you
|
||||
can directly set the parameters of the estimators contained in
|
||||
`estimators`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
**params : keyword arguments
|
||||
Specific parameters using e.g.
|
||||
`set_params(parameter_name=new_value)`. In addition, to setting the
|
||||
parameters of the estimator, the individual estimator of the
|
||||
estimators can also be set, or can be removed by setting them to
|
||||
'drop'.
|
||||
|
||||
Returns
|
||||
-------
|
||||
self : object
|
||||
Estimator instance.
|
||||
"""
|
||||
super()._set_params("estimators", **params)
|
||||
return self
|
||||
|
||||
def get_params(self, deep=True):
|
||||
"""
|
||||
Get the parameters of an estimator from the ensemble.
|
||||
|
||||
Returns the parameters given in the constructor as well as the
|
||||
estimators contained within the `estimators` parameter.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
deep : bool, default=True
|
||||
Setting it to True gets the various estimators and the parameters
|
||||
of the estimators as well.
|
||||
|
||||
Returns
|
||||
-------
|
||||
params : dict
|
||||
Parameter and estimator names mapped to their values or parameter
|
||||
names mapped to their values.
|
||||
"""
|
||||
return super()._get_params("estimators", deep=deep)
|
||||
|
||||
def __sklearn_tags__(self):
|
||||
tags = super().__sklearn_tags__()
|
||||
try:
|
||||
tags.input_tags.allow_nan = all(
|
||||
get_tags(est[1]).input_tags.allow_nan if est[1] != "drop" else True
|
||||
for est in self.estimators
|
||||
)
|
||||
tags.input_tags.sparse = all(
|
||||
get_tags(est[1]).input_tags.sparse if est[1] != "drop" else True
|
||||
for est in self.estimators
|
||||
)
|
||||
except Exception:
|
||||
# If `estimators` does not comply with our API (list of tuples) then it will
|
||||
# fail. In this case, we assume that `allow_nan` and `sparse` are False but
|
||||
# the parameter validation will raise an error during `fit`.
|
||||
pass # pragma: no cover
|
||||
return tags
|
||||
Reference in New Issue
Block a user