1use std::io::{self, ErrorKind};
7use std::mem::MaybeUninit;
8
9use crate::bpf_intf;
10use crate::bpf_intf::*;
11use crate::bpf_skel::*;
12
13use std::ffi::c_int;
14use std::ffi::c_ulong;
15
16use std::collections::HashMap;
17use std::sync::atomic::AtomicBool;
18use std::sync::atomic::Ordering;
19use std::sync::Arc;
20use std::sync::Once;
21
22use anyhow::Context;
23use anyhow::Result;
24use anyhow::bail;
25
26use plain::Plain;
27use procfs::process::all_processes;
28
29use libbpf_rs::OpenObject;
30use libbpf_rs::ProgramInput;
31use libbpf_rs::libbpf_sys::bpf_object_open_opts;
32
33use libc::{pthread_self, pthread_setschedparam, sched_param};
34
35#[cfg(target_env = "musl")]
36use libc::timespec;
37
38use scx_utils::compat;
39use scx_utils::scx_ops_attach;
40use scx_utils::scx_ops_load;
41use scx_utils::scx_ops_open;
42use scx_utils::uei_exited;
43use scx_utils::uei_report;
44use scx_utils::Topology;
45use scx_utils::UserExitInfo;
46
47use scx_rustland_core::ALLOCATOR;
48
49const SCHED_EXT: i32 = 7;
51
52#[allow(dead_code)]
56pub const RL_CPU_ANY: i32 = bpf_intf::RL_CPU_ANY as i32;
57
58#[derive(Debug, PartialEq, Eq, PartialOrd, Clone)]
79pub struct QueuedTask {
80 pub pid: i32, pub cpu: i32, pub nr_cpus_allowed: u64, pub flags: u64, pub start_ts: u64, pub stop_ts: u64, pub exec_runtime: u64, pub weight: u64, pub vtime: u64, }
90
91#[derive(Debug, PartialEq, Eq, PartialOrd, Clone)]
93pub struct DispatchedTask {
94 pub pid: i32, pub cpu: i32, pub flags: u64, pub slice_ns: u64, pub vtime: u64, }
100
101impl DispatchedTask {
102 pub fn new(task: &QueuedTask) -> Self {
107 DispatchedTask {
108 pid: task.pid,
109 cpu: task.cpu,
110 flags: task.flags,
111 slice_ns: 0, vtime: 0,
113 }
114 }
115}
116
117unsafe impl Plain for bpf_intf::dispatched_task_ctx {}
119
120impl AsMut<bpf_intf::dispatched_task_ctx> for bpf_intf::dispatched_task_ctx {
121 fn as_mut(&mut self) -> &mut bpf_intf::dispatched_task_ctx {
122 self
123 }
124}
125
126struct EnqueuedMessage {
130 inner: bpf_intf::queued_task_ctx,
131}
132
133impl EnqueuedMessage {
134 fn from_bytes(bytes: &[u8]) -> Self {
135 let queued_task_struct = unsafe { *(bytes.as_ptr() as *const bpf_intf::queued_task_ctx) };
136 EnqueuedMessage {
137 inner: queued_task_struct,
138 }
139 }
140
141 fn to_queued_task(&self) -> QueuedTask {
142 QueuedTask {
143 pid: self.inner.pid,
144 cpu: self.inner.cpu,
145 nr_cpus_allowed: self.inner.nr_cpus_allowed,
146 flags: self.inner.flags,
147 start_ts: self.inner.start_ts,
148 stop_ts: self.inner.stop_ts,
149 exec_runtime: self.inner.exec_runtime,
150 weight: self.inner.weight,
151 vtime: self.inner.vtime,
152 }
153 }
154}
155
156pub struct BpfScheduler<'cb> {
157 pub skel: BpfSkel<'cb>, shutdown: Arc<AtomicBool>, queued: libbpf_rs::RingBuffer<'cb>, dispatched: libbpf_rs::UserRingBuffer, struct_ops: Option<libbpf_rs::Link>, }
163
164const BUFSIZE: usize = std::mem::size_of::<QueuedTask>();
169
170#[repr(align(8))]
171struct AlignedBuffer([u8; BUFSIZE]);
172
173static mut BUF: AlignedBuffer = AlignedBuffer([0; BUFSIZE]);
174
175static SET_HANDLER: Once = Once::new();
176
177fn set_ctrlc_handler(shutdown: Arc<AtomicBool>) -> Result<(), anyhow::Error> {
178 SET_HANDLER.call_once(|| {
179 let shutdown_clone = shutdown.clone();
180 ctrlc::set_handler(move || {
181 shutdown_clone.store(true, Ordering::Relaxed);
182 })
183 .expect("Error setting Ctrl-C handler");
184 });
185 Ok(())
186}
187
188impl<'cb> BpfScheduler<'cb> {
189 pub fn init(
190 open_object: &'cb mut MaybeUninit<OpenObject>,
191 open_opts: Option<bpf_object_open_opts>,
192 exit_dump_len: u32,
193 partial: bool,
194 debug: bool,
195 builtin_idle: bool,
196 name: &str,
197 ) -> Result<Self> {
198 let shutdown = Arc::new(AtomicBool::new(false));
199 set_ctrlc_handler(shutdown.clone()).context("Error setting Ctrl-C handler")?;
200
201 let mut skel_builder = BpfSkelBuilder::default();
203 skel_builder.obj_builder.debug(debug);
204 let mut skel = scx_ops_open!(skel_builder, open_object, rustland, open_opts)?;
205
206 fn callback(data: &[u8]) -> i32 {
217 #[allow(static_mut_refs)]
218 unsafe {
219 BUF.0.copy_from_slice(data);
226 }
227
228 0
230 }
231
232 let topo = Topology::new().unwrap();
234 skel.maps.rodata_data.as_mut().unwrap().smt_enabled = topo.smt_enabled;
235
236 skel.struct_ops.rustland_mut().flags = *compat::SCX_OPS_ENQ_LAST
238 | *compat::SCX_OPS_ENQ_MIGRATION_DISABLED
239 | *compat::SCX_OPS_ALLOW_QUEUED_WAKEUP;
240 if partial {
241 skel.struct_ops.rustland_mut().flags |= *compat::SCX_OPS_SWITCH_PARTIAL;
242 }
243 skel.struct_ops.rustland_mut().exit_dump_len = exit_dump_len;
244 skel.maps.rodata_data.as_mut().unwrap().usersched_pid = std::process::id();
245 skel.maps.rodata_data.as_mut().unwrap().khugepaged_pid = Self::khugepaged_pid();
246 skel.maps.rodata_data.as_mut().unwrap().builtin_idle = builtin_idle;
247 skel.maps.rodata_data.as_mut().unwrap().debug = debug;
248 let _ = Self::set_scx_ops_name(&mut skel.struct_ops.rustland_mut().name, name);
249
250 let mut skel = scx_ops_load!(skel, rustland, uei)?;
252
253 Self::init_l2_cache_domains(&mut skel, &topo)?;
255 Self::init_l3_cache_domains(&mut skel, &topo)?;
256
257 let struct_ops = Some(scx_ops_attach!(skel, rustland)?);
258
259 let maps = &skel.maps;
261 let queued_ring_buffer = &maps.queued;
262 let mut rbb = libbpf_rs::RingBufferBuilder::new();
263 rbb.add(queued_ring_buffer, callback)
264 .expect("failed to add ringbuf callback");
265 let queued = rbb.build().expect("failed to build ringbuf");
266
267 let dispatched = libbpf_rs::UserRingBuffer::new(&maps.dispatched)
269 .expect("failed to create user ringbuf");
270
271 ALLOCATOR.lock_memory();
274 ALLOCATOR.disable_mmap().expect("Failed to disable mmap");
275
276 if partial {
278 let err = Self::use_sched_ext();
279 if err < 0 {
280 return Err(anyhow::Error::msg(format!(
281 "sched_setscheduler error: {}",
282 err
283 )));
284 }
285 }
286
287 Ok(Self {
288 skel,
289 shutdown,
290 queued,
291 dispatched,
292 struct_ops,
293 })
294 }
295
296 fn set_scx_ops_name(dst: &mut [i8], s: &str) -> Result<()> {
298 if !s.is_ascii() {
299 bail!("name must be an ASCII string");
300 }
301
302 let bytes = s.as_bytes();
303 let n = bytes.len().min(dst.len().saturating_sub(1));
304
305 dst.fill(0);
306 for i in 0..n {
307 dst[i] = bytes[i] as i8;
308 }
309
310 Ok(())
311 }
312
313 fn khugepaged_pid() -> u32 {
315 let procs = match all_processes() {
316 Ok(p) => p,
317 Err(_) => return 0,
318 };
319
320 for proc in procs {
321 let proc = match proc {
322 Ok(p) => p,
323 Err(_) => continue,
324 };
325
326 if let Ok(stat) = proc.stat() {
327 if proc.exe().is_err() && stat.comm == "khugepaged" {
328 return proc.pid() as u32;
329 }
330 }
331 }
332
333 0
334 }
335
336 fn enable_sibling_cpu(
337 skel: &mut BpfSkel<'_>,
338 lvl: usize,
339 cpu: usize,
340 sibling_cpu: usize,
341 ) -> Result<(), u32> {
342 let prog = &mut skel.progs.enable_sibling_cpu;
343 let mut args = domain_arg {
344 lvl_id: lvl as c_int,
345 cpu_id: cpu as c_int,
346 sibling_cpu_id: sibling_cpu as c_int,
347 };
348 let input = ProgramInput {
349 context_in: Some(unsafe {
350 std::slice::from_raw_parts_mut(
351 &mut args as *mut _ as *mut u8,
352 std::mem::size_of_val(&args),
353 )
354 }),
355 ..Default::default()
356 };
357 let out = prog.test_run(input).unwrap();
358 if out.return_value != 0 {
359 return Err(out.return_value);
360 }
361
362 Ok(())
363 }
364
365 fn init_cache_domains<SiblingCpuFn>(
366 skel: &mut BpfSkel<'_>,
367 topo: &Topology,
368 cache_lvl: usize,
369 enable_sibling_cpu_fn: &SiblingCpuFn,
370 ) -> Result<(), std::io::Error>
371 where
372 SiblingCpuFn: Fn(&mut BpfSkel<'_>, usize, usize, usize) -> Result<(), u32>,
373 {
374 let mut cache_id_map: HashMap<usize, Vec<usize>> = HashMap::new();
376 for core in topo.all_cores.values() {
377 for (cpu_id, cpu) in &core.cpus {
378 let cache_id = match cache_lvl {
379 2 => cpu.l2_id,
380 3 => cpu.l3_id,
381 _ => panic!("invalid cache level {}", cache_lvl),
382 };
383 cache_id_map
384 .entry(cache_id)
385 .or_insert_with(Vec::new)
386 .push(*cpu_id);
387 }
388 }
389
390 for (_cache_id, cpus) in cache_id_map {
392 for cpu in &cpus {
393 for sibling_cpu in &cpus {
394 enable_sibling_cpu_fn(skel, cache_lvl, *cpu, *sibling_cpu).map_err(|e| {
395 io::Error::new(
396 ErrorKind::Other,
397 format!(
398 "enable_sibling_cpu_fn failed for cpu {} sibling {}: err {}",
399 cpu, sibling_cpu, e
400 ),
401 )
402 })?;
403 }
404 }
405 }
406
407 Ok(())
408 }
409
410 fn init_l2_cache_domains(
411 skel: &mut BpfSkel<'_>,
412 topo: &Topology,
413 ) -> Result<(), std::io::Error> {
414 Self::init_cache_domains(skel, topo, 2, &|skel, lvl, cpu, sibling_cpu| {
415 Self::enable_sibling_cpu(skel, lvl, cpu, sibling_cpu)
416 })
417 }
418
419 fn init_l3_cache_domains(
420 skel: &mut BpfSkel<'_>,
421 topo: &Topology,
422 ) -> Result<(), std::io::Error> {
423 Self::init_cache_domains(skel, topo, 3, &|skel, lvl, cpu, sibling_cpu| {
424 Self::enable_sibling_cpu(skel, lvl, cpu, sibling_cpu)
425 })
426 }
427
428 pub fn notify_complete(&mut self, nr_pending: u64) {
435 self.skel.maps.bss_data.as_mut().unwrap().nr_scheduled = nr_pending;
436 std::thread::yield_now();
437 }
438
439 #[allow(dead_code)]
441 pub fn nr_online_cpus_mut(&mut self) -> &mut u64 {
442 &mut self
443 .skel
444 .maps
445 .bss_data
446 .as_mut()
447 .unwrap()
448 .nr_online_cpus
449 }
450
451 #[allow(dead_code)]
453 pub fn nr_running_mut(&mut self) -> &mut u64 {
454 &mut self
455 .skel
456 .maps
457 .bss_data
458 .as_mut()
459 .unwrap()
460 .nr_running
461 }
462
463 #[allow(dead_code)]
465 pub fn nr_queued_mut(&mut self) -> &mut u64 {
466 &mut self
467 .skel
468 .maps
469 .bss_data
470 .as_mut()
471 .unwrap()
472 .nr_queued
473 }
474
475 #[allow(dead_code)]
477 pub fn nr_scheduled_mut(&mut self) -> &mut u64 {
478 &mut self
479 .skel
480 .maps
481 .bss_data
482 .as_mut()
483 .unwrap()
484 .nr_scheduled
485 }
486
487 #[allow(dead_code)]
489 pub fn nr_user_dispatches_mut(&mut self) -> &mut u64 {
490 &mut self
491 .skel
492 .maps
493 .bss_data
494 .as_mut()
495 .unwrap()
496 .nr_user_dispatches
497 }
498
499 #[allow(dead_code)]
501 pub fn nr_kernel_dispatches_mut(&mut self) -> &mut u64 {
502 &mut self
503 .skel
504 .maps
505 .bss_data
506 .as_mut()
507 .unwrap()
508 .nr_kernel_dispatches
509 }
510
511 #[allow(dead_code)]
513 pub fn nr_cancel_dispatches_mut(&mut self) -> &mut u64 {
514 &mut self
515 .skel
516 .maps
517 .bss_data
518 .as_mut()
519 .unwrap()
520 .nr_cancel_dispatches
521 }
522
523 #[allow(dead_code)]
525 pub fn nr_bounce_dispatches_mut(&mut self) -> &mut u64 {
526 &mut self
527 .skel
528 .maps
529 .bss_data
530 .as_mut()
531 .unwrap()
532 .nr_bounce_dispatches
533 }
534
535 #[allow(dead_code)]
537 pub fn nr_failed_dispatches_mut(&mut self) -> &mut u64 {
538 &mut self
539 .skel
540 .maps
541 .bss_data
542 .as_mut()
543 .unwrap()
544 .nr_failed_dispatches
545 }
546
547 #[allow(dead_code)]
549 pub fn nr_sched_congested_mut(&mut self) -> &mut u64 {
550 &mut self
551 .skel
552 .maps
553 .bss_data
554 .as_mut()
555 .unwrap()
556 .nr_sched_congested
557 }
558
559 fn use_sched_ext() -> i32 {
561 #[cfg(target_env = "gnu")]
562 let param: sched_param = sched_param { sched_priority: 0 };
563 #[cfg(target_env = "musl")]
564 let param: sched_param = sched_param {
565 sched_priority: 0,
566 sched_ss_low_priority: 0,
567 sched_ss_repl_period: timespec {
568 tv_sec: 0,
569 tv_nsec: 0,
570 },
571 sched_ss_init_budget: timespec {
572 tv_sec: 0,
573 tv_nsec: 0,
574 },
575 sched_ss_max_repl: 0,
576 };
577
578 unsafe { pthread_setschedparam(pthread_self(), SCHED_EXT, ¶m as *const sched_param) }
579 }
580
581 pub fn select_cpu(&mut self, pid: i32, cpu: i32, flags: u64) -> i32 {
583 let prog = &mut self.skel.progs.rs_select_cpu;
584 let mut args = task_cpu_arg {
585 pid: pid as c_int,
586 cpu: cpu as c_int,
587 flags: flags as c_ulong,
588 };
589 let input = ProgramInput {
590 context_in: Some(unsafe {
591 std::slice::from_raw_parts_mut(
592 &mut args as *mut _ as *mut u8,
593 std::mem::size_of_val(&args),
594 )
595 }),
596 ..Default::default()
597 };
598 let out = prog.test_run(input).unwrap();
599
600 out.return_value as i32
601 }
602
603 #[allow(static_mut_refs)]
605 pub fn dequeue_task(&mut self) -> Result<Option<QueuedTask>, i32> {
606 match self.queued.consume_raw_n(1) {
608 0 => {
609 self.skel.maps.bss_data.as_mut().unwrap().nr_queued = 0;
611 Ok(None)
612 }
613 1 => {
614 let task = unsafe { EnqueuedMessage::from_bytes(&BUF.0).to_queued_task() };
616 self.skel.maps.bss_data.as_mut().unwrap().nr_queued = self
617 .skel
618 .maps
619 .bss_data
620 .as_ref()
621 .unwrap()
622 .nr_queued
623 .saturating_sub(1);
624
625 Ok(Some(task))
626 }
627 res if res < 0 => Err(res),
628 res => panic!(
629 "Unexpected return value from libbpf-rs::consume_raw(): {}",
630 res
631 ),
632 }
633 }
634
635 pub fn dispatch_task(&mut self, task: &DispatchedTask) -> Result<(), libbpf_rs::Error> {
637 let mut urb_sample = self
639 .dispatched
640 .reserve(std::mem::size_of::<bpf_intf::dispatched_task_ctx>())?;
641 let bytes = urb_sample.as_mut();
642 let dispatched_task = plain::from_mut_bytes::<bpf_intf::dispatched_task_ctx>(bytes)
643 .expect("failed to convert bytes");
644
645 let bpf_intf::dispatched_task_ctx {
647 pid,
648 cpu,
649 flags,
650 slice_ns,
651 vtime,
652 ..
653 } = &mut dispatched_task.as_mut();
654
655 *pid = task.pid;
656 *cpu = task.cpu;
657 *flags = task.flags;
658 *slice_ns = task.slice_ns;
659 *vtime = task.vtime;
660
661 self.dispatched
666 .submit(urb_sample)
667 .expect("failed to submit task");
668
669 Ok(())
670 }
671
672 pub fn exited(&mut self) -> bool {
674 self.shutdown.load(Ordering::Relaxed) || uei_exited!(&self.skel, uei)
675 }
676
677 pub fn shutdown_and_report(&mut self) -> Result<UserExitInfo> {
679 let _ = self.struct_ops.take();
680 uei_report!(&self.skel, uei)
681 }
682}
683
684impl Drop for BpfScheduler<'_> {
686 fn drop(&mut self) {
687 if let Some(struct_ops) = self.struct_ops.take() {
688 drop(struct_ops);
689 }
690 ALLOCATOR.unlock_memory();
691 }
692}