1use std::io::{self, ErrorKind};
7use std::mem::MaybeUninit;
8
9use crate::bpf_intf;
10use crate::bpf_intf::*;
11use crate::bpf_skel::*;
12
13use std::ffi::c_int;
14use std::ffi::c_ulong;
15
16use std::collections::HashMap;
17use std::sync::atomic::AtomicBool;
18use std::sync::atomic::Ordering;
19use std::sync::Arc;
20use std::sync::Once;
21
22use anyhow::Context;
23use anyhow::Result;
24use anyhow::bail;
25
26use plain::Plain;
27use procfs::process::all_processes;
28
29use libbpf_rs::OpenObject;
30use libbpf_rs::ProgramInput;
31use libbpf_rs::libbpf_sys::bpf_object_open_opts;
32
33use libc::{pthread_self, pthread_setschedparam, sched_param};
34
35#[cfg(target_env = "musl")]
36use libc::timespec;
37
38use scx_utils::compat;
39use scx_utils::scx_ops_attach;
40use scx_utils::scx_ops_load;
41use scx_utils::scx_ops_open;
42use scx_utils::uei_exited;
43use scx_utils::uei_report;
44use scx_utils::Topology;
45use scx_utils::UserExitInfo;
46
47use scx_rustland_core::ALLOCATOR;
48
49const SCHED_EXT: i32 = 7;
51
52#[allow(dead_code)]
56pub const RL_CPU_ANY: i32 = bpf_intf::RL_CPU_ANY as i32;
57
58#[derive(Debug, PartialEq, Eq, PartialOrd, Clone)]
79pub struct QueuedTask {
80 pub pid: i32, pub cpu: i32, pub nr_cpus_allowed: u64, pub flags: u64, pub start_ts: u64, pub stop_ts: u64, pub exec_runtime: u64, pub weight: u64, pub vtime: u64, }
90
91#[derive(Debug, PartialEq, Eq, PartialOrd, Clone)]
93pub struct DispatchedTask {
94 pub pid: i32, pub cpu: i32, pub flags: u64, pub slice_ns: u64, pub vtime: u64, }
100
101impl DispatchedTask {
102 pub fn new(task: &QueuedTask) -> Self {
107 DispatchedTask {
108 pid: task.pid,
109 cpu: task.cpu,
110 flags: task.flags,
111 slice_ns: 0, vtime: 0,
113 }
114 }
115}
116
117unsafe impl Plain for bpf_intf::dispatched_task_ctx {}
119
120impl AsMut<bpf_intf::dispatched_task_ctx> for bpf_intf::dispatched_task_ctx {
121 fn as_mut(&mut self) -> &mut bpf_intf::dispatched_task_ctx {
122 self
123 }
124}
125
126struct EnqueuedMessage {
130 inner: bpf_intf::queued_task_ctx,
131}
132
133impl EnqueuedMessage {
134 fn from_bytes(bytes: &[u8]) -> Self {
135 let queued_task_struct = unsafe { *(bytes.as_ptr() as *const bpf_intf::queued_task_ctx) };
136 EnqueuedMessage {
137 inner: queued_task_struct,
138 }
139 }
140
141 fn to_queued_task(&self) -> QueuedTask {
142 QueuedTask {
143 pid: self.inner.pid,
144 cpu: self.inner.cpu,
145 nr_cpus_allowed: self.inner.nr_cpus_allowed,
146 flags: self.inner.flags,
147 start_ts: self.inner.start_ts,
148 stop_ts: self.inner.stop_ts,
149 exec_runtime: self.inner.exec_runtime,
150 weight: self.inner.weight,
151 vtime: self.inner.vtime,
152 }
153 }
154}
155
156pub struct BpfScheduler<'cb> {
157 pub skel: BpfSkel<'cb>, shutdown: Arc<AtomicBool>, queued: libbpf_rs::RingBuffer<'cb>, dispatched: libbpf_rs::UserRingBuffer, struct_ops: Option<libbpf_rs::Link>, }
163
164const BUFSIZE: usize = std::mem::size_of::<QueuedTask>();
169
170#[repr(align(8))]
171struct AlignedBuffer([u8; BUFSIZE]);
172
173static mut BUF: AlignedBuffer = AlignedBuffer([0; BUFSIZE]);
174
175static SET_HANDLER: Once = Once::new();
176
177fn set_ctrlc_handler(shutdown: Arc<AtomicBool>) -> Result<(), anyhow::Error> {
178 SET_HANDLER.call_once(|| {
179 let shutdown_clone = shutdown.clone();
180 ctrlc::set_handler(move || {
181 shutdown_clone.store(true, Ordering::Relaxed);
182 })
183 .expect("Error setting Ctrl-C handler");
184 });
185 Ok(())
186}
187
188impl<'cb> BpfScheduler<'cb> {
189 pub fn init(
190 open_object: &'cb mut MaybeUninit<OpenObject>,
191 open_opts: Option<bpf_object_open_opts>,
192 exit_dump_len: u32,
193 partial: bool,
194 debug: bool,
195 builtin_idle: bool,
196 name: &str,
197 ) -> Result<Self> {
198 let shutdown = Arc::new(AtomicBool::new(false));
199 set_ctrlc_handler(shutdown.clone()).context("Error setting Ctrl-C handler")?;
200
201 let mut skel_builder = BpfSkelBuilder::default();
203 skel_builder.obj_builder.debug(debug);
204 let mut skel = scx_ops_open!(skel_builder, open_object, rustland, open_opts)?;
205
206 fn callback(data: &[u8]) -> i32 {
217 #[allow(static_mut_refs)]
218 unsafe {
219 BUF.0.copy_from_slice(data);
226 }
227
228 0
230 }
231
232 let topo = Topology::new().unwrap();
234 skel.maps.rodata_data.as_mut().unwrap().smt_enabled = topo.smt_enabled;
235
236 skel.struct_ops.rustland_mut().flags = *compat::SCX_OPS_ENQ_LAST
238 | *compat::SCX_OPS_ENQ_MIGRATION_DISABLED
239 | *compat::SCX_OPS_ALLOW_QUEUED_WAKEUP;
240 if partial {
241 skel.struct_ops.rustland_mut().flags |= *compat::SCX_OPS_SWITCH_PARTIAL;
242 }
243 skel.struct_ops.rustland_mut().exit_dump_len = exit_dump_len;
244 skel.maps.rodata_data.as_mut().unwrap().usersched_pid = std::process::id();
245 skel.maps.rodata_data.as_mut().unwrap().khugepaged_pid = Self::khugepaged_pid();
246 skel.maps.rodata_data.as_mut().unwrap().builtin_idle = builtin_idle;
247 skel.maps.rodata_data.as_mut().unwrap().debug = debug;
248 let _ = Self::set_scx_ops_name(&mut skel.struct_ops.rustland_mut().name, name);
249
250 let mut skel = scx_ops_load!(skel, rustland, uei)?;
252
253 Self::init_l2_cache_domains(&mut skel, &topo)?;
255 Self::init_l3_cache_domains(&mut skel, &topo)?;
256
257 let struct_ops = Some(scx_ops_attach!(skel, rustland)?);
258
259 let maps = &skel.maps;
261 let queued_ring_buffer = &maps.queued;
262 let mut rbb = libbpf_rs::RingBufferBuilder::new();
263 rbb.add(queued_ring_buffer, callback)
264 .expect("failed to add ringbuf callback");
265 let queued = rbb.build().expect("failed to build ringbuf");
266
267 let dispatched = libbpf_rs::UserRingBuffer::new(&maps.dispatched)
269 .expect("failed to create user ringbuf");
270
271 ALLOCATOR.lock_memory();
274 ALLOCATOR.disable_mmap().expect("Failed to disable mmap");
275
276 if partial {
278 let err = Self::use_sched_ext();
279 if err < 0 {
280 return Err(anyhow::Error::msg(format!(
281 "sched_setscheduler error: {}",
282 err
283 )));
284 }
285 }
286
287 Ok(Self {
288 skel,
289 shutdown,
290 queued,
291 dispatched,
292 struct_ops,
293 })
294 }
295
296 fn set_scx_ops_name(name_field: &mut [i8], src: &str) -> Result<()> {
298 if !src.is_ascii() {
299 bail!("name must be an ASCII string");
300 }
301
302 let bytes = src.as_bytes();
303 let n = bytes.len().min(name_field.len().saturating_sub(1));
304
305 name_field.fill(0);
306 for i in 0..n {
307 name_field[i] = bytes[i] as i8;
308 }
309
310 let version_suffix = ::scx_utils::build_id::ops_version_suffix(env!("CARGO_PKG_VERSION"));
311 let bytes = version_suffix.as_bytes();
312 let mut i = 0;
313 let mut bytes_idx = 0;
314 let mut found_null = false;
315
316 while i < name_field.len() - 1 {
317 found_null |= name_field[i] == 0;
318 if !found_null {
319 i += 1;
320 continue;
321 }
322
323 if bytes_idx < bytes.len() {
324 name_field[i] = bytes[bytes_idx] as i8;
325 bytes_idx += 1;
326 } else {
327 break;
328 }
329 i += 1;
330 }
331 name_field[i] = 0;
332
333 Ok(())
334 }
335
336 fn khugepaged_pid() -> u32 {
338 let procs = match all_processes() {
339 Ok(p) => p,
340 Err(_) => return 0,
341 };
342
343 for proc in procs {
344 let proc = match proc {
345 Ok(p) => p,
346 Err(_) => continue,
347 };
348
349 if let Ok(stat) = proc.stat() {
350 if proc.exe().is_err() && stat.comm == "khugepaged" {
351 return proc.pid() as u32;
352 }
353 }
354 }
355
356 0
357 }
358
359 fn enable_sibling_cpu(
360 skel: &mut BpfSkel<'_>,
361 lvl: usize,
362 cpu: usize,
363 sibling_cpu: usize,
364 ) -> Result<(), u32> {
365 let prog = &mut skel.progs.enable_sibling_cpu;
366 let mut args = domain_arg {
367 lvl_id: lvl as c_int,
368 cpu_id: cpu as c_int,
369 sibling_cpu_id: sibling_cpu as c_int,
370 };
371 let input = ProgramInput {
372 context_in: Some(unsafe {
373 std::slice::from_raw_parts_mut(
374 &mut args as *mut _ as *mut u8,
375 std::mem::size_of_val(&args),
376 )
377 }),
378 ..Default::default()
379 };
380 let out = prog.test_run(input).unwrap();
381 if out.return_value != 0 {
382 return Err(out.return_value);
383 }
384
385 Ok(())
386 }
387
388 fn init_cache_domains<SiblingCpuFn>(
389 skel: &mut BpfSkel<'_>,
390 topo: &Topology,
391 cache_lvl: usize,
392 enable_sibling_cpu_fn: &SiblingCpuFn,
393 ) -> Result<(), std::io::Error>
394 where
395 SiblingCpuFn: Fn(&mut BpfSkel<'_>, usize, usize, usize) -> Result<(), u32>,
396 {
397 let mut cache_id_map: HashMap<usize, Vec<usize>> = HashMap::new();
399 for core in topo.all_cores.values() {
400 for (cpu_id, cpu) in &core.cpus {
401 let cache_id = match cache_lvl {
402 2 => cpu.l2_id,
403 3 => cpu.l3_id,
404 _ => panic!("invalid cache level {}", cache_lvl),
405 };
406 cache_id_map
407 .entry(cache_id)
408 .or_insert_with(Vec::new)
409 .push(*cpu_id);
410 }
411 }
412
413 for (_cache_id, cpus) in cache_id_map {
415 for cpu in &cpus {
416 for sibling_cpu in &cpus {
417 enable_sibling_cpu_fn(skel, cache_lvl, *cpu, *sibling_cpu).map_err(|e| {
418 io::Error::new(
419 ErrorKind::Other,
420 format!(
421 "enable_sibling_cpu_fn failed for cpu {} sibling {}: err {}",
422 cpu, sibling_cpu, e
423 ),
424 )
425 })?;
426 }
427 }
428 }
429
430 Ok(())
431 }
432
433 fn init_l2_cache_domains(
434 skel: &mut BpfSkel<'_>,
435 topo: &Topology,
436 ) -> Result<(), std::io::Error> {
437 Self::init_cache_domains(skel, topo, 2, &|skel, lvl, cpu, sibling_cpu| {
438 Self::enable_sibling_cpu(skel, lvl, cpu, sibling_cpu)
439 })
440 }
441
442 fn init_l3_cache_domains(
443 skel: &mut BpfSkel<'_>,
444 topo: &Topology,
445 ) -> Result<(), std::io::Error> {
446 Self::init_cache_domains(skel, topo, 3, &|skel, lvl, cpu, sibling_cpu| {
447 Self::enable_sibling_cpu(skel, lvl, cpu, sibling_cpu)
448 })
449 }
450
451 pub fn notify_complete(&mut self, nr_pending: u64) {
458 self.skel.maps.bss_data.as_mut().unwrap().nr_scheduled = nr_pending;
459 std::thread::yield_now();
460 }
461
462 #[allow(dead_code)]
464 pub fn nr_online_cpus_mut(&mut self) -> &mut u64 {
465 &mut self
466 .skel
467 .maps
468 .bss_data
469 .as_mut()
470 .unwrap()
471 .nr_online_cpus
472 }
473
474 #[allow(dead_code)]
476 pub fn nr_running_mut(&mut self) -> &mut u64 {
477 &mut self
478 .skel
479 .maps
480 .bss_data
481 .as_mut()
482 .unwrap()
483 .nr_running
484 }
485
486 #[allow(dead_code)]
488 pub fn nr_queued_mut(&mut self) -> &mut u64 {
489 &mut self
490 .skel
491 .maps
492 .bss_data
493 .as_mut()
494 .unwrap()
495 .nr_queued
496 }
497
498 #[allow(dead_code)]
500 pub fn nr_scheduled_mut(&mut self) -> &mut u64 {
501 &mut self
502 .skel
503 .maps
504 .bss_data
505 .as_mut()
506 .unwrap()
507 .nr_scheduled
508 }
509
510 #[allow(dead_code)]
512 pub fn nr_user_dispatches_mut(&mut self) -> &mut u64 {
513 &mut self
514 .skel
515 .maps
516 .bss_data
517 .as_mut()
518 .unwrap()
519 .nr_user_dispatches
520 }
521
522 #[allow(dead_code)]
524 pub fn nr_kernel_dispatches_mut(&mut self) -> &mut u64 {
525 &mut self
526 .skel
527 .maps
528 .bss_data
529 .as_mut()
530 .unwrap()
531 .nr_kernel_dispatches
532 }
533
534 #[allow(dead_code)]
536 pub fn nr_cancel_dispatches_mut(&mut self) -> &mut u64 {
537 &mut self
538 .skel
539 .maps
540 .bss_data
541 .as_mut()
542 .unwrap()
543 .nr_cancel_dispatches
544 }
545
546 #[allow(dead_code)]
548 pub fn nr_bounce_dispatches_mut(&mut self) -> &mut u64 {
549 &mut self
550 .skel
551 .maps
552 .bss_data
553 .as_mut()
554 .unwrap()
555 .nr_bounce_dispatches
556 }
557
558 #[allow(dead_code)]
560 pub fn nr_failed_dispatches_mut(&mut self) -> &mut u64 {
561 &mut self
562 .skel
563 .maps
564 .bss_data
565 .as_mut()
566 .unwrap()
567 .nr_failed_dispatches
568 }
569
570 #[allow(dead_code)]
572 pub fn nr_sched_congested_mut(&mut self) -> &mut u64 {
573 &mut self
574 .skel
575 .maps
576 .bss_data
577 .as_mut()
578 .unwrap()
579 .nr_sched_congested
580 }
581
582 fn use_sched_ext() -> i32 {
584 #[cfg(target_env = "gnu")]
585 let param: sched_param = sched_param { sched_priority: 0 };
586 #[cfg(target_env = "musl")]
587 let param: sched_param = sched_param {
588 sched_priority: 0,
589 sched_ss_low_priority: 0,
590 sched_ss_repl_period: timespec {
591 tv_sec: 0,
592 tv_nsec: 0,
593 },
594 sched_ss_init_budget: timespec {
595 tv_sec: 0,
596 tv_nsec: 0,
597 },
598 sched_ss_max_repl: 0,
599 };
600
601 unsafe { pthread_setschedparam(pthread_self(), SCHED_EXT, ¶m as *const sched_param) }
602 }
603
604 pub fn select_cpu(&mut self, pid: i32, cpu: i32, flags: u64) -> i32 {
606 let prog = &mut self.skel.progs.rs_select_cpu;
607 let mut args = task_cpu_arg {
608 pid: pid as c_int,
609 cpu: cpu as c_int,
610 flags: flags as c_ulong,
611 };
612 let input = ProgramInput {
613 context_in: Some(unsafe {
614 std::slice::from_raw_parts_mut(
615 &mut args as *mut _ as *mut u8,
616 std::mem::size_of_val(&args),
617 )
618 }),
619 ..Default::default()
620 };
621 let out = prog.test_run(input).unwrap();
622
623 out.return_value as i32
624 }
625
626 #[allow(static_mut_refs)]
628 pub fn dequeue_task(&mut self) -> Result<Option<QueuedTask>, i32> {
629 match self.queued.consume_raw_n(1) {
631 0 => {
632 self.skel.maps.bss_data.as_mut().unwrap().nr_queued = 0;
634 Ok(None)
635 }
636 1 => {
637 let task = unsafe { EnqueuedMessage::from_bytes(&BUF.0).to_queued_task() };
639 self.skel.maps.bss_data.as_mut().unwrap().nr_queued = self
640 .skel
641 .maps
642 .bss_data
643 .as_ref()
644 .unwrap()
645 .nr_queued
646 .saturating_sub(1);
647
648 Ok(Some(task))
649 }
650 res if res < 0 => Err(res),
651 res => panic!(
652 "Unexpected return value from libbpf-rs::consume_raw(): {}",
653 res
654 ),
655 }
656 }
657
658 pub fn dispatch_task(&mut self, task: &DispatchedTask) -> Result<(), libbpf_rs::Error> {
660 let mut urb_sample = self
662 .dispatched
663 .reserve(std::mem::size_of::<bpf_intf::dispatched_task_ctx>())?;
664 let bytes = urb_sample.as_mut();
665 let dispatched_task = plain::from_mut_bytes::<bpf_intf::dispatched_task_ctx>(bytes)
666 .expect("failed to convert bytes");
667
668 let bpf_intf::dispatched_task_ctx {
670 pid,
671 cpu,
672 flags,
673 slice_ns,
674 vtime,
675 ..
676 } = &mut dispatched_task.as_mut();
677
678 *pid = task.pid;
679 *cpu = task.cpu;
680 *flags = task.flags;
681 *slice_ns = task.slice_ns;
682 *vtime = task.vtime;
683
684 self.dispatched
689 .submit(urb_sample)
690 .expect("failed to submit task");
691
692 Ok(())
693 }
694
695 pub fn exited(&mut self) -> bool {
697 self.shutdown.load(Ordering::Relaxed) || uei_exited!(&self.skel, uei)
698 }
699
700 pub fn shutdown_and_report(&mut self) -> Result<UserExitInfo> {
702 let _ = self.struct_ops.take();
703 uei_report!(&self.skel, uei)
704 }
705}
706
707impl Drop for BpfScheduler<'_> {
709 fn drop(&mut self) {
710 if let Some(struct_ops) = self.struct_ops.take() {
711 drop(struct_ops);
712 }
713 ALLOCATOR.unlock_memory();
714 }
715}