Прогнозирование видов ирисов с помощью логистической регрессии

У меня есть код, который сообщает вам, является ли ирис из набора данных iris vireginica или нет, в зависимости от длины лепестка и ширины лепестка. Но как сделать предсказание с совершенно новым цветком?

%matplotlib inline
import numpy as np
from sklearn import datasets
from sklearn.linear_model import LogisticRegression
import matplotlib.pyplot as plt
iris = datasets.load_iris()

from sklearn.linear_model import LogisticRegression

X = iris["data"][:, (2, 3)]  # petal length, petal width
y = (iris["target"] == 2).astype(np.int)

log_reg = LogisticRegression(C=10**10, random_state=42)
log_reg.fit(X, y)

x0, x1 = np.meshgrid(
        np.linspace(2.9, 7, 500).reshape(-1, 1),
        np.linspace(0.8, 2.7, 200).reshape(-1, 1),
    )
X_new = np.c_[x0.ravel(), x1.ravel()]

y_proba = log_reg.predict_proba(X_new)

plt.figure(figsize=(10, 4))
plt.plot(X[y==0, 0], X[y==0, 1], "bs")
plt.plot(X[y==1, 0], X[y==1, 1], "g^")

zz = y_proba[:, 1].reshape(x0.shape)
contour = plt.contour(x0, x1, zz, cmap=plt.cm.brg)


left_right = np.array([2.9, 7])
boundary = -(log_reg.coef_[0][0] * left_right + log_reg.intercept_[0]) / log_reg.coef_[0][1]

plt.clabel(contour, inline=1, fontsize=12)
plt.plot(left_right, boundary, "k--", linewidth=3)
plt.text(3.5, 1.5, "Not Iris-Virginica", fontsize=14, color="b", ha="center")
plt.text(6.5, 2.3, "Iris-Virginica", fontsize=14, color="g", ha="center")
plt.xlabel("Petal length", fontsize=14)
plt.ylabel("Petal width", fontsize=14)
plt.axis([2.9, 7, 0.8, 2.7])
plt.show()

Теперь предположим, что у меня есть новый цветок, который я измеряю:

  • длина чашелистика: 4.8
  • ширина чашелистника: 2,5
  • длина лепестка: 5,3
  • ширина лепестка: 2,4

Когда я пытаюсь выполнить следующий прогноз, я получаю сообщение об ошибке: ValueError: X имеет 1 функцию на образец; ожидает 2

log_reg.predict([[5.3], [2.4]])

Итак, мой вопрос: как мне сделать здесь предсказания о новом цветке и о том, что это за вид?


person Lars B.    schedule 20.06.2020    source источник


Ответы (1)


В документации для метода predict указано, что входной аргумент должен иметь форму (n_samples, n_features), т.е. 1x2 здесь, тогда как вход в вашем случае - 2x1. Попробуй это:

log_reg.predict([[5.3, 2.4]])
person Abhinav Goyal    schedule 20.06.2020