GraphPipeline¶
GraphPipelines (tpot.GraphPipeline
) 的工作方式与 scikit-learn 的 Pipeline 类类似。不同之处在于,GraphPipeline 不是提供一个步骤列表,而是使用 networkx 提供一个由步骤组成的有向无环图 (networkx.DiGraph
)。在 GraphPipeline 中,父节点从其子节点获取输入(即叶节点获取原始输入 (X,y),根节点是最终的分类器/回归器)。
节点的标签可以是任何内容,但对于每个 sklearn 估计器的实例必须是唯一的。每个节点都有一个名为 "instance" 的属性,用于存储 scikit-learn 估计器的实例。
GraphPipeline 允许在管道中间使用分类器和回归器。在这种情况下,GraphPipeline 将尝试按顺序使用 predict_proba、decision_function 或 predict 的输出。如果设置了 cross_val_predict_cv,则下游模型将使用 sklearn.model_selection.cross_val_predict
的输出进行训练(最终结果使用在完整数据上训练的模型进行预测)。
Parameters
----------
graph: networkx.DiGraph
A directed graph where the nodes are sklearn estimators and the edges are the inputs to those estimators.
cross_val_predict_cv: int, cross-validation generator or an iterable, optional
Determines the cross-validation splitting strategy used in inner classifiers or regressors
method: str, optional
The prediction method to use for the inner classifiers or regressors. If 'auto', it will try to use predict_proba, decision_function, or predict in that order.
memory: str or object with the joblib.Memory interface, optional
Used to cache the input and outputs of nodes to prevent refitting or computationally heavy transformations. By default, no caching is performed. If a string is given, it is the path to the caching directory.
use_label_encoder: bool, optional
If True, the label encoder is used to encode the labels to be 0 to N. If False, the label encoder is not used.
Mainly useful for classifiers (XGBoost) that require labels to be ints from 0 to N.
Can also be a sklearn.preprocessing.LabelEncoder object. If so, that label encoder is used.
输入 [4]
已复制!
from sklearn.svm import SVC
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
import networkx as nx
from tpot import GraphPipeline
import sklearn.metrics
X, y = make_classification(random_state=0)
X_train, X_test, y_train, y_test = train_test_split(X, y,
random_state=0)
g = nx.DiGraph()
g.add_node("scaler", instance=StandardScaler())
g.add_node("svc", instance=SVC())
g.add_node("LogisticRegression", instance=LogisticRegression())
g.add_node("LogisticRegression2", instance=LogisticRegression())
g.add_edge("svc","scaler")
g.add_edge("LogisticRegression", "scaler")
g.add_edge("LogisticRegression2", "LogisticRegression")
g.add_edge("LogisticRegression2", "svc")
est = GraphPipeline(g)
est.plot()
est.fit(X_train, y_train)
print("score")
print(sklearn.metrics.roc_auc_score(y_test, est.predict_proba(X_test)[:,1]))
from sklearn.svm import SVC from sklearn.preprocessing import StandardScaler from sklearn.linear_model import LogisticRegression from sklearn.datasets import make_classification from sklearn.model_selection import train_test_split from sklearn.pipeline import Pipeline import networkx as nx from tpot import GraphPipeline import sklearn.metrics X, y = make_classification(random_state=0) X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) g = nx.DiGraph() g.add_node("scaler", instance=StandardScaler()) g.add_node("svc", instance=SVC()) g.add_node("LogisticRegression", instance=LogisticRegression()) g.add_node("LogisticRegression2", instance=LogisticRegression()) g.add_edge("svc","scaler") g.add_edge("LogisticRegression", "scaler") g.add_edge("LogisticRegression2", "LogisticRegression") g.add_edge("LogisticRegression2", "svc") est = GraphPipeline(g) est.plot() est.fit(X_train, y_train) print("score") print(sklearn.metrics.roc_auc_score(y_test, est.predict_proba(X_test)[:,1]))
score 0.8974358974358974
交叉验证预测¶
在某些情况下,使用 cross_val_predict_cv 可以提高性能。
输入 [5]
已复制!
est = GraphPipeline(g, cross_val_predict_cv=10)
est.plot()
est.fit(X_train, y_train)
print("score")
print(sklearn.metrics.roc_auc_score(y_test, est.predict_proba(X_test)[:,1]))
est = GraphPipeline(g, cross_val_predict_cv=10) est.plot() est.fit(X_train, y_train) print("score") print(sklearn.metrics.roc_auc_score(y_test, est.predict_proba(X_test)[:,1]))
score 0.9166666666666666
您可以使用每个节点的标签访问 GraphPipeline 的各个步骤。
输入 [6]
已复制!
svc = est.graph.nodes["svc"]["instance"]
svc
svc = est.graph.nodes["svc"]["instance"] svc
输出 [6]
SVC()在 Jupyter 环境中,请重新运行此单元格以显示 HTML 表示或信任该笔记本。
在 GitHub 上,HTML 表示无法渲染,请尝试使用 nbviewer.org 加载此页面。
SVC()