Как выполнить простой поиск по сетке с помощью Apache Spark

Я попытался использовать класс GridSearch Scikit Learn для настройки гиперпараметров моего алгоритма логистической регрессии.

Однако GridSearch, даже при параллельном использовании нескольких заданий, требует буквально дней для обработки, если только вы не настраиваете только один параметр. Я думал об использовании Apache Spark для ускорения этого процесса, но у меня есть два вопроса.

  • Чтобы использовать Apache Spark, вам буквально нужно несколько машин для распределения рабочей нагрузки? Например, если у вас есть только 1 ноутбук, бессмысленно ли использовать Apache Spark?

  • Есть ли простой способ использовать GridSearch Scikit Learn в Apache Spark?

Я читал документацию, но в ней говорится о запуске параллельных рабочих процессов во всем конвейере машинного обучения, но мне это нужно только для настройки параметров.

Импорт

import datetime
%matplotlib inline

import pylab
import pandas as pd
import math
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import matplotlib.pylab as pylab

import numpy as np
import statsmodels.api as sm
from statsmodels.formula.api import ols

from sklearn import datasets, tree, metrics, model_selection
from sklearn.preprocessing import LabelEncoder
from sklearn.neighbors import KNeighborsClassifier 
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc
from sklearn.linear_model import LogisticRegression, LinearRegression, Perceptron
from sklearn.feature_selection import SelectKBest, chi2, VarianceThreshold, RFE
from sklearn.svm import SVC
from sklearn.cross_validation import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier, VotingClassifier
from sklearn.naive_bayes import GaussianNB

import findspark
findspark.init()
import pyspark
sc = pyspark.SparkContext()

from datetime import datetime as dt
import scipy
import itertools

ucb_w_reindex = pd.read_csv('clean_airbnb.csv')
ucb = pd.read_csv('clean_airbnb.csv')

pylab.rcParams[ 'figure.figsize' ] = 15 , 10
plt.style.use("fivethirtyeight")

new_style = {'grid': False}
plt.rc('axes', **new_style)

Настройка гиперпараметров алгоритма

X = ucb.drop('country_destination', axis=1)
y = ucb['country_destination'].values

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = .3, random_state=42, stratify=y)

knn = KNeighborsClassifier()

parameters = {'leaf_size': range(1, 100), 'n_neighbors': range(1, 10), 'weights': ['uniform', 'distance'], 
              'algorithm': ['kd_tree', 'ball_tree', 'brute', 'auto']}


# ======== What I want to do in Apache Spark ========= #

%%time
parameters = {'n_neighbors': range(1, 100)}
clf1 = GridSearchCV(estimator=knn, param_grid=parameters, n_jobs=5).fit(X_train, y_train)
best = clf1.best_estimator_

# ==================================================== #

person TJB    schedule 24.07.2017    source источник


Ответы (1)


Вы можете использовать библиотеку под названием spark-sklearn для запуска распределенных проверок параметров. Вы правы в том, что вам понадобится кластер машин или одна многопроцессорная машина для параллельного ускорения.

Надеюсь это поможет,

Roope — команда Microsoft MMLSpark

person Roope Astala - MSFT    schedule 24.07.2017