Videre
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,635 @@
|
||||
"""
|
||||
Testing for export functions of decision trees (sklearn.tree.export).
|
||||
"""
|
||||
|
||||
from io import StringIO
|
||||
from re import finditer, search
|
||||
from textwrap import dedent
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from numpy.random import RandomState
|
||||
|
||||
from sklearn.base import is_classifier
|
||||
from sklearn.ensemble import GradientBoostingClassifier
|
||||
from sklearn.exceptions import NotFittedError
|
||||
from sklearn.tree import (
|
||||
DecisionTreeClassifier,
|
||||
DecisionTreeRegressor,
|
||||
export_graphviz,
|
||||
export_text,
|
||||
plot_tree,
|
||||
)
|
||||
|
||||
# toy sample
|
||||
X = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]]
|
||||
y = [-1, -1, -1, 1, 1, 1]
|
||||
y2 = [[-1, 1], [-1, 1], [-1, 1], [1, 2], [1, 2], [1, 3]]
|
||||
w = [1, 1, 1, 0.5, 0.5, 0.5]
|
||||
y_degraded = [1, 1, 1, 1, 1, 1]
|
||||
|
||||
|
||||
def test_graphviz_toy():
|
||||
# Check correctness of export_graphviz
|
||||
clf = DecisionTreeClassifier(
|
||||
max_depth=3, min_samples_split=2, criterion="gini", random_state=2
|
||||
)
|
||||
clf.fit(X, y)
|
||||
|
||||
# Test export code
|
||||
contents1 = export_graphviz(clf, out_file=None)
|
||||
contents2 = (
|
||||
"digraph Tree {\n"
|
||||
'node [shape=box, fontname="helvetica"] ;\n'
|
||||
'edge [fontname="helvetica"] ;\n'
|
||||
'0 [label="x[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n'
|
||||
'value = [3, 3]"] ;\n'
|
||||
'1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]"] ;\n'
|
||||
"0 -> 1 [labeldistance=2.5, labelangle=45, "
|
||||
'headlabel="True"] ;\n'
|
||||
'2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]"] ;\n'
|
||||
"0 -> 2 [labeldistance=2.5, labelangle=-45, "
|
||||
'headlabel="False"] ;\n'
|
||||
"}"
|
||||
)
|
||||
assert contents1 == contents2
|
||||
|
||||
# Test with feature_names
|
||||
contents1 = export_graphviz(
|
||||
clf, feature_names=["feature0", "feature1"], out_file=None
|
||||
)
|
||||
contents2 = (
|
||||
"digraph Tree {\n"
|
||||
'node [shape=box, fontname="helvetica"] ;\n'
|
||||
'edge [fontname="helvetica"] ;\n'
|
||||
'0 [label="feature0 <= 0.0\\ngini = 0.5\\nsamples = 6\\n'
|
||||
'value = [3, 3]"] ;\n'
|
||||
'1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]"] ;\n'
|
||||
"0 -> 1 [labeldistance=2.5, labelangle=45, "
|
||||
'headlabel="True"] ;\n'
|
||||
'2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]"] ;\n'
|
||||
"0 -> 2 [labeldistance=2.5, labelangle=-45, "
|
||||
'headlabel="False"] ;\n'
|
||||
"}"
|
||||
)
|
||||
|
||||
assert contents1 == contents2
|
||||
|
||||
# Test with feature_names (escaped)
|
||||
contents1 = export_graphviz(
|
||||
clf, feature_names=['feature"0"', 'feature"1"'], out_file=None
|
||||
)
|
||||
contents2 = (
|
||||
"digraph Tree {\n"
|
||||
'node [shape=box, fontname="helvetica"] ;\n'
|
||||
'edge [fontname="helvetica"] ;\n'
|
||||
'0 [label="feature\\"0\\" <= 0.0\\n'
|
||||
"gini = 0.5\\nsamples = 6\\n"
|
||||
'value = [3, 3]"] ;\n'
|
||||
'1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]"] ;\n'
|
||||
"0 -> 1 [labeldistance=2.5, labelangle=45, "
|
||||
'headlabel="True"] ;\n'
|
||||
'2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]"] ;\n'
|
||||
"0 -> 2 [labeldistance=2.5, labelangle=-45, "
|
||||
'headlabel="False"] ;\n'
|
||||
"}"
|
||||
)
|
||||
|
||||
assert contents1 == contents2
|
||||
|
||||
# Test with class_names
|
||||
contents1 = export_graphviz(clf, class_names=["yes", "no"], out_file=None)
|
||||
contents2 = (
|
||||
"digraph Tree {\n"
|
||||
'node [shape=box, fontname="helvetica"] ;\n'
|
||||
'edge [fontname="helvetica"] ;\n'
|
||||
'0 [label="x[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n'
|
||||
'value = [3, 3]\\nclass = yes"] ;\n'
|
||||
'1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]\\n'
|
||||
'class = yes"] ;\n'
|
||||
"0 -> 1 [labeldistance=2.5, labelangle=45, "
|
||||
'headlabel="True"] ;\n'
|
||||
'2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]\\n'
|
||||
'class = no"] ;\n'
|
||||
"0 -> 2 [labeldistance=2.5, labelangle=-45, "
|
||||
'headlabel="False"] ;\n'
|
||||
"}"
|
||||
)
|
||||
|
||||
assert contents1 == contents2
|
||||
|
||||
# Test with class_names (escaped)
|
||||
contents1 = export_graphviz(clf, class_names=['"yes"', '"no"'], out_file=None)
|
||||
contents2 = (
|
||||
"digraph Tree {\n"
|
||||
'node [shape=box, fontname="helvetica"] ;\n'
|
||||
'edge [fontname="helvetica"] ;\n'
|
||||
'0 [label="x[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n'
|
||||
'value = [3, 3]\\nclass = \\"yes\\""] ;\n'
|
||||
'1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]\\n'
|
||||
'class = \\"yes\\""] ;\n'
|
||||
"0 -> 1 [labeldistance=2.5, labelangle=45, "
|
||||
'headlabel="True"] ;\n'
|
||||
'2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]\\n'
|
||||
'class = \\"no\\""] ;\n'
|
||||
"0 -> 2 [labeldistance=2.5, labelangle=-45, "
|
||||
'headlabel="False"] ;\n'
|
||||
"}"
|
||||
)
|
||||
|
||||
assert contents1 == contents2
|
||||
|
||||
# Test plot_options
|
||||
contents1 = export_graphviz(
|
||||
clf,
|
||||
filled=True,
|
||||
impurity=False,
|
||||
proportion=True,
|
||||
special_characters=True,
|
||||
rounded=True,
|
||||
out_file=None,
|
||||
fontname="sans",
|
||||
)
|
||||
contents2 = (
|
||||
"digraph Tree {\n"
|
||||
'node [shape=box, style="filled, rounded", color="black", '
|
||||
'fontname="sans"] ;\n'
|
||||
'edge [fontname="sans"] ;\n'
|
||||
"0 [label=<x<SUB>0</SUB> ≤ 0.0<br/>samples = 100.0%<br/>"
|
||||
'value = [0.5, 0.5]>, fillcolor="#ffffff"] ;\n'
|
||||
"1 [label=<samples = 50.0%<br/>value = [1.0, 0.0]>, "
|
||||
'fillcolor="#e58139"] ;\n'
|
||||
"0 -> 1 [labeldistance=2.5, labelangle=45, "
|
||||
'headlabel="True"] ;\n'
|
||||
"2 [label=<samples = 50.0%<br/>value = [0.0, 1.0]>, "
|
||||
'fillcolor="#399de5"] ;\n'
|
||||
"0 -> 2 [labeldistance=2.5, labelangle=-45, "
|
||||
'headlabel="False"] ;\n'
|
||||
"}"
|
||||
)
|
||||
|
||||
assert contents1 == contents2
|
||||
|
||||
# Test max_depth
|
||||
contents1 = export_graphviz(clf, max_depth=0, class_names=True, out_file=None)
|
||||
contents2 = (
|
||||
"digraph Tree {\n"
|
||||
'node [shape=box, fontname="helvetica"] ;\n'
|
||||
'edge [fontname="helvetica"] ;\n'
|
||||
'0 [label="x[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n'
|
||||
'value = [3, 3]\\nclass = y[0]"] ;\n'
|
||||
'1 [label="(...)"] ;\n'
|
||||
"0 -> 1 ;\n"
|
||||
'2 [label="(...)"] ;\n'
|
||||
"0 -> 2 ;\n"
|
||||
"}"
|
||||
)
|
||||
|
||||
assert contents1 == contents2
|
||||
|
||||
# Test max_depth with plot_options
|
||||
contents1 = export_graphviz(
|
||||
clf, max_depth=0, filled=True, out_file=None, node_ids=True
|
||||
)
|
||||
contents2 = (
|
||||
"digraph Tree {\n"
|
||||
'node [shape=box, style="filled", color="black", '
|
||||
'fontname="helvetica"] ;\n'
|
||||
'edge [fontname="helvetica"] ;\n'
|
||||
'0 [label="node #0\\nx[0] <= 0.0\\ngini = 0.5\\n'
|
||||
'samples = 6\\nvalue = [3, 3]", fillcolor="#ffffff"] ;\n'
|
||||
'1 [label="(...)", fillcolor="#C0C0C0"] ;\n'
|
||||
"0 -> 1 ;\n"
|
||||
'2 [label="(...)", fillcolor="#C0C0C0"] ;\n'
|
||||
"0 -> 2 ;\n"
|
||||
"}"
|
||||
)
|
||||
|
||||
assert contents1 == contents2
|
||||
|
||||
# Test multi-output with weighted samples
|
||||
clf = DecisionTreeClassifier(
|
||||
max_depth=2, min_samples_split=2, criterion="gini", random_state=2
|
||||
)
|
||||
clf = clf.fit(X, y2, sample_weight=w)
|
||||
|
||||
contents1 = export_graphviz(clf, filled=True, impurity=False, out_file=None)
|
||||
contents2 = (
|
||||
"digraph Tree {\n"
|
||||
'node [shape=box, style="filled", color="black", '
|
||||
'fontname="helvetica"] ;\n'
|
||||
'edge [fontname="helvetica"] ;\n'
|
||||
'0 [label="x[0] <= 0.0\\nsamples = 6\\n'
|
||||
"value = [[3.0, 1.5, 0.0]\\n"
|
||||
'[3.0, 1.0, 0.5]]", fillcolor="#ffffff"] ;\n'
|
||||
'1 [label="samples = 3\\nvalue = [[3, 0, 0]\\n'
|
||||
'[3, 0, 0]]", fillcolor="#e58139"] ;\n'
|
||||
"0 -> 1 [labeldistance=2.5, labelangle=45, "
|
||||
'headlabel="True"] ;\n'
|
||||
'2 [label="x[0] <= 1.5\\nsamples = 3\\n'
|
||||
"value = [[0.0, 1.5, 0.0]\\n"
|
||||
'[0.0, 1.0, 0.5]]", fillcolor="#f1bd97"] ;\n'
|
||||
"0 -> 2 [labeldistance=2.5, labelangle=-45, "
|
||||
'headlabel="False"] ;\n'
|
||||
'3 [label="samples = 2\\nvalue = [[0, 1, 0]\\n'
|
||||
'[0, 1, 0]]", fillcolor="#e58139"] ;\n'
|
||||
"2 -> 3 ;\n"
|
||||
'4 [label="samples = 1\\nvalue = [[0.0, 0.5, 0.0]\\n'
|
||||
'[0.0, 0.0, 0.5]]", fillcolor="#e58139"] ;\n'
|
||||
"2 -> 4 ;\n"
|
||||
"}"
|
||||
)
|
||||
|
||||
assert contents1 == contents2
|
||||
|
||||
# Test regression output with plot_options
|
||||
clf = DecisionTreeRegressor(
|
||||
max_depth=3, min_samples_split=2, criterion="squared_error", random_state=2
|
||||
)
|
||||
clf.fit(X, y)
|
||||
|
||||
contents1 = export_graphviz(
|
||||
clf,
|
||||
filled=True,
|
||||
leaves_parallel=True,
|
||||
out_file=None,
|
||||
rotate=True,
|
||||
rounded=True,
|
||||
fontname="sans",
|
||||
)
|
||||
contents2 = (
|
||||
"digraph Tree {\n"
|
||||
'node [shape=box, style="filled, rounded", color="black", '
|
||||
'fontname="sans"] ;\n'
|
||||
"graph [ranksep=equally, splines=polyline] ;\n"
|
||||
'edge [fontname="sans"] ;\n'
|
||||
"rankdir=LR ;\n"
|
||||
'0 [label="x[0] <= 0.0\\nsquared_error = 1.0\\nsamples = 6\\n'
|
||||
'value = 0.0", fillcolor="#f2c09c"] ;\n'
|
||||
'1 [label="squared_error = 0.0\\nsamples = 3\\'
|
||||
'nvalue = -1.0", '
|
||||
'fillcolor="#ffffff"] ;\n'
|
||||
"0 -> 1 [labeldistance=2.5, labelangle=-45, "
|
||||
'headlabel="True"] ;\n'
|
||||
'2 [label="squared_error = 0.0\\nsamples = 3\\nvalue = 1.0", '
|
||||
'fillcolor="#e58139"] ;\n'
|
||||
"0 -> 2 [labeldistance=2.5, labelangle=45, "
|
||||
'headlabel="False"] ;\n'
|
||||
"{rank=same ; 0} ;\n"
|
||||
"{rank=same ; 1; 2} ;\n"
|
||||
"}"
|
||||
)
|
||||
|
||||
assert contents1 == contents2
|
||||
|
||||
# Test classifier with degraded learning set
|
||||
clf = DecisionTreeClassifier(max_depth=3)
|
||||
clf.fit(X, y_degraded)
|
||||
|
||||
contents1 = export_graphviz(clf, filled=True, out_file=None)
|
||||
contents2 = (
|
||||
"digraph Tree {\n"
|
||||
'node [shape=box, style="filled", color="black", '
|
||||
'fontname="helvetica"] ;\n'
|
||||
'edge [fontname="helvetica"] ;\n'
|
||||
'0 [label="gini = 0.0\\nsamples = 6\\nvalue = 6.0", '
|
||||
'fillcolor="#ffffff"] ;\n'
|
||||
"}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("constructor", [list, np.array])
|
||||
def test_graphviz_feature_class_names_array_support(constructor):
|
||||
# Check that export_graphviz treats feature names
|
||||
# and class names correctly and supports arrays
|
||||
clf = DecisionTreeClassifier(
|
||||
max_depth=3, min_samples_split=2, criterion="gini", random_state=2
|
||||
)
|
||||
clf.fit(X, y)
|
||||
|
||||
# Test with feature_names
|
||||
contents1 = export_graphviz(
|
||||
clf, feature_names=constructor(["feature0", "feature1"]), out_file=None
|
||||
)
|
||||
contents2 = (
|
||||
"digraph Tree {\n"
|
||||
'node [shape=box, fontname="helvetica"] ;\n'
|
||||
'edge [fontname="helvetica"] ;\n'
|
||||
'0 [label="feature0 <= 0.0\\ngini = 0.5\\nsamples = 6\\n'
|
||||
'value = [3, 3]"] ;\n'
|
||||
'1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]"] ;\n'
|
||||
"0 -> 1 [labeldistance=2.5, labelangle=45, "
|
||||
'headlabel="True"] ;\n'
|
||||
'2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]"] ;\n'
|
||||
"0 -> 2 [labeldistance=2.5, labelangle=-45, "
|
||||
'headlabel="False"] ;\n'
|
||||
"}"
|
||||
)
|
||||
|
||||
assert contents1 == contents2
|
||||
|
||||
# Test with class_names
|
||||
contents1 = export_graphviz(
|
||||
clf, class_names=constructor(["yes", "no"]), out_file=None
|
||||
)
|
||||
contents2 = (
|
||||
"digraph Tree {\n"
|
||||
'node [shape=box, fontname="helvetica"] ;\n'
|
||||
'edge [fontname="helvetica"] ;\n'
|
||||
'0 [label="x[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n'
|
||||
'value = [3, 3]\\nclass = yes"] ;\n'
|
||||
'1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]\\n'
|
||||
'class = yes"] ;\n'
|
||||
"0 -> 1 [labeldistance=2.5, labelangle=45, "
|
||||
'headlabel="True"] ;\n'
|
||||
'2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]\\n'
|
||||
'class = no"] ;\n'
|
||||
"0 -> 2 [labeldistance=2.5, labelangle=-45, "
|
||||
'headlabel="False"] ;\n'
|
||||
"}"
|
||||
)
|
||||
|
||||
assert contents1 == contents2
|
||||
|
||||
|
||||
def test_graphviz_errors():
|
||||
# Check for errors of export_graphviz
|
||||
clf = DecisionTreeClassifier(max_depth=3, min_samples_split=2)
|
||||
|
||||
# Check not-fitted decision tree error
|
||||
out = StringIO()
|
||||
with pytest.raises(NotFittedError):
|
||||
export_graphviz(clf, out)
|
||||
|
||||
clf.fit(X, y)
|
||||
|
||||
# Check if it errors when length of feature_names
|
||||
# mismatches with number of features
|
||||
message = "Length of feature_names, 1 does not match number of features, 2"
|
||||
with pytest.raises(ValueError, match=message):
|
||||
export_graphviz(clf, None, feature_names=["a"])
|
||||
|
||||
message = "Length of feature_names, 3 does not match number of features, 2"
|
||||
with pytest.raises(ValueError, match=message):
|
||||
export_graphviz(clf, None, feature_names=["a", "b", "c"])
|
||||
|
||||
# Check error when feature_names contains non-string elements
|
||||
message = "All feature names must be strings."
|
||||
with pytest.raises(ValueError, match=message):
|
||||
export_graphviz(clf, None, feature_names=["a", 1])
|
||||
|
||||
# Check error when argument is not an estimator
|
||||
message = "is not an estimator instance"
|
||||
with pytest.raises(TypeError, match=message):
|
||||
export_graphviz(clf.fit(X, y).tree_)
|
||||
|
||||
# Check class_names error
|
||||
out = StringIO()
|
||||
with pytest.raises(IndexError):
|
||||
export_graphviz(clf, out, class_names=[])
|
||||
|
||||
|
||||
def test_friedman_mse_in_graphviz():
|
||||
clf = DecisionTreeRegressor(criterion="friedman_mse", random_state=0)
|
||||
clf.fit(X, y)
|
||||
dot_data = StringIO()
|
||||
export_graphviz(clf, out_file=dot_data)
|
||||
|
||||
clf = GradientBoostingClassifier(n_estimators=2, random_state=0)
|
||||
clf.fit(X, y)
|
||||
for estimator in clf.estimators_:
|
||||
export_graphviz(estimator[0], out_file=dot_data)
|
||||
|
||||
for finding in finditer(r"\[.*?samples.*?\]", dot_data.getvalue()):
|
||||
assert "friedman_mse" in finding.group()
|
||||
|
||||
|
||||
def test_precision():
|
||||
rng_reg = RandomState(2)
|
||||
rng_clf = RandomState(8)
|
||||
for X, y, clf in zip(
|
||||
(rng_reg.random_sample((5, 2)), rng_clf.random_sample((1000, 4))),
|
||||
(rng_reg.random_sample((5,)), rng_clf.randint(2, size=(1000,))),
|
||||
(
|
||||
DecisionTreeRegressor(
|
||||
criterion="friedman_mse", random_state=0, max_depth=1
|
||||
),
|
||||
DecisionTreeClassifier(max_depth=1, random_state=0),
|
||||
),
|
||||
):
|
||||
clf.fit(X, y)
|
||||
for precision in (4, 3):
|
||||
dot_data = export_graphviz(
|
||||
clf, out_file=None, precision=precision, proportion=True
|
||||
)
|
||||
|
||||
# With the current random state, the impurity and the threshold
|
||||
# will have the number of precision set in the export_graphviz
|
||||
# function. We will check the number of precision with a strict
|
||||
# equality. The value reported will have only 2 precision and
|
||||
# therefore, only a less equal comparison will be done.
|
||||
|
||||
# check value
|
||||
for finding in finditer(r"value = \d+\.\d+", dot_data):
|
||||
assert len(search(r"\.\d+", finding.group()).group()) <= precision + 1
|
||||
# check impurity
|
||||
if is_classifier(clf):
|
||||
pattern = r"gini = \d+\.\d+"
|
||||
else:
|
||||
pattern = r"friedman_mse = \d+\.\d+"
|
||||
|
||||
# check impurity
|
||||
for finding in finditer(pattern, dot_data):
|
||||
assert len(search(r"\.\d+", finding.group()).group()) == precision + 1
|
||||
# check threshold
|
||||
for finding in finditer(r"<= \d+\.\d+", dot_data):
|
||||
assert len(search(r"\.\d+", finding.group()).group()) == precision + 1
|
||||
|
||||
|
||||
def test_export_text_errors():
|
||||
clf = DecisionTreeClassifier(max_depth=2, random_state=0)
|
||||
clf.fit(X, y)
|
||||
err_msg = "feature_names must contain 2 elements, got 1"
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
export_text(clf, feature_names=["a"])
|
||||
err_msg = (
|
||||
"When `class_names` is an array, it should contain as"
|
||||
" many items as `decision_tree.classes_`. Got 1 while"
|
||||
" the tree was fitted with 2 classes."
|
||||
)
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
export_text(clf, class_names=["a"])
|
||||
|
||||
|
||||
def test_export_text():
|
||||
clf = DecisionTreeClassifier(max_depth=2, random_state=0)
|
||||
clf.fit(X, y)
|
||||
|
||||
expected_report = dedent(
|
||||
"""
|
||||
|--- feature_1 <= 0.00
|
||||
| |--- class: -1
|
||||
|--- feature_1 > 0.00
|
||||
| |--- class: 1
|
||||
"""
|
||||
).lstrip()
|
||||
|
||||
assert export_text(clf) == expected_report
|
||||
# testing that leaves at level 1 are not truncated
|
||||
assert export_text(clf, max_depth=0) == expected_report
|
||||
# testing that the rest of the tree is truncated
|
||||
assert export_text(clf, max_depth=10) == expected_report
|
||||
|
||||
expected_report = dedent(
|
||||
"""
|
||||
|--- feature_1 <= 0.00
|
||||
| |--- weights: [3.00, 0.00] class: -1
|
||||
|--- feature_1 > 0.00
|
||||
| |--- weights: [0.00, 3.00] class: 1
|
||||
"""
|
||||
).lstrip()
|
||||
assert export_text(clf, show_weights=True) == expected_report
|
||||
|
||||
expected_report = dedent(
|
||||
"""
|
||||
|- feature_1 <= 0.00
|
||||
| |- class: -1
|
||||
|- feature_1 > 0.00
|
||||
| |- class: 1
|
||||
"""
|
||||
).lstrip()
|
||||
assert export_text(clf, spacing=1) == expected_report
|
||||
|
||||
X_l = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1], [-1, 1]]
|
||||
y_l = [-1, -1, -1, 1, 1, 1, 2]
|
||||
clf = DecisionTreeClassifier(max_depth=4, random_state=0)
|
||||
clf.fit(X_l, y_l)
|
||||
expected_report = dedent(
|
||||
"""
|
||||
|--- feature_1 <= 0.00
|
||||
| |--- class: -1
|
||||
|--- feature_1 > 0.00
|
||||
| |--- truncated branch of depth 2
|
||||
"""
|
||||
).lstrip()
|
||||
assert export_text(clf, max_depth=0) == expected_report
|
||||
|
||||
X_mo = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]]
|
||||
y_mo = [[-1, -1], [-1, -1], [-1, -1], [1, 1], [1, 1], [1, 1]]
|
||||
|
||||
reg = DecisionTreeRegressor(max_depth=2, random_state=0)
|
||||
reg.fit(X_mo, y_mo)
|
||||
|
||||
expected_report = dedent(
|
||||
"""
|
||||
|--- feature_1 <= 0.0
|
||||
| |--- value: [-1.0, -1.0]
|
||||
|--- feature_1 > 0.0
|
||||
| |--- value: [1.0, 1.0]
|
||||
"""
|
||||
).lstrip()
|
||||
assert export_text(reg, decimals=1) == expected_report
|
||||
assert export_text(reg, decimals=1, show_weights=True) == expected_report
|
||||
|
||||
X_single = [[-2], [-1], [-1], [1], [1], [2]]
|
||||
reg = DecisionTreeRegressor(max_depth=2, random_state=0)
|
||||
reg.fit(X_single, y_mo)
|
||||
|
||||
expected_report = dedent(
|
||||
"""
|
||||
|--- first <= 0.0
|
||||
| |--- value: [-1.0, -1.0]
|
||||
|--- first > 0.0
|
||||
| |--- value: [1.0, 1.0]
|
||||
"""
|
||||
).lstrip()
|
||||
assert export_text(reg, decimals=1, feature_names=["first"]) == expected_report
|
||||
assert (
|
||||
export_text(reg, decimals=1, show_weights=True, feature_names=["first"])
|
||||
== expected_report
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("constructor", [list, np.array])
|
||||
def test_export_text_feature_class_names_array_support(constructor):
|
||||
# Check that export_graphviz treats feature names
|
||||
# and class names correctly and supports arrays
|
||||
clf = DecisionTreeClassifier(max_depth=2, random_state=0)
|
||||
clf.fit(X, y)
|
||||
|
||||
expected_report = dedent(
|
||||
"""
|
||||
|--- b <= 0.00
|
||||
| |--- class: -1
|
||||
|--- b > 0.00
|
||||
| |--- class: 1
|
||||
"""
|
||||
).lstrip()
|
||||
assert export_text(clf, feature_names=constructor(["a", "b"])) == expected_report
|
||||
|
||||
expected_report = dedent(
|
||||
"""
|
||||
|--- feature_1 <= 0.00
|
||||
| |--- class: cat
|
||||
|--- feature_1 > 0.00
|
||||
| |--- class: dog
|
||||
"""
|
||||
).lstrip()
|
||||
assert export_text(clf, class_names=constructor(["cat", "dog"])) == expected_report
|
||||
|
||||
|
||||
def test_plot_tree_entropy(pyplot):
|
||||
# mostly smoke tests
|
||||
# Check correctness of export_graphviz for criterion = entropy
|
||||
clf = DecisionTreeClassifier(
|
||||
max_depth=3, min_samples_split=2, criterion="entropy", random_state=2
|
||||
)
|
||||
clf.fit(X, y)
|
||||
|
||||
# Test export code
|
||||
feature_names = ["first feat", "sepal_width"]
|
||||
nodes = plot_tree(clf, feature_names=feature_names)
|
||||
assert len(nodes) == 5
|
||||
assert (
|
||||
nodes[0].get_text()
|
||||
== "first feat <= 0.0\nentropy = 1.0\nsamples = 6\nvalue = [3, 3]"
|
||||
)
|
||||
assert nodes[1].get_text() == "entropy = 0.0\nsamples = 3\nvalue = [3, 0]"
|
||||
assert nodes[2].get_text() == "True "
|
||||
assert nodes[3].get_text() == "entropy = 0.0\nsamples = 3\nvalue = [0, 3]"
|
||||
assert nodes[4].get_text() == " False"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("fontsize", [None, 10, 20])
|
||||
def test_plot_tree_gini(pyplot, fontsize):
|
||||
# mostly smoke tests
|
||||
# Check correctness of export_graphviz for criterion = gini
|
||||
clf = DecisionTreeClassifier(
|
||||
max_depth=3,
|
||||
min_samples_split=2,
|
||||
criterion="gini",
|
||||
random_state=2,
|
||||
)
|
||||
clf.fit(X, y)
|
||||
|
||||
# Test export code
|
||||
feature_names = ["first feat", "sepal_width"]
|
||||
nodes = plot_tree(clf, feature_names=feature_names, fontsize=fontsize)
|
||||
assert len(nodes) == 5
|
||||
if fontsize is not None:
|
||||
assert all(node.get_fontsize() == fontsize for node in nodes)
|
||||
assert (
|
||||
nodes[0].get_text()
|
||||
== "first feat <= 0.0\ngini = 0.5\nsamples = 6\nvalue = [3, 3]"
|
||||
)
|
||||
assert nodes[1].get_text() == "gini = 0.0\nsamples = 3\nvalue = [3, 0]"
|
||||
assert nodes[2].get_text() == "True "
|
||||
assert nodes[3].get_text() == "gini = 0.0\nsamples = 3\nvalue = [0, 3]"
|
||||
assert nodes[4].get_text() == " False"
|
||||
|
||||
|
||||
def test_not_fitted_tree(pyplot):
|
||||
# Testing if not fitted tree throws the correct error
|
||||
clf = DecisionTreeRegressor()
|
||||
with pytest.raises(NotFittedError):
|
||||
plot_tree(clf)
|
||||
@@ -0,0 +1,51 @@
|
||||
import numpy as np
|
||||
|
||||
from sklearn.tree._utils import PytestWeightedFenwickTree
|
||||
|
||||
|
||||
def test_cython_weighted_fenwick_tree(global_random_seed):
|
||||
"""
|
||||
Test Cython's weighted Fenwick tree implementation
|
||||
"""
|
||||
rng = np.random.default_rng(global_random_seed)
|
||||
|
||||
n = 100
|
||||
indices = rng.permutation(n)
|
||||
y = rng.normal(size=n)
|
||||
w = rng.integers(0, 4, size=n)
|
||||
y_included_so_far = np.zeros_like(y)
|
||||
w_included_so_far = np.zeros_like(w)
|
||||
|
||||
tree = PytestWeightedFenwickTree(n)
|
||||
tree.py_reset(n)
|
||||
|
||||
for i in range(n):
|
||||
idx = indices[i]
|
||||
tree.py_add(idx, y[idx], w[idx])
|
||||
y_included_so_far[idx] = y[idx]
|
||||
w_included_so_far[idx] = w[idx]
|
||||
|
||||
target = rng.uniform(0, w_included_so_far.sum())
|
||||
t_idx_low, t_idx, cw, cwy = tree.py_search(target)
|
||||
|
||||
# check the aggregates are consistent with the returned idx
|
||||
assert np.isclose(cw, np.sum(w_included_so_far[:t_idx]))
|
||||
assert np.isclose(
|
||||
cwy, np.sum(w_included_so_far[:t_idx] * y_included_so_far[:t_idx])
|
||||
)
|
||||
|
||||
# check if the cumulative weight is less than or equal to the target
|
||||
# depending on t_idx_low and t_idx
|
||||
if t_idx_low == t_idx:
|
||||
assert cw < target
|
||||
else:
|
||||
assert cw == target
|
||||
|
||||
# check that if we add the next non-null weight, we are above the target:
|
||||
next_weights = w_included_so_far[t_idx:][w_included_so_far[t_idx:] > 0]
|
||||
if next_weights.size > 0:
|
||||
assert cw + next_weights[0] > target
|
||||
# and not below the target for `t_idx_low`:
|
||||
next_weights = w_included_so_far[t_idx_low:][w_included_so_far[t_idx_low:] > 0]
|
||||
if next_weights.size > 0:
|
||||
assert cw + next_weights[0] >= target
|
||||
@@ -0,0 +1,512 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from sklearn.datasets import make_classification, make_regression
|
||||
from sklearn.ensemble import (
|
||||
ExtraTreesClassifier,
|
||||
ExtraTreesRegressor,
|
||||
RandomForestClassifier,
|
||||
RandomForestRegressor,
|
||||
)
|
||||
from sklearn.tree import (
|
||||
DecisionTreeClassifier,
|
||||
DecisionTreeRegressor,
|
||||
ExtraTreeClassifier,
|
||||
ExtraTreeRegressor,
|
||||
)
|
||||
from sklearn.utils._testing import assert_allclose
|
||||
from sklearn.utils.fixes import CSC_CONTAINERS
|
||||
|
||||
TREE_CLASSIFIER_CLASSES = [DecisionTreeClassifier, ExtraTreeClassifier]
|
||||
TREE_REGRESSOR_CLASSES = [DecisionTreeRegressor, ExtraTreeRegressor]
|
||||
TREE_BASED_CLASSIFIER_CLASSES = TREE_CLASSIFIER_CLASSES + [
|
||||
RandomForestClassifier,
|
||||
ExtraTreesClassifier,
|
||||
]
|
||||
TREE_BASED_REGRESSOR_CLASSES = TREE_REGRESSOR_CLASSES + [
|
||||
RandomForestRegressor,
|
||||
ExtraTreesRegressor,
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("TreeClassifier", TREE_BASED_CLASSIFIER_CLASSES)
|
||||
@pytest.mark.parametrize("depth_first_builder", (True, False))
|
||||
@pytest.mark.parametrize("sparse_splitter", (True, False))
|
||||
@pytest.mark.parametrize("csc_container", CSC_CONTAINERS)
|
||||
def test_monotonic_constraints_classifications(
|
||||
TreeClassifier,
|
||||
depth_first_builder,
|
||||
sparse_splitter,
|
||||
global_random_seed,
|
||||
csc_container,
|
||||
):
|
||||
n_samples = 1000
|
||||
n_samples_train = 900
|
||||
X, y = make_classification(
|
||||
n_samples=n_samples,
|
||||
n_classes=2,
|
||||
n_features=5,
|
||||
n_informative=5,
|
||||
n_redundant=0,
|
||||
random_state=global_random_seed,
|
||||
)
|
||||
X_train, y_train = X[:n_samples_train], y[:n_samples_train]
|
||||
X_test, _ = X[n_samples_train:], y[n_samples_train:]
|
||||
|
||||
X_test_0incr, X_test_0decr = np.copy(X_test), np.copy(X_test)
|
||||
X_test_1incr, X_test_1decr = np.copy(X_test), np.copy(X_test)
|
||||
X_test_0incr[:, 0] += 10
|
||||
X_test_0decr[:, 0] -= 10
|
||||
X_test_1incr[:, 1] += 10
|
||||
X_test_1decr[:, 1] -= 10
|
||||
monotonic_cst = np.zeros(X.shape[1])
|
||||
monotonic_cst[0] = 1
|
||||
monotonic_cst[1] = -1
|
||||
|
||||
if depth_first_builder:
|
||||
est = TreeClassifier(max_depth=None, monotonic_cst=monotonic_cst)
|
||||
else:
|
||||
est = TreeClassifier(
|
||||
max_depth=None,
|
||||
monotonic_cst=monotonic_cst,
|
||||
max_leaf_nodes=n_samples_train,
|
||||
)
|
||||
if hasattr(est, "random_state"):
|
||||
est.set_params(**{"random_state": global_random_seed})
|
||||
if hasattr(est, "n_estimators"):
|
||||
est.set_params(**{"n_estimators": 5})
|
||||
if sparse_splitter:
|
||||
X_train = csc_container(X_train)
|
||||
est.fit(X_train, y_train)
|
||||
proba_test = est.predict_proba(X_test)
|
||||
|
||||
assert np.logical_and(proba_test >= 0.0, proba_test <= 1.0).all(), (
|
||||
"Probability should always be in [0, 1] range."
|
||||
)
|
||||
assert_allclose(proba_test.sum(axis=1), 1.0)
|
||||
|
||||
# Monotonic increase constraint, it applies to the positive class
|
||||
assert np.all(est.predict_proba(X_test_0incr)[:, 1] >= proba_test[:, 1])
|
||||
assert np.all(est.predict_proba(X_test_0decr)[:, 1] <= proba_test[:, 1])
|
||||
|
||||
# Monotonic decrease constraint, it applies to the positive class
|
||||
assert np.all(est.predict_proba(X_test_1incr)[:, 1] <= proba_test[:, 1])
|
||||
assert np.all(est.predict_proba(X_test_1decr)[:, 1] >= proba_test[:, 1])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("TreeRegressor", TREE_BASED_REGRESSOR_CLASSES)
|
||||
@pytest.mark.parametrize("depth_first_builder", (True, False))
|
||||
@pytest.mark.parametrize("sparse_splitter", (True, False))
|
||||
@pytest.mark.parametrize("criterion", ("absolute_error", "squared_error"))
|
||||
@pytest.mark.parametrize("csc_container", CSC_CONTAINERS)
|
||||
def test_monotonic_constraints_regressions(
|
||||
TreeRegressor,
|
||||
depth_first_builder,
|
||||
sparse_splitter,
|
||||
criterion,
|
||||
global_random_seed,
|
||||
csc_container,
|
||||
):
|
||||
n_samples = 1000
|
||||
n_samples_train = 900
|
||||
# Build a regression task using 5 informative features
|
||||
X, y = make_regression(
|
||||
n_samples=n_samples,
|
||||
n_features=5,
|
||||
n_informative=5,
|
||||
random_state=global_random_seed,
|
||||
)
|
||||
train = np.arange(n_samples_train)
|
||||
test = np.arange(n_samples_train, n_samples)
|
||||
X_train = X[train]
|
||||
y_train = y[train]
|
||||
X_test = np.copy(X[test])
|
||||
X_test_incr = np.copy(X_test)
|
||||
X_test_decr = np.copy(X_test)
|
||||
X_test_incr[:, 0] += 10
|
||||
X_test_decr[:, 1] += 10
|
||||
monotonic_cst = np.zeros(X.shape[1])
|
||||
monotonic_cst[0] = 1
|
||||
monotonic_cst[1] = -1
|
||||
|
||||
if depth_first_builder:
|
||||
est = TreeRegressor(
|
||||
max_depth=None,
|
||||
monotonic_cst=monotonic_cst,
|
||||
criterion=criterion,
|
||||
)
|
||||
else:
|
||||
est = TreeRegressor(
|
||||
max_depth=8,
|
||||
monotonic_cst=monotonic_cst,
|
||||
criterion=criterion,
|
||||
max_leaf_nodes=n_samples_train,
|
||||
)
|
||||
if hasattr(est, "random_state"):
|
||||
est.set_params(random_state=global_random_seed)
|
||||
if hasattr(est, "n_estimators"):
|
||||
est.set_params(**{"n_estimators": 5})
|
||||
if sparse_splitter:
|
||||
X_train = csc_container(X_train)
|
||||
est.fit(X_train, y_train)
|
||||
y = est.predict(X_test)
|
||||
# Monotonic increase constraint
|
||||
y_incr = est.predict(X_test_incr)
|
||||
# y_incr should always be greater than y
|
||||
assert np.all(y_incr >= y)
|
||||
|
||||
# Monotonic decrease constraint
|
||||
y_decr = est.predict(X_test_decr)
|
||||
# y_decr should always be lower than y
|
||||
assert np.all(y_decr <= y)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("TreeClassifier", TREE_BASED_CLASSIFIER_CLASSES)
|
||||
def test_multiclass_raises(TreeClassifier):
|
||||
X, y = make_classification(
|
||||
n_samples=100, n_features=5, n_classes=3, n_informative=3, random_state=0
|
||||
)
|
||||
y[0] = 0
|
||||
monotonic_cst = np.zeros(X.shape[1])
|
||||
monotonic_cst[0] = -1
|
||||
monotonic_cst[1] = 1
|
||||
est = TreeClassifier(max_depth=None, monotonic_cst=monotonic_cst, random_state=0)
|
||||
|
||||
msg = "Monotonicity constraints are not supported with multiclass classification"
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
est.fit(X, y)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("TreeClassifier", TREE_BASED_CLASSIFIER_CLASSES)
|
||||
def test_multiple_output_raises(TreeClassifier):
|
||||
X = [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]
|
||||
y = [[1, 0, 1, 0, 1], [1, 0, 1, 0, 1]]
|
||||
|
||||
est = TreeClassifier(
|
||||
max_depth=None, monotonic_cst=np.array([-1, 1]), random_state=0
|
||||
)
|
||||
msg = "Monotonicity constraints are not supported with multiple output"
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
est.fit(X, y)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"Tree",
|
||||
[
|
||||
DecisionTreeClassifier,
|
||||
DecisionTreeRegressor,
|
||||
ExtraTreeClassifier,
|
||||
ExtraTreeRegressor,
|
||||
],
|
||||
)
|
||||
def test_missing_values_raises(Tree):
|
||||
X, y = make_classification(
|
||||
n_samples=100, n_features=5, n_classes=2, n_informative=3, random_state=0
|
||||
)
|
||||
X[0, 0] = np.nan
|
||||
monotonic_cst = np.zeros(X.shape[1])
|
||||
monotonic_cst[0] = 1
|
||||
est = Tree(max_depth=None, monotonic_cst=monotonic_cst, random_state=0)
|
||||
|
||||
msg = "Input X contains NaN"
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
est.fit(X, y)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("TreeClassifier", TREE_BASED_CLASSIFIER_CLASSES)
|
||||
def test_bad_monotonic_cst_raises(TreeClassifier):
|
||||
X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]
|
||||
y = [1, 0, 1, 0, 1]
|
||||
|
||||
msg = "monotonic_cst has shape 3 but the input data X has 2 features."
|
||||
est = TreeClassifier(
|
||||
max_depth=None, monotonic_cst=np.array([-1, 1, 0]), random_state=0
|
||||
)
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
est.fit(X, y)
|
||||
|
||||
msg = "monotonic_cst must be None or an array-like of -1, 0 or 1."
|
||||
est = TreeClassifier(
|
||||
max_depth=None, monotonic_cst=np.array([-2, 2]), random_state=0
|
||||
)
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
est.fit(X, y)
|
||||
|
||||
est = TreeClassifier(
|
||||
max_depth=None, monotonic_cst=np.array([-1, 0.8]), random_state=0
|
||||
)
|
||||
with pytest.raises(ValueError, match=msg + "(.*)0.8]"):
|
||||
est.fit(X, y)
|
||||
|
||||
|
||||
def assert_1d_reg_tree_children_monotonic_bounded(tree_, monotonic_sign):
|
||||
values = tree_.value
|
||||
for i in range(tree_.node_count):
|
||||
if tree_.children_left[i] > i and tree_.children_right[i] > i:
|
||||
# Check monotonicity on children
|
||||
i_left = tree_.children_left[i]
|
||||
i_right = tree_.children_right[i]
|
||||
if monotonic_sign == 1:
|
||||
assert values[i_left] <= values[i_right]
|
||||
elif monotonic_sign == -1:
|
||||
assert values[i_left] >= values[i_right]
|
||||
val_middle = (values[i_left] + values[i_right]) / 2
|
||||
# Check bounds on grand-children, filtering out leaf nodes
|
||||
if tree_.feature[i_left] >= 0:
|
||||
i_left_right = tree_.children_right[i_left]
|
||||
if monotonic_sign == 1:
|
||||
assert values[i_left_right] <= val_middle
|
||||
elif monotonic_sign == -1:
|
||||
assert values[i_left_right] >= val_middle
|
||||
if tree_.feature[i_right] >= 0:
|
||||
i_right_left = tree_.children_left[i_right]
|
||||
if monotonic_sign == 1:
|
||||
assert val_middle <= values[i_right_left]
|
||||
elif monotonic_sign == -1:
|
||||
assert val_middle >= values[i_right_left]
|
||||
|
||||
|
||||
def test_assert_1d_reg_tree_children_monotonic_bounded():
|
||||
X = np.linspace(-1, 1, 7).reshape(-1, 1)
|
||||
y = np.sin(2 * np.pi * X.ravel())
|
||||
|
||||
reg = DecisionTreeRegressor(max_depth=None, random_state=0).fit(X, y)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
assert_1d_reg_tree_children_monotonic_bounded(reg.tree_, 1)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
assert_1d_reg_tree_children_monotonic_bounded(reg.tree_, -1)
|
||||
|
||||
|
||||
def assert_1d_reg_monotonic(clf, monotonic_sign, min_x, max_x, n_steps):
|
||||
X_grid = np.linspace(min_x, max_x, n_steps).reshape(-1, 1)
|
||||
y_pred_grid = clf.predict(X_grid)
|
||||
if monotonic_sign == 1:
|
||||
assert (np.diff(y_pred_grid) >= 0.0).all()
|
||||
elif monotonic_sign == -1:
|
||||
assert (np.diff(y_pred_grid) <= 0.0).all()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("TreeRegressor", TREE_REGRESSOR_CLASSES)
|
||||
def test_1d_opposite_monotonicity_cst_data(TreeRegressor):
|
||||
# Check that positive monotonic data with negative monotonic constraint
|
||||
# yield constant predictions, equal to the average of target values
|
||||
X = np.linspace(-2, 2, 10).reshape(-1, 1)
|
||||
y = X.ravel()
|
||||
clf = TreeRegressor(monotonic_cst=[-1])
|
||||
clf.fit(X, y)
|
||||
assert clf.tree_.node_count == 1
|
||||
assert clf.tree_.value[0] == 0.0
|
||||
|
||||
# Swap monotonicity
|
||||
clf = TreeRegressor(monotonic_cst=[1])
|
||||
clf.fit(X, -y)
|
||||
assert clf.tree_.node_count == 1
|
||||
assert clf.tree_.value[0] == 0.0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("TreeRegressor", TREE_REGRESSOR_CLASSES)
|
||||
@pytest.mark.parametrize("monotonic_sign", (-1, 1))
|
||||
@pytest.mark.parametrize("depth_first_builder", (True, False))
|
||||
@pytest.mark.parametrize("criterion", ("absolute_error", "squared_error"))
|
||||
def test_1d_tree_nodes_values(
|
||||
TreeRegressor, monotonic_sign, depth_first_builder, criterion, global_random_seed
|
||||
):
|
||||
# Adaptation from test_nodes_values in test_monotonic_constraints.py
|
||||
# in sklearn.ensemble._hist_gradient_boosting
|
||||
# Build a single tree with only one feature, and make sure the node
|
||||
# values respect the monotonicity constraints.
|
||||
|
||||
# Considering the following tree with a monotonic +1 constraint, we
|
||||
# should have:
|
||||
#
|
||||
# root
|
||||
# / \
|
||||
# a b
|
||||
# / \ / \
|
||||
# c d e f
|
||||
#
|
||||
# a <= root <= b
|
||||
# c <= d <= (a + b) / 2 <= e <= f
|
||||
|
||||
rng = np.random.RandomState(global_random_seed)
|
||||
n_samples = 1000
|
||||
n_features = 1
|
||||
X = rng.rand(n_samples, n_features)
|
||||
y = rng.rand(n_samples)
|
||||
|
||||
if depth_first_builder:
|
||||
# No max_leaf_nodes, default depth first tree builder
|
||||
clf = TreeRegressor(
|
||||
monotonic_cst=[monotonic_sign],
|
||||
criterion=criterion,
|
||||
random_state=global_random_seed,
|
||||
)
|
||||
else:
|
||||
# max_leaf_nodes triggers best first tree builder
|
||||
clf = TreeRegressor(
|
||||
monotonic_cst=[monotonic_sign],
|
||||
max_leaf_nodes=n_samples,
|
||||
criterion=criterion,
|
||||
random_state=global_random_seed,
|
||||
)
|
||||
clf.fit(X, y)
|
||||
|
||||
assert_1d_reg_tree_children_monotonic_bounded(clf.tree_, monotonic_sign)
|
||||
assert_1d_reg_monotonic(clf, monotonic_sign, np.min(X), np.max(X), 100)
|
||||
|
||||
|
||||
def assert_nd_reg_tree_children_monotonic_bounded(tree_, monotonic_cst):
|
||||
upper_bound = np.full(tree_.node_count, np.inf)
|
||||
lower_bound = np.full(tree_.node_count, -np.inf)
|
||||
for i in range(tree_.node_count):
|
||||
feature = tree_.feature[i]
|
||||
node_value = tree_.value[i][0][0] # unpack value from nx1x1 array
|
||||
# While building the tree, the computed middle value is slightly
|
||||
# different from the average of the siblings values, because
|
||||
# sum_right / weighted_n_right
|
||||
# is slightly different from the value of the right sibling.
|
||||
# This can cause a discrepancy up to numerical noise when clipping,
|
||||
# which is resolved by comparing with some loss of precision.
|
||||
assert np.float32(node_value) <= np.float32(upper_bound[i])
|
||||
assert np.float32(node_value) >= np.float32(lower_bound[i])
|
||||
|
||||
if feature < 0:
|
||||
# Leaf: nothing to do
|
||||
continue
|
||||
|
||||
# Split node: check and update bounds for the children.
|
||||
i_left = tree_.children_left[i]
|
||||
i_right = tree_.children_right[i]
|
||||
# unpack value from nx1x1 array
|
||||
middle_value = (tree_.value[i_left][0][0] + tree_.value[i_right][0][0]) / 2
|
||||
|
||||
if monotonic_cst[feature] == 0:
|
||||
# Feature without monotonicity constraint: propagate bounds
|
||||
# down the tree to both children.
|
||||
# Otherwise, with 2 features and a monotonic increase constraint
|
||||
# (encoded by +1) on feature 0, the following tree can be accepted,
|
||||
# although it does not respect the monotonic increase constraint:
|
||||
#
|
||||
# X[0] <= 0
|
||||
# value = 100
|
||||
# / \
|
||||
# X[0] <= -1 X[1] <= 0
|
||||
# value = 50 value = 150
|
||||
# / \ / \
|
||||
# leaf leaf leaf leaf
|
||||
# value = 25 value = 75 value = 50 value = 250
|
||||
|
||||
lower_bound[i_left] = lower_bound[i]
|
||||
upper_bound[i_left] = upper_bound[i]
|
||||
lower_bound[i_right] = lower_bound[i]
|
||||
upper_bound[i_right] = upper_bound[i]
|
||||
|
||||
elif monotonic_cst[feature] == 1:
|
||||
# Feature with constraint: check monotonicity
|
||||
assert tree_.value[i_left] <= tree_.value[i_right]
|
||||
|
||||
# Propagate bounds down the tree to both children.
|
||||
lower_bound[i_left] = lower_bound[i]
|
||||
upper_bound[i_left] = middle_value
|
||||
lower_bound[i_right] = middle_value
|
||||
upper_bound[i_right] = upper_bound[i]
|
||||
|
||||
elif monotonic_cst[feature] == -1:
|
||||
# Feature with constraint: check monotonicity
|
||||
assert tree_.value[i_left] >= tree_.value[i_right]
|
||||
|
||||
# Update and propagate bounds down the tree to both children.
|
||||
lower_bound[i_left] = middle_value
|
||||
upper_bound[i_left] = upper_bound[i]
|
||||
lower_bound[i_right] = lower_bound[i]
|
||||
upper_bound[i_right] = middle_value
|
||||
|
||||
else: # pragma: no cover
|
||||
raise ValueError(f"monotonic_cst[{feature}]={monotonic_cst[feature]}")
|
||||
|
||||
|
||||
def test_assert_nd_reg_tree_children_monotonic_bounded():
|
||||
# Check that assert_nd_reg_tree_children_monotonic_bounded can detect
|
||||
# non-monotonic tree predictions.
|
||||
X = np.linspace(0, 2 * np.pi, 30).reshape(-1, 1)
|
||||
y = np.sin(X).ravel()
|
||||
reg = DecisionTreeRegressor(max_depth=None, random_state=0).fit(X, y)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
assert_nd_reg_tree_children_monotonic_bounded(reg.tree_, [1])
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
assert_nd_reg_tree_children_monotonic_bounded(reg.tree_, [-1])
|
||||
|
||||
assert_nd_reg_tree_children_monotonic_bounded(reg.tree_, [0])
|
||||
|
||||
# Check that assert_nd_reg_tree_children_monotonic_bounded raises
|
||||
# when the data (and therefore the model) is naturally monotonic in the
|
||||
# opposite direction.
|
||||
X = np.linspace(-5, 5, 5).reshape(-1, 1)
|
||||
y = X.ravel() ** 3
|
||||
reg = DecisionTreeRegressor(max_depth=None, random_state=0).fit(X, y)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
assert_nd_reg_tree_children_monotonic_bounded(reg.tree_, [-1])
|
||||
|
||||
# For completeness, check that the converse holds when swapping the sign.
|
||||
reg = DecisionTreeRegressor(max_depth=None, random_state=0).fit(X, -y)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
assert_nd_reg_tree_children_monotonic_bounded(reg.tree_, [1])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("TreeRegressor", TREE_REGRESSOR_CLASSES)
|
||||
@pytest.mark.parametrize("monotonic_sign", (-1, 1))
|
||||
@pytest.mark.parametrize("depth_first_builder", (True, False))
|
||||
@pytest.mark.parametrize("criterion", ("absolute_error", "squared_error"))
|
||||
def test_nd_tree_nodes_values(
|
||||
TreeRegressor, monotonic_sign, depth_first_builder, criterion, global_random_seed
|
||||
):
|
||||
# Build tree with several features, and make sure the nodes
|
||||
# values respect the monotonicity constraints.
|
||||
|
||||
# Considering the following tree with a monotonic increase constraint on X[0],
|
||||
# we should have:
|
||||
#
|
||||
# root
|
||||
# X[0]<=t
|
||||
# / \
|
||||
# a b
|
||||
# X[0]<=u X[1]<=v
|
||||
# / \ / \
|
||||
# c d e f
|
||||
#
|
||||
# i) a <= root <= b
|
||||
# ii) c <= a <= d <= (a+b)/2
|
||||
# iii) (a+b)/2 <= min(e,f)
|
||||
# For iii) we check that each node value is within the proper lower and
|
||||
# upper bounds.
|
||||
|
||||
rng = np.random.RandomState(global_random_seed)
|
||||
n_samples = 1000
|
||||
n_features = 2
|
||||
monotonic_cst = [monotonic_sign, 0]
|
||||
X = rng.rand(n_samples, n_features)
|
||||
y = rng.rand(n_samples)
|
||||
|
||||
if depth_first_builder:
|
||||
# No max_leaf_nodes, default depth first tree builder
|
||||
clf = TreeRegressor(
|
||||
monotonic_cst=monotonic_cst,
|
||||
criterion=criterion,
|
||||
random_state=global_random_seed,
|
||||
)
|
||||
else:
|
||||
# max_leaf_nodes triggers best first tree builder
|
||||
clf = TreeRegressor(
|
||||
monotonic_cst=monotonic_cst,
|
||||
max_leaf_nodes=n_samples,
|
||||
criterion=criterion,
|
||||
random_state=global_random_seed,
|
||||
)
|
||||
clf.fit(X, y)
|
||||
assert_nd_reg_tree_children_monotonic_bounded(clf.tree_, monotonic_cst)
|
||||
@@ -0,0 +1,49 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from sklearn.tree._reingold_tilford import Tree, buchheim
|
||||
|
||||
simple_tree = Tree("", 0, Tree("", 1), Tree("", 2))
|
||||
|
||||
bigger_tree = Tree(
|
||||
"",
|
||||
0,
|
||||
Tree(
|
||||
"",
|
||||
1,
|
||||
Tree("", 3),
|
||||
Tree("", 4, Tree("", 7), Tree("", 8)),
|
||||
),
|
||||
Tree("", 2, Tree("", 5), Tree("", 6)),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tree, n_nodes", [(simple_tree, 3), (bigger_tree, 9)])
|
||||
def test_buchheim(tree, n_nodes):
|
||||
def walk_tree(draw_tree):
|
||||
res = [(draw_tree.x, draw_tree.y)]
|
||||
for child in draw_tree.children:
|
||||
# parents higher than children:
|
||||
assert child.y == draw_tree.y + 1
|
||||
res.extend(walk_tree(child))
|
||||
if len(draw_tree.children):
|
||||
# these trees are always binary
|
||||
# parents are centered above children
|
||||
assert (
|
||||
draw_tree.x == (draw_tree.children[0].x + draw_tree.children[1].x) / 2
|
||||
)
|
||||
return res
|
||||
|
||||
layout = buchheim(tree)
|
||||
coordinates = walk_tree(layout)
|
||||
assert len(coordinates) == n_nodes
|
||||
# test that x values are unique per depth / level
|
||||
# we could also do it quicker using defaultdicts..
|
||||
depth = 0
|
||||
while True:
|
||||
x_at_this_depth = [node[0] for node in coordinates if node[1] == depth]
|
||||
if not x_at_this_depth:
|
||||
# reached all leafs
|
||||
break
|
||||
assert len(np.unique(x_at_this_depth)) == len(x_at_this_depth)
|
||||
depth += 1
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user