1use std::mem::MaybeUninit;
7
8use crate::bpf_intf;
9use crate::bpf_intf::*;
10use crate::bpf_skel::*;
11
12use std::ffi::c_int;
13use std::ffi::c_ulong;
14
15use std::collections::HashMap;
16use std::sync::atomic::AtomicBool;
17use std::sync::atomic::Ordering;
18use std::sync::Arc;
19use std::sync::Once;
20
21use anyhow::Context;
22use anyhow::Result;
23
24use plain::Plain;
25
26use libbpf_rs::OpenObject;
27use libbpf_rs::ProgramInput;
28
29use libc::{pthread_self, pthread_setschedparam, sched_param};
30
31#[cfg(target_env = "musl")]
32use libc::timespec;
33
34use scx_utils::compat;
35use scx_utils::scx_ops_attach;
36use scx_utils::scx_ops_load;
37use scx_utils::scx_ops_open;
38use scx_utils::uei_exited;
39use scx_utils::uei_report;
40use scx_utils::Topology;
41use scx_utils::UserExitInfo;
42
43use scx_rustland_core::ALLOCATOR;
44
45const SCHED_EXT: i32 = 7;
47
48#[allow(dead_code)]
52pub const RL_CPU_ANY: i32 = bpf_intf::RL_CPU_ANY as i32;
53
54#[derive(Debug, PartialEq, Eq, PartialOrd, Clone)]
76pub struct QueuedTask {
77 pub pid: i32, pub cpu: i32, pub flags: u64, pub exec_runtime: u64, pub sum_exec_runtime: u64, pub nvcsw: u64, pub weight: u64, pub slice: u64, pub vtime: u64, cpumask_cnt: u64, }
88
89#[derive(Debug, PartialEq, Eq, PartialOrd, Clone)]
91pub struct DispatchedTask {
92 pub pid: i32, pub cpu: i32, pub flags: u64, pub slice_ns: u64, pub vtime: u64, cpumask_cnt: u64, }
99
100impl DispatchedTask {
101 pub fn new(task: &QueuedTask) -> Self {
106 DispatchedTask {
107 pid: task.pid,
108 cpu: task.cpu,
109 flags: task.flags,
110 slice_ns: 0, vtime: 0,
112 cpumask_cnt: task.cpumask_cnt,
113 }
114 }
115}
116
117unsafe impl Plain for bpf_intf::dispatched_task_ctx {}
119
120impl AsMut<bpf_intf::dispatched_task_ctx> for bpf_intf::dispatched_task_ctx {
121 fn as_mut(&mut self) -> &mut bpf_intf::dispatched_task_ctx {
122 self
123 }
124}
125
126struct EnqueuedMessage {
130 inner: bpf_intf::queued_task_ctx,
131}
132
133impl EnqueuedMessage {
134 fn from_bytes(bytes: &[u8]) -> Self {
135 let queued_task_struct = unsafe { *(bytes.as_ptr() as *const bpf_intf::queued_task_ctx) };
136 EnqueuedMessage {
137 inner: queued_task_struct,
138 }
139 }
140
141 fn to_queued_task(&self) -> QueuedTask {
142 QueuedTask {
143 pid: self.inner.pid,
144 cpu: self.inner.cpu,
145 flags: self.inner.flags,
146 exec_runtime: self.inner.exec_runtime,
147 sum_exec_runtime: self.inner.sum_exec_runtime,
148 nvcsw: self.inner.nvcsw,
149 weight: self.inner.weight,
150 slice: self.inner.slice,
151 vtime: self.inner.vtime,
152 cpumask_cnt: self.inner.cpumask_cnt,
153 }
154 }
155}
156
157pub struct BpfScheduler<'cb> {
158 pub skel: BpfSkel<'cb>, shutdown: Arc<AtomicBool>, queued: libbpf_rs::RingBuffer<'cb>, dispatched: libbpf_rs::UserRingBuffer, struct_ops: Option<libbpf_rs::Link>, }
164
165const BUFSIZE: usize = std::mem::size_of::<QueuedTask>();
170
171#[repr(align(8))]
172struct AlignedBuffer([u8; BUFSIZE]);
173
174static mut BUF: AlignedBuffer = AlignedBuffer([0; BUFSIZE]);
175
176const LIBBPF_STOP: i32 = -255;
179
180static SET_HANDLER: Once = Once::new();
181
182fn set_ctrlc_handler(shutdown: Arc<AtomicBool>) -> Result<(), anyhow::Error> {
183 SET_HANDLER.call_once(|| {
184 let shutdown_clone = shutdown.clone();
185 ctrlc::set_handler(move || {
186 shutdown_clone.store(true, Ordering::Relaxed);
187 }).expect("Error setting Ctrl-C handler");
188 });
189 Ok(())
190}
191
192impl<'cb> BpfScheduler<'cb> {
193 pub fn init(
194 open_object: &'cb mut MaybeUninit<OpenObject>,
195 exit_dump_len: u32,
196 partial: bool,
197 debug: bool,
198 builtin_idle: bool,
199 ) -> Result<Self> {
200 let shutdown = Arc::new(AtomicBool::new(false));
201 set_ctrlc_handler(shutdown.clone()).context("Error setting Ctrl-C handler")?;
202
203 let mut skel_builder = BpfSkelBuilder::default();
205 skel_builder.obj_builder.debug(debug);
206 let mut skel = scx_ops_open!(skel_builder, open_object, rustland)?;
207
208 ALLOCATOR.lock_memory();
211
212 fn callback(data: &[u8]) -> i32 {
223 #[allow(static_mut_refs)]
224 unsafe {
225 BUF.0.copy_from_slice(data);
232 }
233
234 LIBBPF_STOP
246 }
247
248 let topo = Topology::new().unwrap();
250 skel.maps.rodata_data.smt_enabled = topo.smt_enabled;
251
252 skel.struct_ops.rustland_mut().flags = 0
254 | *compat::SCX_OPS_ALLOW_QUEUED_WAKEUP
255 | *compat::SCX_OPS_ENQ_MIGRATION_DISABLED
256 | *compat::SCX_OPS_ENQ_LAST
257 | *compat::SCX_OPS_KEEP_BUILTIN_IDLE;
258 if partial {
259 skel.struct_ops.rustland_mut().flags |= *compat::SCX_OPS_SWITCH_PARTIAL;
260 }
261 skel.struct_ops.rustland_mut().exit_dump_len = exit_dump_len;
262
263 skel.maps.bss_data.usersched_pid = std::process::id();
264 skel.maps.rodata_data.builtin_idle = builtin_idle;
265 skel.maps.rodata_data.debug = debug;
266
267 let mut skel = scx_ops_load!(skel, rustland, uei)?;
269
270 Self::init_l2_cache_domains(&mut skel, &topo)?;
272 Self::init_l3_cache_domains(&mut skel, &topo)?;
273
274 let struct_ops = Some(scx_ops_attach!(skel, rustland)?);
275
276 let maps = &skel.maps;
278 let queued_ring_buffer = &maps.queued;
279 let mut rbb = libbpf_rs::RingBufferBuilder::new();
280 rbb.add(queued_ring_buffer, callback)
281 .expect("failed to add ringbuf callback");
282 let queued = rbb.build().expect("failed to build ringbuf");
283
284 let dispatched = libbpf_rs::UserRingBuffer::new(&maps.dispatched)
286 .expect("failed to create user ringbuf");
287
288 match Self::use_sched_ext() {
290 0 => Ok(Self {
291 skel,
292 shutdown,
293 queued,
294 dispatched,
295 struct_ops,
296 }),
297 err => Err(anyhow::Error::msg(format!(
298 "sched_setscheduler error: {}",
299 err
300 ))),
301 }
302 }
303
304 fn enable_sibling_cpu(
305 skel: &mut BpfSkel<'_>,
306 lvl: usize,
307 cpu: usize,
308 sibling_cpu: usize,
309 ) -> Result<(), u32> {
310 let prog = &mut skel.progs.enable_sibling_cpu;
311 let mut args = domain_arg {
312 lvl_id: lvl as c_int,
313 cpu_id: cpu as c_int,
314 sibling_cpu_id: sibling_cpu as c_int,
315 };
316 let input = ProgramInput {
317 context_in: Some(unsafe {
318 std::slice::from_raw_parts_mut(
319 &mut args as *mut _ as *mut u8,
320 std::mem::size_of_val(&args),
321 )
322 }),
323 ..Default::default()
324 };
325 let out = prog.test_run(input).unwrap();
326 if out.return_value != 0 {
327 return Err(out.return_value);
328 }
329
330 Ok(())
331 }
332
333 fn init_cache_domains<SiblingCpuFn>(
334 skel: &mut BpfSkel<'_>,
335 topo: &Topology,
336 cache_lvl: usize,
337 enable_sibling_cpu_fn: &SiblingCpuFn,
338 ) -> Result<(), std::io::Error>
339 where SiblingCpuFn: Fn(&mut BpfSkel<'_>, usize, usize, usize) -> Result<(), u32>
340 {
341 let mut cache_id_map: HashMap<usize, Vec<usize>> = HashMap::new();
343 for core in topo.all_cores.values() {
344 for (cpu_id, cpu) in &core.cpus {
345 let cache_id = match cache_lvl {
346 2 => cpu.l2_id,
347 3 => cpu.l3_id,
348 _ => panic!("invalid cache level {}", cache_lvl),
349 };
350 cache_id_map
351 .entry(cache_id)
352 .or_insert_with(Vec::new)
353 .push(*cpu_id);
354 }
355 }
356
357 for (_cache_id, cpus) in cache_id_map {
359 for cpu in &cpus {
360 for sibling_cpu in &cpus {
361 match enable_sibling_cpu_fn(skel, cache_lvl, *cpu, *sibling_cpu) {
362 Ok(()) => {}
363 Err(_) => {}
364 }
365 }
366 }
367 }
368
369 Ok(())
370 }
371
372 fn init_l2_cache_domains(
373 skel: &mut BpfSkel<'_>,
374 topo: &Topology,
375 ) -> Result<(), std::io::Error> {
376 Self::init_cache_domains(skel, topo, 2, &|skel, lvl, cpu, sibling_cpu| {
377 Self::enable_sibling_cpu(skel, lvl, cpu, sibling_cpu)
378 })
379 }
380
381 fn init_l3_cache_domains(
382 skel: &mut BpfSkel<'_>,
383 topo: &Topology,
384 ) -> Result<(), std::io::Error> {
385 Self::init_cache_domains(skel, topo, 3, &|skel, lvl, cpu, sibling_cpu| {
386 Self::enable_sibling_cpu(skel, lvl, cpu, sibling_cpu)
387 })
388 }
389
390 pub fn notify_complete(&mut self, nr_pending: u64) {
397 self.skel.maps.bss_data.nr_scheduled = nr_pending;
398 std::thread::yield_now();
399 }
400
401 #[allow(dead_code)]
403 pub fn nr_online_cpus_mut(&mut self) -> &mut u64 {
404 &mut self.skel.maps.bss_data.nr_online_cpus
405 }
406
407 #[allow(dead_code)]
409 pub fn nr_running_mut(&mut self) -> &mut u64 {
410 &mut self.skel.maps.bss_data.nr_running
411 }
412
413 #[allow(dead_code)]
415 pub fn nr_queued_mut(&mut self) -> &mut u64 {
416 &mut self.skel.maps.bss_data.nr_queued
417 }
418
419 #[allow(dead_code)]
421 pub fn nr_scheduled_mut(&mut self) -> &mut u64 {
422 &mut self.skel.maps.bss_data.nr_scheduled
423 }
424
425 #[allow(dead_code)]
427 pub fn nr_user_dispatches_mut(&mut self) -> &mut u64 {
428 &mut self.skel.maps.bss_data.nr_user_dispatches
429 }
430
431 #[allow(dead_code)]
433 pub fn nr_kernel_dispatches_mut(&mut self) -> &mut u64 {
434 &mut self.skel.maps.bss_data.nr_kernel_dispatches
435 }
436
437 #[allow(dead_code)]
439 pub fn nr_cancel_dispatches_mut(&mut self) -> &mut u64 {
440 &mut self.skel.maps.bss_data.nr_cancel_dispatches
441 }
442
443 #[allow(dead_code)]
445 pub fn nr_bounce_dispatches_mut(&mut self) -> &mut u64 {
446 &mut self.skel.maps.bss_data.nr_bounce_dispatches
447 }
448
449 #[allow(dead_code)]
451 pub fn nr_failed_dispatches_mut(&mut self) -> &mut u64 {
452 &mut self.skel.maps.bss_data.nr_failed_dispatches
453 }
454
455 #[allow(dead_code)]
457 pub fn nr_sched_congested_mut(&mut self) -> &mut u64 {
458 &mut self.skel.maps.bss_data.nr_sched_congested
459 }
460
461 fn use_sched_ext() -> i32 {
463 #[cfg(target_env = "gnu")]
464 let param: sched_param = sched_param { sched_priority: 0 };
465 #[cfg(target_env = "musl")]
466 let param: sched_param = sched_param {
467 sched_priority: 0,
468 sched_ss_low_priority: 0,
469 sched_ss_repl_period: timespec {
470 tv_sec: 0,
471 tv_nsec: 0,
472 },
473 sched_ss_init_budget: timespec {
474 tv_sec: 0,
475 tv_nsec: 0,
476 },
477 sched_ss_max_repl: 0,
478 };
479
480 unsafe { pthread_setschedparam(pthread_self(), SCHED_EXT, ¶m as *const sched_param) }
481 }
482
483 pub fn select_cpu(&mut self, pid: i32, cpu: i32, flags: u64) -> i32 {
485 let prog = &mut self.skel.progs.rs_select_cpu;
486 let mut args = task_cpu_arg {
487 pid: pid as c_int,
488 cpu: cpu as c_int,
489 flags: flags as c_ulong,
490 };
491 let input = ProgramInput {
492 context_in: Some(unsafe {
493 std::slice::from_raw_parts_mut(
494 &mut args as *mut _ as *mut u8,
495 std::mem::size_of_val(&args),
496 )
497 }),
498 ..Default::default()
499 };
500 let out = prog.test_run(input).unwrap();
501
502 out.return_value as i32
503 }
504
505 #[allow(static_mut_refs)]
507 pub fn dequeue_task(&mut self) -> Result<Option<QueuedTask>, i32> {
508 match self.queued.consume_raw() {
509 0 => {
510 self.skel.maps.bss_data.nr_queued = 0;
511 Ok(None)
512 }
513 LIBBPF_STOP => {
514 let task = unsafe { EnqueuedMessage::from_bytes(&BUF.0).to_queued_task() };
516 let _ = self.skel.maps.bss_data.nr_queued.saturating_sub(1);
517
518 Ok(Some(task))
519 }
520 res if res < 0 => Err(res),
521 res => panic!(
522 "Unexpected return value from libbpf-rs::consume_raw(): {}",
523 res
524 ),
525 }
526 }
527
528 pub fn dispatch_task(&mut self, task: &DispatchedTask) -> Result<(), libbpf_rs::Error> {
530 let mut urb_sample = self
532 .dispatched
533 .reserve(std::mem::size_of::<bpf_intf::dispatched_task_ctx>())?;
534 let bytes = urb_sample.as_mut();
535 let dispatched_task = plain::from_mut_bytes::<bpf_intf::dispatched_task_ctx>(bytes)
536 .expect("failed to convert bytes");
537
538 let bpf_intf::dispatched_task_ctx {
540 pid,
541 cpu,
542 flags,
543 slice_ns,
544 vtime,
545 cpumask_cnt,
546 ..
547 } = &mut dispatched_task.as_mut();
548
549 *pid = task.pid;
550 *cpu = task.cpu;
551 *flags = task.flags;
552 *slice_ns = task.slice_ns;
553 *vtime = task.vtime;
554 *cpumask_cnt = task.cpumask_cnt;
555
556 self.dispatched
561 .submit(urb_sample)
562 .expect("failed to submit task");
563
564 Ok(())
565 }
566
567 pub fn exited(&mut self) -> bool {
569 self.shutdown.load(Ordering::Relaxed) || uei_exited!(&self.skel, uei)
570 }
571
572 pub fn shutdown_and_report(&mut self) -> Result<UserExitInfo> {
574 self.struct_ops.take();
575 uei_report!(&self.skel, uei)
576 }
577}
578
579impl Drop for BpfScheduler<'_> {
581 fn drop(&mut self) {
582 if let Some(struct_ops) = self.struct_ops.take() {
583 drop(struct_ops);
584 }
585 ALLOCATOR.unlock_memory();
586 }
587}