Я пытаюсь обучить модель дерева решений, сохранить ее, а затем перезагрузить, когда она мне понадобится позже. Однако я продолжаю получать следующую ошибку:
Этот экземпляр DecisionTreeClassifier еще не установлен. Перед использованием этого метода вызовите 'fit' с соответствующими аргументами.
Вот мой код:
X_train, X_test, y_train, y_test = train_test_split(data, label, test_size=0.20, random_state=4)
names = ["Decision Tree", "Random Forest", "Neural Net"]
classifiers = [
DecisionTreeClassifier(),
RandomForestClassifier(),
MLPClassifier()
]
score = 0
for name, clf in zip(names, classifiers):
if name == "Decision Tree":
clf = DecisionTreeClassifier(random_state=0)
grid_search = GridSearchCV(clf, param_grid=param_grid_DT)
grid_search.fit(X_train, y_train_TF)
if grid_search.best_score_ > score:
score = grid_search.best_score_
best_clf = clf
elif name == "Random Forest":
clf = RandomForestClassifier(random_state=0)
grid_search = GridSearchCV(clf, param_grid_RF)
grid_search.fit(X_train, y_train_TF)
if grid_search.best_score_ > score:
score = grid_search.best_score_
best_clf = clf
elif name == "Neural Net":
clf = MLPClassifier()
clf.fit(X_train, y_train_TF)
y_pred = clf.predict(X_test)
current_score = accuracy_score(y_test_TF, y_pred)
if current_score > score:
score = current_score
best_clf = clf
pkl_filename = "pickle_model.pkl"
with open(pkl_filename, 'wb') as file:
pickle.dump(best_clf, file)
from sklearn.externals import joblib
# Save to file in the current working directory
joblib_file = "joblib_model.pkl"
joblib.dump(best_clf, joblib_file)
print("best classifier: ", best_clf, " Accuracy= ", score)
Вот как я загружаю модель и тестирую ее:
#First method
with open(pkl_filename, 'rb') as h:
loaded_model = pickle.load(h)
#Second method
joblib_model = joblib.load(joblib_file)
Как видите, я пробовал два способа сохранить его, но ни один из них не помог.
Вот как я тестировал:
print(loaded_model.predict(test))
print(joblib_model.predict(test))
Вы можете ясно видеть, что модели на самом деле подогнаны, и если я попробую с любыми другими моделями, такими как SVM или логистическая регрессия, метод будет работать нормально.
best_clf = grid_search
. ВашMLPClassifier
код в порядке. - person Scratch'N'Purr   schedule 18.07.2018