Skip to main content

scx_utils/
compat.rs

1// Copyright (c) Meta Platforms, Inc. and affiliates.
2//
3// This software may be used and distributed according to the terms of the
4// GNU General Public License version 2.
5
6use anyhow::{anyhow, bail, Context, Result};
7use libbpf_rs::libbpf_sys::*;
8use libbpf_rs::{AsRawLibbpf, OpenProgramImpl, ProgramImpl};
9use log::warn;
10use std::env;
11use std::ffi::c_void;
12use std::ffi::CStr;
13use std::ffi::CString;
14use std::io;
15use std::io::BufRead;
16use std::io::BufReader;
17use std::mem::size_of;
18use std::slice::from_raw_parts;
19
20const PROCFS_MOUNTS: &str = "/proc/mounts";
21const TRACEFS: &str = "tracefs";
22const DEBUGFS: &str = "debugfs";
23
24lazy_static::lazy_static! {
25    pub static ref SCX_OPS_KEEP_BUILTIN_IDLE: u64 =
26        read_enum("scx_ops_flags", "SCX_OPS_KEEP_BUILTIN_IDLE").unwrap_or(0);
27    pub static ref SCX_OPS_ENQ_LAST: u64 =
28        read_enum("scx_ops_flags", "SCX_OPS_ENQ_LAST").unwrap_or(0);
29    pub static ref SCX_OPS_ENQ_EXITING: u64 =
30        read_enum("scx_ops_flags", "SCX_OPS_ENQ_EXITING").unwrap_or(0);
31    pub static ref SCX_OPS_SWITCH_PARTIAL: u64 =
32        read_enum("scx_ops_flags", "SCX_OPS_SWITCH_PARTIAL").unwrap_or(0);
33    pub static ref SCX_OPS_ENQ_MIGRATION_DISABLED: u64 =
34        read_enum("scx_ops_flags", "SCX_OPS_ENQ_MIGRATION_DISABLED").unwrap_or(0);
35    pub static ref SCX_OPS_ALLOW_QUEUED_WAKEUP: u64 =
36        read_enum("scx_ops_flags", "SCX_OPS_ALLOW_QUEUED_WAKEUP").unwrap_or(0);
37    pub static ref SCX_OPS_BUILTIN_IDLE_PER_NODE: u64 =
38        read_enum("scx_ops_flags", "SCX_OPS_BUILTIN_IDLE_PER_NODE").unwrap_or(0);
39    pub static ref SCX_OPS_ALWAYS_ENQ_IMMED: u64 =
40        read_enum("scx_ops_flags", "SCX_OPS_ALWAYS_ENQ_IMMED").unwrap_or(0);
41
42    pub static ref SCX_PICK_IDLE_CORE: u64 =
43        read_enum("scx_pick_idle_cpu_flags", "SCX_PICK_IDLE_CORE").unwrap_or(0);
44    pub static ref SCX_PICK_IDLE_IN_NODE: u64 =
45        read_enum("scx_pick_idle_cpu_flags", "SCX_PICK_IDLE_IN_NODE").unwrap_or(0);
46
47    pub static ref ROOT_PREFIX: String =
48        env::var("SCX_SYSFS_PREFIX").unwrap_or("".to_string());
49}
50
51fn load_vmlinux_btf() -> &'static mut btf {
52    let btf = unsafe { btf__load_vmlinux_btf() };
53    if btf.is_null() {
54        panic!("btf__load_vmlinux_btf() returned NULL, was CONFIG_DEBUG_INFO_BTF enabled?")
55    }
56    unsafe { &mut *btf }
57}
58
59lazy_static::lazy_static! {
60    static ref VMLINUX_BTF: &'static mut btf = load_vmlinux_btf();
61}
62
63fn btf_kind(t: &btf_type) -> u32 {
64    (t.info >> 24) & 0x1f
65}
66
67fn btf_vlen(t: &btf_type) -> u32 {
68    t.info & 0xffff
69}
70
71fn btf_type_plus_1(t: &btf_type) -> *const c_void {
72    let ptr_val = t as *const btf_type as usize;
73    (ptr_val + size_of::<btf_type>()) as *const c_void
74}
75
76fn btf_enum(t: &btf_type) -> &[btf_enum] {
77    let ptr = btf_type_plus_1(t);
78    unsafe { from_raw_parts(ptr as *const btf_enum, btf_vlen(t) as usize) }
79}
80
81fn btf_enum64(t: &btf_type) -> &[btf_enum64] {
82    let ptr = btf_type_plus_1(t);
83    unsafe { from_raw_parts(ptr as *const btf_enum64, btf_vlen(t) as usize) }
84}
85
86fn btf_members(t: &btf_type) -> &[btf_member] {
87    let ptr = btf_type_plus_1(t);
88    unsafe { from_raw_parts(ptr as *const btf_member, btf_vlen(t) as usize) }
89}
90
91fn btf_name_str_by_offset(btf: &btf, name_off: u32) -> Result<&str> {
92    let n = unsafe { btf__name_by_offset(btf, name_off) };
93    if n.is_null() {
94        bail!("btf__name_by_offset() returned NULL");
95    }
96    Ok(unsafe { CStr::from_ptr(n) }
97        .to_str()
98        .with_context(|| format!("Failed to convert {:?} to string", n))?)
99}
100
101pub fn read_enum(type_name: &str, name: &str) -> Result<u64> {
102    let btf: &btf = *VMLINUX_BTF;
103
104    let type_name = CString::new(type_name).unwrap();
105    let tid = unsafe { btf__find_by_name(btf, type_name.as_ptr()) };
106    if tid < 0 {
107        bail!("type {:?} doesn't exist, ret={}", type_name, tid);
108    }
109
110    let t = unsafe { btf__type_by_id(btf, tid as _) };
111    if t.is_null() {
112        bail!("btf__type_by_id({}) returned NULL", tid);
113    }
114    let t = unsafe { &*t };
115
116    match btf_kind(t) {
117        BTF_KIND_ENUM => {
118            for e in btf_enum(t).iter() {
119                if btf_name_str_by_offset(btf, e.name_off)? == name {
120                    return Ok(e.val as u64);
121                }
122            }
123        }
124        BTF_KIND_ENUM64 => {
125            for e in btf_enum64(t).iter() {
126                if btf_name_str_by_offset(btf, e.name_off)? == name {
127                    return Ok(((e.val_hi32 as u64) << 32) | (e.val_lo32) as u64);
128                }
129            }
130        }
131        _ => (),
132    }
133
134    Err(anyhow!("{:?} doesn't exist in {:?}", name, type_name))
135}
136
137pub fn struct_has_field(type_name: &str, field: &str) -> Result<bool> {
138    let btf: &btf = *VMLINUX_BTF;
139
140    let type_name = CString::new(type_name).unwrap();
141    let tid = unsafe { btf__find_by_name_kind(btf, type_name.as_ptr(), BTF_KIND_STRUCT) };
142    if tid < 0 {
143        bail!("type {:?} doesn't exist, ret={}", type_name, tid);
144    }
145
146    let t = unsafe { btf__type_by_id(btf, tid as _) };
147    if t.is_null() {
148        bail!("btf__type_by_id({}) returned NULL", tid);
149    }
150    let t = unsafe { &*t };
151
152    for m in btf_members(t).iter() {
153        if btf_name_str_by_offset(btf, m.name_off)? == field {
154            return Ok(true);
155        }
156    }
157
158    Ok(false)
159}
160
161pub fn ksym_exists(ksym: &str) -> Result<bool> {
162    let btf: &btf = *VMLINUX_BTF;
163
164    let ksym_name = CString::new(ksym).unwrap();
165    let tid = unsafe { btf__find_by_name(btf, ksym_name.as_ptr()) };
166    Ok(tid >= 0)
167}
168
169pub fn in_kallsyms(ksym: &str) -> Result<bool> {
170    let file = std::fs::File::open("/proc/kallsyms")?;
171    let reader = std::io::BufReader::new(file);
172
173    for line in reader.lines() {
174        for sym in line.unwrap().split_whitespace() {
175            if ksym == sym {
176                return Ok(true);
177            }
178        }
179    }
180
181    Ok(false)
182}
183
184/// Returns the mount point for a filesystem type.
185pub fn get_fs_mount(mount_type: &str) -> Result<Vec<std::path::PathBuf>> {
186    let proc_mounts_path = std::path::Path::new(PROCFS_MOUNTS);
187
188    let file = std::fs::File::open(proc_mounts_path)
189        .with_context(|| format!("Failed to open {}", proc_mounts_path.display()))?;
190
191    let reader = BufReader::new(file);
192
193    let mut mounts = Vec::new();
194    for line in reader.lines() {
195        let line = line.context("Failed to read line from /proc/mounts")?;
196        let mount_info: Vec<&str> = line.split_whitespace().collect();
197
198        if mount_info.len() > 3 && mount_info[2] == mount_type {
199            let mount_path = std::path::PathBuf::from(mount_info[1]);
200            mounts.push(mount_path);
201        }
202    }
203
204    Ok(mounts)
205}
206
207/// Returns the tracefs mount point.
208pub fn tracefs_mount() -> Result<std::path::PathBuf> {
209    let mounts = get_fs_mount(TRACEFS)?;
210    mounts.into_iter().next().context("No tracefs mount found")
211}
212
213/// Returns the debugfs mount point.
214pub fn debugfs_mount() -> Result<std::path::PathBuf> {
215    let mounts = get_fs_mount(DEBUGFS)?;
216    mounts.into_iter().next().context("No debugfs mount found")
217}
218
219pub fn tracer_available(tracer: &str) -> Result<bool> {
220    let base_path = tracefs_mount().unwrap_or_else(|_| debugfs_mount().unwrap().join("tracing"));
221    let file = match std::fs::File::open(base_path.join("available_tracers")) {
222        Ok(f) => f,
223        Err(_) => return Ok(false),
224    };
225    let reader = std::io::BufReader::new(file);
226
227    for line in reader.lines() {
228        for tc in line.unwrap().split_whitespace() {
229            if tracer == tc {
230                return Ok(true);
231            }
232        }
233    }
234
235    Ok(false)
236}
237
238pub fn tracepoint_exists(tracepoint: &str) -> Result<bool> {
239    let base_path = tracefs_mount().unwrap_or_else(|_| debugfs_mount().unwrap().join("tracing"));
240    let file = match std::fs::File::open(base_path.join("available_events")) {
241        Ok(f) => f,
242        Err(_) => return Ok(false),
243    };
244    let reader = std::io::BufReader::new(file);
245
246    for line in reader.lines() {
247        for tp in line.unwrap().split_whitespace() {
248            if tracepoint == tp {
249                return Ok(true);
250            }
251        }
252    }
253
254    Ok(false)
255}
256
257pub fn cond_kprobe_enable<T>(sym: &str, prog_ptr: &OpenProgramImpl<T>) -> Result<bool> {
258    if in_kallsyms(sym)? {
259        unsafe {
260            bpf_program__set_autoload(prog_ptr.as_libbpf_object().as_ptr(), true);
261        }
262        return Ok(true);
263    } else {
264        warn!("symbol {sym} is missing, kprobe not loaded");
265    }
266
267    Ok(false)
268}
269
270pub fn cond_kprobes_enable<T>(kprobes: Vec<(&str, &OpenProgramImpl<T>)>) -> Result<bool> {
271    // Check if all the symbols exist.
272    for (sym, _) in kprobes.iter() {
273        if in_kallsyms(sym)? == false {
274            warn!("symbol {sym} is missing, kprobe not loaded");
275            return Ok(false);
276        }
277    }
278
279    // Enable all the tracepoints.
280    for (_, ptr) in kprobes.iter() {
281        unsafe {
282            bpf_program__set_autoload(ptr.as_libbpf_object().as_ptr(), true);
283        }
284    }
285
286    Ok(true)
287}
288
289pub fn cond_kprobe_load<T>(sym: &str, prog_ptr: &OpenProgramImpl<T>) -> Result<bool> {
290    if in_kallsyms(sym)? {
291        unsafe {
292            bpf_program__set_autoload(prog_ptr.as_libbpf_object().as_ptr(), true);
293            bpf_program__set_autoattach(prog_ptr.as_libbpf_object().as_ptr(), false);
294        }
295        return Ok(true);
296    } else {
297        warn!("symbol {sym} is missing, kprobe not loaded");
298    }
299
300    Ok(false)
301}
302
303pub fn cond_kprobe_attach<T>(sym: &str, prog_ptr: &ProgramImpl<T>) -> Result<bool> {
304    if in_kallsyms(sym)? {
305        unsafe {
306            bpf_program__attach(prog_ptr.as_libbpf_object().as_ptr());
307        }
308        return Ok(true);
309    } else {
310        warn!("symbol {sym} is missing, kprobe not loaded");
311    }
312
313    Ok(false)
314}
315
316pub fn cond_tracepoint_enable<T>(tracepoint: &str, prog_ptr: &OpenProgramImpl<T>) -> Result<bool> {
317    if tracepoint_exists(tracepoint)? {
318        unsafe {
319            bpf_program__set_autoload(prog_ptr.as_libbpf_object().as_ptr(), true);
320        }
321        return Ok(true);
322    } else {
323        warn!("tracepoint {tracepoint} is missing, tracepoint not loaded");
324    }
325
326    Ok(false)
327}
328
329pub fn cond_tracepoints_enable<T>(tracepoints: Vec<(&str, &OpenProgramImpl<T>)>) -> Result<bool> {
330    // Check if all the tracepoints exist.
331    for (tp, _) in tracepoints.iter() {
332        if tracepoint_exists(tp)? == false {
333            warn!("tracepoint {tp} is missing, tracepoint not loaded");
334            return Ok(false);
335        }
336    }
337
338    // Enable all the tracepoints.
339    for (_, ptr) in tracepoints.iter() {
340        unsafe {
341            bpf_program__set_autoload(ptr.as_libbpf_object().as_ptr(), true);
342        }
343    }
344
345    Ok(true)
346}
347
348pub fn is_sched_ext_enabled() -> io::Result<bool> {
349    let content = std::fs::read_to_string("/sys/kernel/sched_ext/state")?;
350
351    match content.trim() {
352        "enabled" => Ok(true),
353        "disabled" => Ok(false),
354        _ => {
355            // Error if the content is neither "enabled" nor "disabled"
356            Err(io::Error::new(
357                io::ErrorKind::InvalidData,
358                "Unexpected content in /sys/kernel/sched_ext/state",
359            ))
360        }
361    }
362}
363
364#[macro_export]
365macro_rules! unwrap_or_break {
366    ($expr: expr, $label: lifetime) => {{
367        match $expr {
368            Ok(val) => val,
369            Err(e) => break $label Err(e),
370        }
371    }};
372}
373
374pub fn check_min_requirements() -> Result<()> {
375    // ec7e3b0463e1 ("implement-ops") in https://github.com/sched-ext/sched_ext
376    // is the current minimum required kernel version.
377    if let Ok(false) | Err(_) = struct_has_field("sched_ext_ops", "dump") {
378        bail!("sched_ext_ops.dump() missing, kernel too old?");
379    }
380    Ok(())
381}
382
383/// struct sched_ext_ops can change over time. If compat.bpf.h::SCX_OPS_DEFINE()
384/// is used to define ops, and scx_ops_open!(), scx_ops_load!(), and
385/// scx_ops_attach!() are used to open, load and attach it, backward
386/// compatibility is automatically maintained where reasonable.
387#[rustfmt::skip]
388#[macro_export]
389macro_rules! scx_ops_open {
390    ($builder: expr, $obj_ref: expr, $ops: ident, $open_opts: expr) => { 'block: {
391        scx_utils::paste! {
392        scx_utils::unwrap_or_break!(scx_utils::compat::check_min_requirements(), 'block);
393            use ::anyhow::Context;
394            use ::libbpf_rs::skel::SkelBuilder;
395
396            let mut skel = match $open_opts {
397                Some(opts_ref) => { // Match a reference directly
398                    match $builder.open_opts(opts_ref, $obj_ref).context("Failed to open BPF program with options") {
399                        Ok(val) => val,
400                        Err(e) => break 'block Err(e),
401                    }
402                }
403                None => {
404                    match $builder.open($obj_ref).context("Failed to open BPF program") {
405                        Ok(val) => val,
406                        Err(e) => break 'block Err(e),
407                    }
408                }
409            };
410
411            let ops = skel.struct_ops.[<$ops _mut>]();
412            let path = std::path::Path::new("/sys/kernel/sched_ext/hotplug_seq");
413
414            let val = match std::fs::read_to_string(&path) {
415                Ok(val) => val,
416                Err(_) => {
417                    break 'block Err(anyhow::anyhow!("Failed to open or read file {:?}", path));
418                }
419            };
420
421            ops.hotplug_seq = match val.trim().parse::<u64>() {
422                Ok(parsed) => parsed,
423                Err(_) => {
424                    break 'block Err(anyhow::anyhow!("Failed to parse hotplug seq {}", val));
425                }
426            };
427
428            if let Ok(s) = ::std::env::var("SCX_TIMEOUT_MS") {
429                skel.struct_ops.[<$ops _mut>]().timeout_ms = match s.parse::<u32>() {
430                    Ok(ms) => {
431                        ::scx_utils::info!("Setting timeout_ms to {} based on environment", ms);
432                        ms
433                    },
434                    Err(e) => {
435                        break 'block anyhow::Result::Err(e).context("SCX_TIMEOUT_MS has invalid value");
436                    },
437                };
438            }
439
440            {
441                let ops = skel.struct_ops.[<$ops _mut>]();
442
443                let name_field = &mut ops.name;
444
445                let version_suffix = ::scx_utils::build_id::ops_version_suffix(env!("CARGO_PKG_VERSION"));
446                let bytes = version_suffix.as_bytes();
447                let mut i = 0;
448                let mut bytes_idx = 0;
449                let mut found_null = false;
450
451                while i < name_field.len() - 1 {
452                    found_null |= name_field[i] == 0;
453                    if !found_null {
454                        i += 1;
455                        continue;
456                    }
457
458                    if bytes_idx < bytes.len() {
459                        name_field[i] = bytes[bytes_idx] as i8;
460                        bytes_idx += 1;
461                    } else {
462                        break;
463                    }
464                    i += 1;
465                }
466                name_field[i] = 0;
467            }
468
469            $crate::import_enums!(skel);
470
471            let result = ::anyhow::Result::Ok(skel);
472
473            result
474        }
475    }};
476}
477
478/// struct sched_ext_ops can change over time. If compat.bpf.h::SCX_OPS_DEFINE()
479/// is used to define ops, and scx_ops_open!(), scx_ops_load!(), and
480/// scx_ops_attach!() are used to open, load and attach it, backward
481/// compatibility is automatically maintained where reasonable.
482#[rustfmt::skip]
483#[macro_export]
484macro_rules! scx_ops_load {
485    ($skel: expr, $ops: ident, $uei: ident) => { 'block: {
486        scx_utils::paste! {
487            use ::anyhow::Context;
488            use ::libbpf_rs::skel::OpenSkel;
489
490            {
491                let ops = $skel.struct_ops.[<$ops _mut>]();
492                if ops.sub_cgroup_id > 0 {
493                    if let Ok(false) | Err(_) = scx_utils::compat::struct_has_field("sched_ext_ops", "sub_cgroup_id") {
494                        ::scx_utils::warn!("kernel doesn't support ops.sub_cgroup_id");
495                        ops.sub_cgroup_id = 0;
496                    }
497                }
498            }
499
500            scx_utils::uei_set_size!($skel, $ops, $uei);
501            $skel.load().context("Failed to load BPF program")
502        }
503    }};
504}
505
506/// Must be used together with scx_ops_load!(). See there.
507#[rustfmt::skip]
508#[macro_export]
509macro_rules! scx_ops_attach {
510    ($skel: expr, $ops: ident) => {
511        scx_ops_attach!($skel, $ops, false)
512    };
513    ($skel: expr, $ops: ident, $is_subsched: expr) => { 'block: {
514        use ::anyhow::Context;
515        use ::libbpf_rs::skel::Skel;
516
517        if !$is_subsched && scx_utils::compat::is_sched_ext_enabled().unwrap_or(false) {
518            break 'block Err(anyhow::anyhow!(
519                "another sched_ext scheduler is already running"
520            ));
521        }
522        $skel
523            .attach()
524            .context("Failed to attach non-struct_ops BPF programs")
525            .and_then(|_| {
526                $skel
527                    .maps
528                    .$ops
529                    .attach_struct_ops()
530                    .context("Failed to attach struct_ops BPF programs")
531            })
532    }};
533}
534
535#[cfg(test)]
536mod tests {
537    #[test]
538    fn test_read_enum() {
539        assert_eq!(super::read_enum("pid_type", "PIDTYPE_TGID").unwrap(), 1);
540    }
541
542    #[test]
543    fn test_struct_has_field() {
544        assert!(super::struct_has_field("task_struct", "flags").unwrap());
545        assert!(!super::struct_has_field("task_struct", "NO_SUCH_FIELD").unwrap());
546        assert!(super::struct_has_field("NO_SUCH_STRUCT", "NO_SUCH_FIELD").is_err());
547    }
548
549    #[test]
550    fn test_ksym_exists() {
551        assert!(super::ksym_exists("bpf_task_acquire").unwrap());
552        assert!(!super::ksym_exists("NO_SUCH_KFUNC").unwrap());
553    }
554}