Skip to content

Commit 3a00196

Browse files
Add SLURMRunner from jacobtomlinson/dask-hpc-runners (#659)
* Add SLURMRunner from jacobtomlinson/dask-hpc-runners * Remove unused code * Add base class test * Fix typing error in older Python versions * Fix another typing error in older Python versions * If SLURM not installed report 0 cores * Remove timeout mark * Ensure Slurm workers have dev code to run runner script * Add extra nodes * Revert "Add extra nodes" This reverts commit c630c11. * Set scheduler file path correctly * Make debug output less noisy * Refactor docs and add runners section * Add redirects * Fix links
1 parent eb69b58 commit 3a00196

23 files changed

+701
-43
lines changed

ci/slurm.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ function show_network_interfaces {
2828
}
2929

3030
function jobqueue_install {
31-
docker exec slurmctld conda run -n dask-jobqueue /bin/bash -c "cd /dask-jobqueue; pip install -e ."
31+
for c in slurmctld c1 c2; do
32+
docker exec $c conda run -n dask-jobqueue /bin/bash -c "cd /dask-jobqueue; pip install -e ."
33+
done
3234
}
3335

3436
function jobqueue_script {

ci/slurm/docker-compose.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ services:
6969
- slurm_jobdir:/data
7070
- var_log_slurm:/var/log/slurm
7171
- shared_space:/shared_space
72+
- ../..:/dask-jobqueue
7273
expose:
7374
- "6818"
7475
depends_on:
@@ -91,6 +92,7 @@ services:
9192
- slurm_jobdir:/data
9293
- var_log_slurm:/var/log/slurm
9394
- shared_space:/shared_space
95+
- ../..:/dask-jobqueue
9496
expose:
9597
- "6818"
9698
depends_on:

dask_jobqueue/runner.py

Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
import asyncio
2+
import sys
3+
import os
4+
import signal
5+
from contextlib import suppress
6+
from enum import Enum
7+
from typing import Dict, Optional
8+
import warnings
9+
from tornado.ioloop import IOLoop
10+
11+
from distributed.core import CommClosedError, Status, rpc
12+
from distributed.scheduler import Scheduler
13+
from distributed.utils import LoopRunner, import_term, SyncMethodMixin
14+
from distributed.worker import Worker
15+
16+
17+
# Close gracefully when receiving a SIGINT
18+
signal.signal(signal.SIGINT, lambda *_: sys.exit())
19+
20+
21+
class Role(Enum):
22+
"""
23+
This Enum contains the various roles processes can be.
24+
"""
25+
26+
worker = "worker"
27+
scheduler = "scheduler"
28+
client = "client"
29+
30+
31+
class BaseRunner(SyncMethodMixin):
32+
"""Superclass for runner objects.
33+
34+
This class contains common functionality for Dask cluster runner classes.
35+
36+
To implement this class, you must provide
37+
38+
1. A ``get_role`` method which returns a role from the ``Role`` enum.
39+
2. Optionally, a ``set_scheduler_address`` method for the scheduler process to communicate its address.
40+
3. A ``get_scheduler_address`` method for all other processed to recieve the scheduler address.
41+
4. Optionally, a ``get_worker_name`` to provide a platform specific name to the workers.
42+
5. Optionally, a ``before_scheduler_start`` to perform any actions before the scheduler is created.
43+
6. Optionally, a ``before_worker_start`` to perform any actions before the worker is created.
44+
7. Optionally, a ``before_client_start`` to perform any actions before the client code continues.
45+
8. Optionally, a ``on_scheduler_start`` to perform anything on the scheduler once it has started.
46+
9. Optionally, a ``on_worker_start`` to perform anything on the worker once it has started.
47+
48+
For that, you should get the following:
49+
50+
A context manager and object which can be used within a script that is run in parallel to decide which processes
51+
run the scheduler, workers and client code.
52+
53+
"""
54+
55+
__loop: Optional[IOLoop] = None
56+
57+
def __init__(
58+
self,
59+
scheduler: bool = True,
60+
scheduler_options: Dict = None,
61+
worker_class: str = None,
62+
worker_options: Dict = None,
63+
client: bool = True,
64+
asynchronous: bool = False,
65+
loop: asyncio.BaseEventLoop = None,
66+
):
67+
self.status = Status.created
68+
self.scheduler = scheduler
69+
self.scheduler_address = None
70+
self.scheduler_comm = None
71+
self.client = client
72+
if self.client and not self.scheduler:
73+
raise RuntimeError("Cannot run client code without a scheduler.")
74+
self.scheduler_options = (
75+
scheduler_options if scheduler_options is not None else {}
76+
)
77+
self.worker_class = (
78+
Worker if worker_class is None else import_term(worker_class)
79+
)
80+
self.worker_options = worker_options if worker_options is not None else {}
81+
self.role = None
82+
self.__asynchronous = asynchronous
83+
self._loop_runner = LoopRunner(loop=loop, asynchronous=asynchronous)
84+
85+
if not self.__asynchronous:
86+
self._loop_runner.start()
87+
self.sync(self._start)
88+
89+
async def get_role(self) -> str:
90+
raise NotImplementedError()
91+
92+
async def set_scheduler_address(self, scheduler: Scheduler) -> None:
93+
raise None
94+
95+
async def get_scheduler_address(self) -> str:
96+
raise NotImplementedError()
97+
98+
async def get_worker_name(self) -> str:
99+
return None
100+
101+
async def before_scheduler_start(self) -> None:
102+
return None
103+
104+
async def before_worker_start(self) -> None:
105+
return None
106+
107+
async def before_client_start(self) -> None:
108+
return None
109+
110+
async def on_scheduler_start(self, scheduler: Scheduler) -> None:
111+
return None
112+
113+
async def on_worker_start(self, worker: Worker) -> None:
114+
return None
115+
116+
@property
117+
def loop(self) -> Optional[IOLoop]:
118+
loop = self.__loop
119+
if loop is None:
120+
# If the loop is not running when this is called, the LoopRunner.loop
121+
# property will raise a DeprecationWarning
122+
# However subsequent calls might occur - eg atexit, where a stopped
123+
# loop is still acceptable - so we cache access to the loop.
124+
self.__loop = loop = self._loop_runner.loop
125+
return loop
126+
127+
@loop.setter
128+
def loop(self, value: IOLoop) -> None:
129+
warnings.warn(
130+
"setting the loop property is deprecated", DeprecationWarning, stacklevel=2
131+
)
132+
if value is None:
133+
raise ValueError("expected an IOLoop, got None")
134+
self.__loop = value
135+
136+
def __await__(self):
137+
async def _await():
138+
if self.status != Status.running:
139+
await self._start()
140+
return self
141+
142+
return _await().__await__()
143+
144+
async def __aenter__(self):
145+
await self
146+
return self
147+
148+
async def __aexit__(self, *args):
149+
await self._close()
150+
151+
def __enter__(self):
152+
return self.sync(self.__aenter__)
153+
154+
def __exit__(self, typ, value, traceback):
155+
return self.sync(self.__aexit__)
156+
157+
def __del__(self):
158+
with suppress(AttributeError, RuntimeError): # during closing
159+
self.loop.add_callback(self.close)
160+
161+
async def _start(self) -> None:
162+
self.role = await self.get_role()
163+
if self.role == Role.scheduler:
164+
await self.start_scheduler()
165+
os.kill(
166+
os.getpid(), signal.SIGINT
167+
) # Shutdown with a signal to give the event loop time to close
168+
elif self.role == Role.worker:
169+
await self.start_worker()
170+
os.kill(
171+
os.getpid(), signal.SIGINT
172+
) # Shutdown with a signal to give the event loop time to close
173+
elif self.role == Role.client:
174+
self.scheduler_address = await self.get_scheduler_address()
175+
if self.scheduler_address:
176+
self.scheduler_comm = rpc(self.scheduler_address)
177+
await self.before_client_start()
178+
self.status = Status.running
179+
180+
async def start_scheduler(self) -> None:
181+
await self.before_scheduler_start()
182+
async with Scheduler(**self.scheduler_options) as scheduler:
183+
await self.set_scheduler_address(scheduler)
184+
await self.on_scheduler_start(scheduler)
185+
await scheduler.finished()
186+
187+
async def start_worker(self) -> None:
188+
if (
189+
"scheduler_file" not in self.worker_options
190+
and "scheduler_ip" not in self.worker_options
191+
):
192+
self.worker_options["scheduler_ip"] = await self.get_scheduler_address()
193+
worker_name = await self.get_worker_name()
194+
await self.before_worker_start()
195+
async with self.worker_class(name=worker_name, **self.worker_options) as worker:
196+
await self.on_worker_start(worker)
197+
await worker.finished()
198+
199+
async def _close(self) -> None:
200+
if self.status == Status.running:
201+
if self.scheduler_comm:
202+
with suppress(CommClosedError):
203+
await self.scheduler_comm.terminate()
204+
self.status = Status.closed
205+
206+
def close(self) -> None:
207+
return self.sync(self._close)
208+
209+
210+
class AsyncCommWorld:
211+
def __init__(self):
212+
self.roles = {"scheduler": None, "client": None}
213+
self.role_lock = asyncio.Lock()
214+
self.scheduler_address = None
215+
216+
217+
class AsyncRunner(BaseRunner):
218+
def __init__(self, commworld: AsyncCommWorld, *args, **kwargs):
219+
self.commworld = commworld
220+
super().__init__(*args, **kwargs)
221+
222+
async def get_role(self) -> str:
223+
async with self.commworld.role_lock:
224+
if self.commworld.roles["scheduler"] is None and self.scheduler:
225+
self.commworld.roles["scheduler"] = self
226+
return Role.scheduler
227+
elif self.commworld.roles["client"] is None and self.client:
228+
self.commworld.roles["client"] = self
229+
return Role.client
230+
else:
231+
return Role.worker
232+
233+
async def set_scheduler_address(self, scheduler: Scheduler) -> None:
234+
self.commworld.scheduler_address = scheduler.address
235+
236+
async def get_scheduler_address(self) -> str:
237+
while self.commworld.scheduler_address is None:
238+
await asyncio.sleep(0.1)
239+
return self.commworld.scheduler_address

dask_jobqueue/slurm.py

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
import logging
22
import math
33
import warnings
4+
import asyncio
5+
import json
6+
import os
7+
from pathlib import Path
48

59
import dask
10+
from dask.distributed import Scheduler
611

712
from .core import Job, JobQueueCluster, job_parameters, cluster_parameters
13+
from .runner import Role, BaseRunner
814

915
logger = logging.getLogger(__name__)
1016

@@ -26,7 +32,7 @@ def __init__(
2632
job_cpu=None,
2733
job_mem=None,
2834
config_name=None,
29-
**base_class_kwargs
35+
**base_class_kwargs,
3036
):
3137
super().__init__(
3238
scheduler=scheduler, name=name, config_name=config_name, **base_class_kwargs
@@ -177,3 +183,75 @@ class SLURMCluster(JobQueueCluster):
177183
job=job_parameters, cluster=cluster_parameters
178184
)
179185
job_cls = SLURMJob
186+
187+
188+
class WorldTooSmallException(RuntimeError):
189+
"""Not enough Slurm tasks to start all required processes."""
190+
191+
192+
class SLURMRunner(BaseRunner):
193+
def __init__(self, *args, scheduler_file="scheduler-{job_id}.json", **kwargs):
194+
try:
195+
self.proc_id = int(os.environ["SLURM_PROCID"])
196+
self.world_size = self.n_workers = int(os.environ["SLURM_NTASKS"])
197+
self.job_id = int(os.environ["SLURM_JOB_ID"])
198+
except KeyError as e:
199+
raise RuntimeError(
200+
"SLURM_PROCID, SLURM_NTASKS, and SLURM_JOB_ID must be present in the environment."
201+
) from e
202+
if not scheduler_file:
203+
scheduler_file = kwargs.get("scheduler_options", {}).get("scheduler_file")
204+
205+
if not scheduler_file:
206+
raise RuntimeError(
207+
"scheduler_file must be specified in either the "
208+
"scheduler_options or as keyword argument to SlurmRunner."
209+
)
210+
211+
# Encourage filename uniqueness by inserting the job ID
212+
scheduler_file = scheduler_file.format(job_id=self.job_id)
213+
scheduler_file = Path(scheduler_file)
214+
215+
if isinstance(kwargs.get("scheduler_options"), dict):
216+
kwargs["scheduler_options"]["scheduler_file"] = scheduler_file
217+
else:
218+
kwargs["scheduler_options"] = {"scheduler_file": scheduler_file}
219+
if isinstance(kwargs.get("worker_options"), dict):
220+
kwargs["worker_options"]["scheduler_file"] = scheduler_file
221+
else:
222+
kwargs["worker_options"] = {"scheduler_file": scheduler_file}
223+
224+
self.scheduler_file = scheduler_file
225+
226+
super().__init__(*args, **kwargs)
227+
228+
async def get_role(self) -> str:
229+
if self.scheduler and self.client and self.world_size < 3:
230+
raise WorldTooSmallException(
231+
f"Not enough Slurm tasks to start cluster, found {self.world_size}, "
232+
"needs at least 3, one each for the scheduler, client and a worker."
233+
)
234+
elif self.scheduler and self.world_size < 2:
235+
raise WorldTooSmallException(
236+
f"Not enough Slurm tasks to start cluster, found {self.world_size}, "
237+
"needs at least 2, one each for the scheduler and a worker."
238+
)
239+
self.n_workers -= int(self.scheduler) + int(self.client)
240+
if self.proc_id == 0 and self.scheduler:
241+
return Role.scheduler
242+
elif self.proc_id == 1 and self.client:
243+
return Role.client
244+
else:
245+
return Role.worker
246+
247+
async def set_scheduler_address(self, scheduler: Scheduler) -> None:
248+
return
249+
250+
async def get_scheduler_address(self) -> str:
251+
while not self.scheduler_file or not self.scheduler_file.exists():
252+
await asyncio.sleep(0.2)
253+
cfg = json.loads(self.scheduler_file.read_text())
254+
return cfg["address"]
255+
256+
async def get_worker_name(self) -> str:
257+
return self.proc_id
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from dask.distributed import Client
2+
from dask_jobqueue.slurm import SLURMRunner
3+
4+
with SLURMRunner(scheduler_file="/shared_space/{job_id}.json") as runner:
5+
with Client(runner) as client:
6+
assert client.submit(lambda x: x + 1, 10).result() == 11
7+
assert client.submit(lambda x: x + 1, 20, workers=2).result() == 21
8+
print("Test passed")

0 commit comments

Comments
 (0)