1use std::mem::MaybeUninit;
7
8use crate::bpf_intf;
9use crate::bpf_intf::*;
10use crate::bpf_skel::*;
11
12use std::ffi::c_int;
13use std::ffi::c_ulong;
14use std::ffi::CStr;
15
16use std::sync::atomic::AtomicBool;
17use std::sync::atomic::Ordering;
18use std::sync::Arc;
19use std::sync::Once;
20
21use anyhow::bail;
22use anyhow::Context;
23use anyhow::Result;
24
25use plain::Plain;
26use procfs::process::all_processes;
27
28use libbpf_rs::libbpf_sys::bpf_object_open_opts;
29use libbpf_rs::OpenObject;
30use libbpf_rs::ProgramInput;
31
32use libc::{c_char, pthread_self, pthread_setschedparam, sched_param};
33
34#[cfg(target_env = "musl")]
35use libc::timespec;
36
37use scx_utils::compat;
38use scx_utils::scx_ops_attach;
39use scx_utils::scx_ops_load;
40use scx_utils::scx_ops_open;
41use scx_utils::uei_exited;
42use scx_utils::uei_report;
43use scx_utils::Topology;
44use scx_utils::UserExitInfo;
45
46use scx_rustland_core::ALLOCATOR;
47
48const SCHED_EXT: i32 = 7;
50const TASK_COMM_LEN: usize = 16;
51
52#[allow(dead_code)]
56pub const RL_CPU_ANY: i32 = bpf_intf::RL_CPU_ANY as i32;
57
58#[derive(Debug, PartialEq, Eq, PartialOrd, Clone)]
79pub struct QueuedTask {
80 pub pid: i32, pub cpu: i32, pub nr_cpus_allowed: u64, pub flags: u64, pub start_ts: u64, pub stop_ts: u64, pub exec_runtime: u64, pub weight: u64, pub vtime: u64, pub enq_cnt: u64,
90 pub comm: [c_char; TASK_COMM_LEN], }
92
93impl QueuedTask {
94 #[allow(dead_code)]
96 pub fn comm_str(&self) -> String {
97 let c_str = unsafe { CStr::from_ptr(self.comm.as_ptr()) };
99
100 c_str.to_string_lossy().into_owned()
102 }
103}
104
105#[derive(Debug, PartialEq, Eq, PartialOrd, Clone)]
107pub struct DispatchedTask {
108 pub pid: i32, pub cpu: i32, pub flags: u64, pub slice_ns: u64, pub vtime: u64, pub enq_cnt: u64,
114}
115
116impl DispatchedTask {
117 pub fn new(task: &QueuedTask) -> Self {
122 DispatchedTask {
123 pid: task.pid,
124 cpu: task.cpu,
125 flags: task.flags,
126 slice_ns: 0, vtime: 0,
128 enq_cnt: task.enq_cnt,
129 }
130 }
131}
132
133unsafe impl Plain for bpf_intf::dispatched_task_ctx {}
135
136impl AsMut<bpf_intf::dispatched_task_ctx> for bpf_intf::dispatched_task_ctx {
137 fn as_mut(&mut self) -> &mut bpf_intf::dispatched_task_ctx {
138 self
139 }
140}
141
142struct EnqueuedMessage {
146 inner: bpf_intf::queued_task_ctx,
147}
148
149impl EnqueuedMessage {
150 fn from_bytes(bytes: &[u8]) -> Self {
151 let queued_task_struct = unsafe { *(bytes.as_ptr() as *const bpf_intf::queued_task_ctx) };
152 EnqueuedMessage {
153 inner: queued_task_struct,
154 }
155 }
156
157 fn to_queued_task(&self) -> QueuedTask {
158 QueuedTask {
159 pid: self.inner.pid,
160 cpu: self.inner.cpu,
161 nr_cpus_allowed: self.inner.nr_cpus_allowed,
162 flags: self.inner.flags,
163 start_ts: self.inner.start_ts,
164 stop_ts: self.inner.stop_ts,
165 exec_runtime: self.inner.exec_runtime,
166 weight: self.inner.weight,
167 vtime: self.inner.vtime,
168 enq_cnt: self.inner.enq_cnt,
169 comm: self.inner.comm,
170 }
171 }
172}
173
174pub struct BpfScheduler<'cb> {
175 pub skel: BpfSkel<'cb>, shutdown: Arc<AtomicBool>, queued: libbpf_rs::RingBuffer<'cb>, dispatched: libbpf_rs::UserRingBuffer, struct_ops: Option<libbpf_rs::Link>, }
181
182const BUFSIZE: usize = size_of::<queued_task_ctx>();
187
188#[repr(align(8))]
189struct AlignedBuffer([u8; BUFSIZE]);
190
191static mut BUF: AlignedBuffer = AlignedBuffer([0; BUFSIZE]);
192
193static SET_HANDLER: Once = Once::new();
194
195fn set_ctrlc_handler(shutdown: Arc<AtomicBool>) -> Result<(), anyhow::Error> {
196 SET_HANDLER.call_once(|| {
197 let shutdown_clone = shutdown.clone();
198 ctrlc::set_handler(move || {
199 shutdown_clone.store(true, Ordering::Relaxed);
200 })
201 .expect("Error setting Ctrl-C handler");
202 });
203 Ok(())
204}
205
206impl<'cb> BpfScheduler<'cb> {
207 #[allow(clippy::too_many_arguments)]
208 pub fn init(
209 open_object: &'cb mut MaybeUninit<OpenObject>,
210 open_opts: Option<bpf_object_open_opts>,
211 exit_dump_len: u32,
212 partial: bool,
213 debug: bool,
214 builtin_idle: bool,
215 slice_ns: u64,
216 name: &str,
217 ) -> Result<Self> {
218 let shutdown = Arc::new(AtomicBool::new(false));
219 set_ctrlc_handler(shutdown.clone()).context("Error setting Ctrl-C handler")?;
220
221 let mut skel_builder = BpfSkelBuilder::default();
223 skel_builder.obj_builder.debug(debug);
224 let mut skel = scx_ops_open!(skel_builder, open_object, rustland, open_opts)?;
225
226 fn callback(data: &[u8]) -> i32 {
237 #[allow(static_mut_refs)]
238 unsafe {
239 BUF.0.copy_from_slice(data);
246 }
247
248 0
250 }
251
252 let topo = Topology::new().unwrap();
254 skel.maps.rodata_data.as_mut().unwrap().smt_enabled = topo.smt_enabled;
255
256 skel.struct_ops.rustland_mut().flags =
258 *compat::SCX_OPS_ENQ_LAST | *compat::SCX_OPS_ALLOW_QUEUED_WAKEUP;
259 if partial {
260 skel.struct_ops.rustland_mut().flags |= *compat::SCX_OPS_SWITCH_PARTIAL;
261 }
262 skel.struct_ops.rustland_mut().exit_dump_len = exit_dump_len;
263 skel.maps.rodata_data.as_mut().unwrap().usersched_pid = std::process::id();
264 skel.maps.rodata_data.as_mut().unwrap().khugepaged_pid = Self::khugepaged_pid();
265 skel.maps.rodata_data.as_mut().unwrap().builtin_idle = builtin_idle;
266 skel.maps.rodata_data.as_mut().unwrap().slice_ns = slice_ns;
267 skel.maps.rodata_data.as_mut().unwrap().debug = debug;
268 let _ = Self::set_scx_ops_name(&mut skel.struct_ops.rustland_mut().name, name);
269
270 let mut skel = scx_ops_load!(skel, rustland, uei)?;
272
273 let struct_ops = Some(scx_ops_attach!(skel, rustland)?);
274
275 let maps = &skel.maps;
277 let queued_ring_buffer = &maps.queued;
278 let mut rbb = libbpf_rs::RingBufferBuilder::new();
279 rbb.add(queued_ring_buffer, callback)
280 .expect("failed to add ringbuf callback");
281 let queued = rbb.build().expect("failed to build ringbuf");
282
283 let dispatched = libbpf_rs::UserRingBuffer::new(&maps.dispatched)
285 .expect("failed to create user ringbuf");
286
287 ALLOCATOR.lock_memory();
290 ALLOCATOR.disable_mmap().expect("Failed to disable mmap");
291
292 if partial {
294 let err = Self::use_sched_ext();
295 if err < 0 {
296 return Err(anyhow::Error::msg(format!(
297 "sched_setscheduler error: {err}"
298 )));
299 }
300 }
301
302 Ok(Self {
303 skel,
304 shutdown,
305 queued,
306 dispatched,
307 struct_ops,
308 })
309 }
310
311 fn set_scx_ops_name(name_field: &mut [i8], src: &str) -> Result<()> {
313 if !src.is_ascii() {
314 bail!("name must be an ASCII string");
315 }
316
317 let bytes = src.as_bytes();
318 let n = bytes.len().min(name_field.len().saturating_sub(1));
319
320 name_field.fill(0);
321 for i in 0..n {
322 name_field[i] = bytes[i] as i8;
323 }
324
325 let version_suffix = ::scx_utils::build_id::ops_version_suffix(env!("CARGO_PKG_VERSION"));
326 let bytes = version_suffix.as_bytes();
327 let mut i = 0;
328 let mut bytes_idx = 0;
329 let mut found_null = false;
330
331 while i < name_field.len() - 1 {
332 found_null |= name_field[i] == 0;
333 if !found_null {
334 i += 1;
335 continue;
336 }
337
338 if bytes_idx < bytes.len() {
339 name_field[i] = bytes[bytes_idx] as i8;
340 bytes_idx += 1;
341 } else {
342 break;
343 }
344 i += 1;
345 }
346 name_field[i] = 0;
347
348 Ok(())
349 }
350
351 fn khugepaged_pid() -> u32 {
353 let procs = match all_processes() {
354 Ok(p) => p,
355 Err(_) => return 0,
356 };
357
358 for proc in procs {
359 let proc = match proc {
360 Ok(p) => p,
361 Err(_) => continue,
362 };
363
364 if let Ok(stat) = proc.stat() {
365 if proc.exe().is_err() && stat.comm == "khugepaged" {
366 return proc.pid() as u32;
367 }
368 }
369 }
370
371 0
372 }
373
374 pub fn notify_complete(&mut self, nr_pending: u64) {
381 self.skel.maps.bss_data.as_mut().unwrap().nr_scheduled = nr_pending;
382 std::thread::yield_now();
383 }
384
385 #[allow(dead_code)]
387 pub fn nr_online_cpus_mut(&mut self) -> &mut u64 {
388 &mut self.skel.maps.bss_data.as_mut().unwrap().nr_online_cpus
389 }
390
391 #[allow(dead_code)]
393 pub fn nr_running_mut(&mut self) -> &mut u64 {
394 &mut self.skel.maps.bss_data.as_mut().unwrap().nr_running
395 }
396
397 #[allow(dead_code)]
399 pub fn nr_queued_mut(&mut self) -> &mut u64 {
400 &mut self.skel.maps.bss_data.as_mut().unwrap().nr_queued
401 }
402
403 #[allow(dead_code)]
405 pub fn nr_scheduled_mut(&mut self) -> &mut u64 {
406 &mut self.skel.maps.bss_data.as_mut().unwrap().nr_scheduled
407 }
408
409 #[allow(dead_code)]
411 pub fn nr_user_dispatches_mut(&mut self) -> &mut u64 {
412 &mut self.skel.maps.bss_data.as_mut().unwrap().nr_user_dispatches
413 }
414
415 #[allow(dead_code)]
417 pub fn nr_kernel_dispatches_mut(&mut self) -> &mut u64 {
418 &mut self
419 .skel
420 .maps
421 .bss_data
422 .as_mut()
423 .unwrap()
424 .nr_kernel_dispatches
425 }
426
427 #[allow(dead_code)]
429 pub fn nr_cancel_dispatches_mut(&mut self) -> &mut u64 {
430 &mut self
431 .skel
432 .maps
433 .bss_data
434 .as_mut()
435 .unwrap()
436 .nr_cancel_dispatches
437 }
438
439 #[allow(dead_code)]
441 pub fn nr_bounce_dispatches_mut(&mut self) -> &mut u64 {
442 &mut self
443 .skel
444 .maps
445 .bss_data
446 .as_mut()
447 .unwrap()
448 .nr_bounce_dispatches
449 }
450
451 #[allow(dead_code)]
453 pub fn nr_failed_dispatches_mut(&mut self) -> &mut u64 {
454 &mut self
455 .skel
456 .maps
457 .bss_data
458 .as_mut()
459 .unwrap()
460 .nr_failed_dispatches
461 }
462
463 #[allow(dead_code)]
465 pub fn nr_sched_congested_mut(&mut self) -> &mut u64 {
466 &mut self.skel.maps.bss_data.as_mut().unwrap().nr_sched_congested
467 }
468
469 fn use_sched_ext() -> i32 {
471 #[cfg(target_env = "gnu")]
472 let param: sched_param = sched_param { sched_priority: 0 };
473 #[cfg(target_env = "musl")]
474 let param: sched_param = sched_param {
475 sched_priority: 0,
476 sched_ss_low_priority: 0,
477 sched_ss_repl_period: timespec {
478 tv_sec: 0,
479 tv_nsec: 0,
480 },
481 sched_ss_init_budget: timespec {
482 tv_sec: 0,
483 tv_nsec: 0,
484 },
485 sched_ss_max_repl: 0,
486 };
487
488 unsafe { pthread_setschedparam(pthread_self(), SCHED_EXT, ¶m as *const sched_param) }
489 }
490
491 #[allow(dead_code)]
493 pub fn select_cpu(&mut self, pid: i32, cpu: i32, flags: u64) -> i32 {
494 let prog = &mut self.skel.progs.rs_select_cpu;
495 let mut args = task_cpu_arg {
496 pid: pid as c_int,
497 cpu: cpu as c_int,
498 flags: flags as c_ulong,
499 };
500 let input = ProgramInput {
501 context_in: Some(unsafe {
502 std::slice::from_raw_parts_mut(
503 &mut args as *mut _ as *mut u8,
504 std::mem::size_of_val(&args),
505 )
506 }),
507 ..Default::default()
508 };
509 let out = prog.test_run(input).unwrap();
510
511 out.return_value as i32
512 }
513
514 #[allow(static_mut_refs)]
516 pub fn dequeue_task(&mut self) -> Result<Option<QueuedTask>, i32> {
517 let bss_data = self.skel.maps.bss_data.as_mut().unwrap();
518
519 match self.queued.consume_raw_n(1) {
521 0 => {
522 bss_data.nr_queued = 0;
524 Ok(None)
525 }
526 1 => {
527 let task = unsafe { EnqueuedMessage::from_bytes(&BUF.0).to_queued_task() };
529 bss_data.nr_queued = bss_data.nr_queued.saturating_sub(1);
530
531 Ok(Some(task))
532 }
533 res if res < 0 => Err(res),
534 res => panic!("Unexpected return value from libbpf-rs::consume_raw(): {res}"),
535 }
536 }
537
538 pub fn dispatch_task(&mut self, task: &DispatchedTask) -> Result<(), libbpf_rs::Error> {
540 let mut urb_sample = self
542 .dispatched
543 .reserve(std::mem::size_of::<bpf_intf::dispatched_task_ctx>())?;
544 let bytes = urb_sample.as_mut();
545 let dispatched_task = plain::from_mut_bytes::<bpf_intf::dispatched_task_ctx>(bytes)
546 .expect("failed to convert bytes");
547
548 let bpf_intf::dispatched_task_ctx {
550 pid,
551 cpu,
552 flags,
553 slice_ns,
554 vtime,
555 enq_cnt,
556 ..
557 } = dispatched_task;
558
559 *pid = task.pid;
560 *cpu = task.cpu;
561 *flags = task.flags;
562 *slice_ns = task.slice_ns;
563 *vtime = task.vtime;
564 *enq_cnt = task.enq_cnt;
565
566 self.dispatched
571 .submit(urb_sample)
572 .expect("failed to submit task");
573
574 Ok(())
575 }
576
577 pub fn exited(&mut self) -> bool {
579 self.shutdown.load(Ordering::Relaxed) || uei_exited!(&self.skel, uei)
580 }
581
582 pub fn shutdown_and_report(&mut self) -> Result<UserExitInfo> {
584 let _ = self.struct_ops.take();
585 uei_report!(&self.skel, uei)
586 }
587}
588
589impl Drop for BpfScheduler<'_> {
591 fn drop(&mut self) {
592 if let Some(struct_ops) = self.struct_ops.take() {
593 drop(struct_ops);
594 }
595 ALLOCATOR.unlock_memory();
596 }
597}