Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 51b5636

Browse files
authoredJun 7, 2017
yarn: add option to pass extra package in conda env (#4)
* yarn: add possibility to pass extra required packages, example with scikit-learn * yarn: set possibility to add packages in conda env
1 parent de8cee8 commit 51b5636

File tree

4 files changed

+13
-10
lines changed

4 files changed

+13
-10
lines changed
 

‎examples/gridsearchcv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757

5858
clf = GridSearchCV(SVC(C=1), tuned_parameters, cv=5,
5959
scoring='%s_weighted' % score, n_jobs=4)
60-
with parallel_backend('yarn'):
60+
with parallel_backend('yarn', packages=['scikit-learn']):
6161
clf.fit(X_train, y_train)
6262

6363
print("Best parameters set found on development set:")

‎joblibhadoop/yarn/backend.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,11 @@
1111
class YarnBackend(ThreadingBackend):
1212
"""The YARN backend class."""
1313

14-
_pool = None
15-
parallel = None
14+
def __init__(self, packages=[]):
15+
"""Constructor"""
16+
self.packages = packages
17+
self._pool = None
18+
self.parallel = None
1619

1720
def effective_n_jobs(self, n_jobs):
1821
"""Return the number of effective jobs running in the backend."""
@@ -27,7 +30,7 @@ def effective_n_jobs(self, n_jobs):
2730
def configure(self, n_jobs, parallel=None, **backend_args):
2831
"""Initialize the backend."""
2932
n_jobs = self.effective_n_jobs(n_jobs)
30-
self._pool = YarnPool(processes=n_jobs)
33+
self._pool = YarnPool(processes=n_jobs, packages=self.packages)
3134
self.parallel = parallel
3235
return n_jobs
3336

‎joblibhadoop/yarn/pool.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
CONDA_ENV_INSTALL_COMMAND = 'conda install -y -q -p {} {}'
2121

2222

23-
def create_conda_env(*extra_packages):
23+
def create_conda_env(*packages):
2424
"""Create a conda environment to pass to Knit"""
2525
# Create conda environment
2626
if os.path.isfile(os.path.join(TEMP_DIR, JOBLIB_YARN_CONDA_ENV + '.zip')):
@@ -29,10 +29,10 @@ def create_conda_env(*extra_packages):
2929
os.system(CONDA_ENV_CREATE_COMMAND.format(
3030
os.path.join(TEMP_DIR, JOBLIB_YARN_CONDA_ENV),
3131
conda_environment_filename()))
32-
if len(*extra_packages):
32+
if len(packages):
3333
os.system(CONDA_ENV_INSTALL_COMMAND.format(
3434
os.path.join(TEMP_DIR, JOBLIB_YARN_CONDA_ENV),
35-
' '.join(*extra_packages)))
35+
' '.join(packages)))
3636
# Archive conda environment
3737
shutil.make_archive(os.path.join(TEMP_DIR, JOBLIB_YARN_CONDA_ENV), 'zip',
3838
root_dir=TEMP_DIR,
@@ -42,14 +42,14 @@ def create_conda_env(*extra_packages):
4242
class YarnPool(RemotePool):
4343
"""The Yarn Pool mananger."""
4444

45-
def __init__(self, processes=None, port=0, authkey=None):
45+
def __init__(self, processes=None, port=0, authkey=None, packages=[]):
4646
super(YarnPool, self).__init__(processes=processes,
4747
port=port,
4848
authkey=authkey,
4949
workerscript=JOBLIB_YARN_WORKER)
5050
self.stopping = False
5151
self.knit = Knit(autodetect=True)
52-
create_conda_env([])
52+
create_conda_env(*packages)
5353
cmd = ('$PYTHON_BIN $CONDA_PREFIX/bin/{} --host {} --port {} --key {}'
5454
.format(JOBLIB_YARN_WORKER,
5555
socket.gethostname(),

‎joblibhadoop/yarn/tests/test_yarn_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def test_parallel_backend_njobs():
4949
register_parallel_backend('yarn', YarnBackend)
5050

5151
# Run in parallel using Yarn backend
52-
with parallel_backend('yarn', n_jobs=5):
52+
with parallel_backend('yarn', n_jobs=5, packages=['scikit-learn']):
5353
result = Parallel(verbose=100)(
5454
delayed(sqrt)(i**2) for i in range(100))
5555

0 commit comments

Comments
 (0)
Please sign in to comment.