Skip to content

[MISC][Bugfix] Use less CPU when message queue has been empty for some time #16226

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
@@ -1741,6 +1741,9 @@ class is dynamically inherited by the worker class. This is used to inject
rank: int = 0
"""Global rank in distributed setup."""

sleep_on_idle: bool = False
"""Reduce CPU usage when sglang is idle."""

@property
def world_size_across_dp(self) -> int:
"""world_size_across_dp is TPxPPxDP, it is the size of the world
@@ -2673,6 +2676,7 @@ def create_draft_parallel_config(
ray_workers_use_nsight=target_parallel_config.
ray_workers_use_nsight,
placement_group=target_parallel_config.placement_group,
sleep_on_idle=target_parallel_config.sleep_on_idle,
)

return draft_parallel_config
65 changes: 56 additions & 9 deletions vllm/distributed/device_communicators/shm_broadcast.py
Original file line number Diff line number Diff line change
@@ -43,6 +43,43 @@ def sched_yield():
time.sleep(0)


class SpinTimer:

def record_activity(self):
pass

def spin(self):
sched_yield()


class SpinSleepTimer(SpinTimer):
"""
In setups which have long inactivity periods it is desirable to reduce
system power consumption when vllm does nothing. This would lead to more
CPU thermal headroom when a request eventually comes, especially when
multiple GPUs are connected as each GPU would otherwise pin one thread at
100% CPU usage.
The simplest solution is to reduce polling frequency when there is no
activity for a certain period of time.
"""

def __init__(self, busy_loop_s: float = 10.0, wait_sleep_s: float = 0.1):
self.last_activity = time.monotonic()
self.busy_loop_s = busy_loop_s
self.wait_sleep_s = wait_sleep_s

def record_activity(self):
self.last_activity = time.monotonic()

def spin(self):
curr_time = time.monotonic()
if curr_time >= self.last_activity + self.busy_loop_s:
time.sleep(self.wait_sleep_s)
else:
sched_yield()


class ShmRingBuffer:

def __init__(self,
@@ -57,7 +94,7 @@ def __init__(self,
of items that can be stored in the buffer are known in advance.
In this case, we don't need to synchronize the access to
the buffer.
Buffer memory layout:
data metadata
| |
@@ -253,6 +290,7 @@ def __init__(
self.local_reader_rank = -1
# rank does not matter for remote readers
self._is_remote_reader = False
self._read_spin_timer = SpinTimer()

self.handle = Handle(
local_reader_ranks=local_reader_ranks,
@@ -269,7 +307,9 @@ def export_handle(self) -> Handle:
return self.handle

@staticmethod
def create_from_handle(handle: Handle, rank) -> "MessageQueue":
def create_from_handle(handle: Handle,
rank,
sleep_on_idle: bool = False) -> "MessageQueue":
self = MessageQueue.__new__(MessageQueue)
self.handle = handle
self._is_writer = False
@@ -291,6 +331,9 @@ def create_from_handle(handle: Handle, rank) -> "MessageQueue":
self.local_socket.connect(socket_addr)

self.remote_socket = None

self._read_spin_timer = SpinSleepTimer(
) if sleep_on_idle else SpinTimer()
else:
self.buffer = None # type: ignore
self.current_idx = -1
@@ -422,7 +465,7 @@ def acquire_read(self,
# we need to wait until it is written

# Release the processor to other threads
sched_yield()
self._read_spin_timer.spin()

# if we wait for a long time, log a message
if (time.monotonic() - start_time
@@ -453,6 +496,8 @@ def acquire_read(self,
metadata_buffer[self.local_reader_rank + 1] = 1
self.current_idx = (self.current_idx +
1) % self.buffer.max_chunks

self._read_spin_timer.record_activity()
break

def enqueue(self, obj, timeout: Optional[float] = None):
@@ -507,11 +552,12 @@ def broadcast_object(self, obj=None):
return self.dequeue()

@staticmethod
def create_from_process_group(pg: Union[ProcessGroup,
StatelessProcessGroup],
max_chunk_bytes,
max_chunks,
writer_rank=0) -> "MessageQueue":
def create_from_process_group(
pg: Union[ProcessGroup, StatelessProcessGroup],
max_chunk_bytes,
max_chunks,
writer_rank=0,
sleep_on_idle: bool = False) -> "MessageQueue":
if isinstance(pg, ProcessGroup):
group_rank = dist.get_rank(pg)
group_world_size = dist.get_world_size(pg)
@@ -552,6 +598,7 @@ def create_from_process_group(pg: Union[ProcessGroup,
handle = recv[0] # type: ignore
else:
handle = pg.broadcast_obj(None, writer_rank)
buffer_io = MessageQueue.create_from_handle(handle, group_rank)
buffer_io = MessageQueue.create_from_handle(
handle, group_rank, sleep_on_idle=sleep_on_idle)
buffer_io.wait_until_ready()
return buffer_io
4 changes: 4 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
@@ -293,6 +293,7 @@ class EngineArgs:
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
max_parallel_loading_workers: Optional[
int] = ParallelConfig.max_parallel_loading_workers
sleep_on_idle: bool = ParallelConfig.sleep_on_idle
block_size: Optional[BlockSize] = CacheConfig.block_size
enable_prefix_caching: Optional[bool] = CacheConfig.enable_prefix_caching
prefix_caching_hash_algo: PrefixCachingHashAlgo = \
@@ -634,6 +635,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
**parallel_kwargs["worker_cls"])
parallel_group.add_argument("--worker-extension-cls",
**parallel_kwargs["worker_extension_cls"])
parallel_group.add_argument("--sleep-on-idle",
**parallel_kwargs["sleep_on_idle"])

# KV cache arguments
cache_kwargs = get_kwargs(CacheConfig)
@@ -1075,6 +1078,7 @@ def create_engine_config(
distributed_executor_backend=self.distributed_executor_backend,
worker_cls=self.worker_cls,
worker_extension_cls=self.worker_extension_cls,
sleep_on_idle=self.sleep_on_idle,
)

speculative_config = self.create_speculative_config(
13 changes: 9 additions & 4 deletions vllm/v1/executor/multiproc_executor.py
Original file line number Diff line number Diff line change
@@ -89,7 +89,9 @@ def _init_executor(self) -> None:

# Workers must be created before wait_for_ready to avoid
# deadlock, since worker.init_device() does a device sync.
self.workers = WorkerProc.wait_for_ready(unready_workers)
self.workers = WorkerProc.wait_for_ready(
unready_workers,
sleep_on_idle=self.parallel_config.sleep_on_idle)

# Ensure message queues are ready. Will deadlock if re-ordered
# Must be kept consistent with the WorkerProc.
@@ -348,7 +350,9 @@ def __init__(

# Initialize MessageQueue for receiving SchedulerOutput
self.rpc_broadcast_mq = MessageQueue.create_from_handle(
input_shm_handle, self.worker.rank)
input_shm_handle,
self.worker.rank,
sleep_on_idle=vllm_config.parallel_config.sleep_on_idle)

# Initializes a message queue for sending the model output
self.worker_response_mq = MessageQueue(1, 1)
@@ -389,7 +393,8 @@ def make_worker_process(

@staticmethod
def wait_for_ready(
unready_proc_handles: list[UnreadyWorkerProcHandle]
unready_proc_handles: list[UnreadyWorkerProcHandle],
sleep_on_idle: bool = False,
) -> list[WorkerProcHandle]:

e = Exception("WorkerProc initialization failed due to "
@@ -412,7 +417,7 @@ def wait_for_ready(

# Extract the message queue handle.
worker_response_mq = MessageQueue.create_from_handle(
response["handle"], 0)
response["handle"], 0, sleep_on_idle=sleep_on_idle)
ready_proc_handles[unready_proc_handle.rank] = (
WorkerProcHandle.from_unready_handle(
unready_proc_handle, worker_response_mq))