scx_utils/
compat.rs

1// Copyright (c) Meta Platforms, Inc. and affiliates.
2//
3// This software may be used and distributed according to the terms of the
4// GNU General Public License version 2.
5
6use anyhow::{anyhow, bail, Context, Result};
7use libbpf_rs::libbpf_sys::*;
8use libbpf_rs::{AsRawLibbpf, OpenProgramImpl, ProgramImpl};
9use log::warn;
10use std::env;
11use std::ffi::c_void;
12use std::ffi::CStr;
13use std::ffi::CString;
14use std::io;
15use std::io::BufRead;
16use std::io::BufReader;
17use std::mem::size_of;
18use std::slice::from_raw_parts;
19
20const PROCFS_MOUNTS: &str = "/proc/mounts";
21const TRACEFS: &str = "tracefs";
22const DEBUGFS: &str = "debugfs";
23
24lazy_static::lazy_static! {
25    pub static ref SCX_OPS_KEEP_BUILTIN_IDLE: u64 =
26        read_enum("scx_ops_flags", "SCX_OPS_KEEP_BUILTIN_IDLE").unwrap_or(0);
27    pub static ref SCX_OPS_ENQ_LAST: u64 =
28        read_enum("scx_ops_flags", "SCX_OPS_ENQ_LAST").unwrap_or(0);
29    pub static ref SCX_OPS_ENQ_EXITING: u64 =
30        read_enum("scx_ops_flags", "SCX_OPS_ENQ_EXITING").unwrap_or(0);
31    pub static ref SCX_OPS_SWITCH_PARTIAL: u64 =
32        read_enum("scx_ops_flags", "SCX_OPS_SWITCH_PARTIAL").unwrap_or(0);
33    pub static ref SCX_OPS_ENQ_MIGRATION_DISABLED: u64 =
34        read_enum("scx_ops_flags", "SCX_OPS_ENQ_MIGRATION_DISABLED").unwrap_or(0);
35    pub static ref SCX_OPS_ALLOW_QUEUED_WAKEUP: u64 =
36        read_enum("scx_ops_flags", "SCX_OPS_ALLOW_QUEUED_WAKEUP").unwrap_or(0);
37    pub static ref SCX_OPS_BUILTIN_IDLE_PER_NODE: u64 =
38        read_enum("scx_ops_flags", "SCX_OPS_BUILTIN_IDLE_PER_NODE").unwrap_or(0);
39
40    pub static ref SCX_PICK_IDLE_CORE: u64 =
41        read_enum("scx_pick_idle_cpu_flags", "SCX_PICK_IDLE_CORE").unwrap_or(0);
42    pub static ref SCX_PICK_IDLE_IN_NODE: u64 =
43        read_enum("scx_pick_idle_cpu_flags", "SCX_PICK_IDLE_IN_NODE").unwrap_or(0);
44
45    pub static ref ROOT_PREFIX: String =
46        env::var("SCX_SYSFS_PREFIX").unwrap_or("".to_string());
47}
48
49fn load_vmlinux_btf() -> &'static mut btf {
50    let btf = unsafe { btf__load_vmlinux_btf() };
51    if btf.is_null() {
52        panic!("btf__load_vmlinux_btf() returned NULL, was CONFIG_DEBUG_INFO_BTF enabled?")
53    }
54    unsafe { &mut *btf }
55}
56
57lazy_static::lazy_static! {
58    static ref VMLINUX_BTF: &'static mut btf = load_vmlinux_btf();
59}
60
61fn btf_kind(t: &btf_type) -> u32 {
62    (t.info >> 24) & 0x1f
63}
64
65fn btf_vlen(t: &btf_type) -> u32 {
66    t.info & 0xffff
67}
68
69fn btf_type_plus_1(t: &btf_type) -> *const c_void {
70    let ptr_val = t as *const btf_type as usize;
71    (ptr_val + size_of::<btf_type>()) as *const c_void
72}
73
74fn btf_enum(t: &btf_type) -> &[btf_enum] {
75    let ptr = btf_type_plus_1(t);
76    unsafe { from_raw_parts(ptr as *const btf_enum, btf_vlen(t) as usize) }
77}
78
79fn btf_enum64(t: &btf_type) -> &[btf_enum64] {
80    let ptr = btf_type_plus_1(t);
81    unsafe { from_raw_parts(ptr as *const btf_enum64, btf_vlen(t) as usize) }
82}
83
84fn btf_members(t: &btf_type) -> &[btf_member] {
85    let ptr = btf_type_plus_1(t);
86    unsafe { from_raw_parts(ptr as *const btf_member, btf_vlen(t) as usize) }
87}
88
89fn btf_name_str_by_offset(btf: &btf, name_off: u32) -> Result<&str> {
90    let n = unsafe { btf__name_by_offset(btf, name_off) };
91    if n.is_null() {
92        bail!("btf__name_by_offset() returned NULL");
93    }
94    Ok(unsafe { CStr::from_ptr(n) }
95        .to_str()
96        .with_context(|| format!("Failed to convert {:?} to string", n))?)
97}
98
99pub fn read_enum(type_name: &str, name: &str) -> Result<u64> {
100    let btf: &btf = *VMLINUX_BTF;
101
102    let type_name = CString::new(type_name).unwrap();
103    let tid = unsafe { btf__find_by_name(btf, type_name.as_ptr()) };
104    if tid < 0 {
105        bail!("type {:?} doesn't exist, ret={}", type_name, tid);
106    }
107
108    let t = unsafe { btf__type_by_id(btf, tid as _) };
109    if t.is_null() {
110        bail!("btf__type_by_id({}) returned NULL", tid);
111    }
112    let t = unsafe { &*t };
113
114    match btf_kind(t) {
115        BTF_KIND_ENUM => {
116            for e in btf_enum(t).iter() {
117                if btf_name_str_by_offset(btf, e.name_off)? == name {
118                    return Ok(e.val as u64);
119                }
120            }
121        }
122        BTF_KIND_ENUM64 => {
123            for e in btf_enum64(t).iter() {
124                if btf_name_str_by_offset(btf, e.name_off)? == name {
125                    return Ok(((e.val_hi32 as u64) << 32) | (e.val_lo32) as u64);
126                }
127            }
128        }
129        _ => (),
130    }
131
132    Err(anyhow!("{:?} doesn't exist in {:?}", name, type_name))
133}
134
135pub fn struct_has_field(type_name: &str, field: &str) -> Result<bool> {
136    let btf: &btf = *VMLINUX_BTF;
137
138    let type_name = CString::new(type_name).unwrap();
139    let tid = unsafe { btf__find_by_name_kind(btf, type_name.as_ptr(), BTF_KIND_STRUCT) };
140    if tid < 0 {
141        bail!("type {:?} doesn't exist, ret={}", type_name, tid);
142    }
143
144    let t = unsafe { btf__type_by_id(btf, tid as _) };
145    if t.is_null() {
146        bail!("btf__type_by_id({}) returned NULL", tid);
147    }
148    let t = unsafe { &*t };
149
150    for m in btf_members(t).iter() {
151        if btf_name_str_by_offset(btf, m.name_off)? == field {
152            return Ok(true);
153        }
154    }
155
156    Ok(false)
157}
158
159pub fn ksym_exists(ksym: &str) -> Result<bool> {
160    let btf: &btf = *VMLINUX_BTF;
161
162    let ksym_name = CString::new(ksym).unwrap();
163    let tid = unsafe { btf__find_by_name(btf, ksym_name.as_ptr()) };
164    Ok(tid >= 0)
165}
166
167pub fn in_kallsyms(ksym: &str) -> Result<bool> {
168    let file = std::fs::File::open("/proc/kallsyms")?;
169    let reader = std::io::BufReader::new(file);
170
171    for line in reader.lines() {
172        for sym in line.unwrap().split_whitespace() {
173            if ksym == sym {
174                return Ok(true);
175            }
176        }
177    }
178
179    Ok(false)
180}
181
182/// Returns the mount point for a filesystem type.
183pub fn get_fs_mount(mount_type: &str) -> Result<Vec<std::path::PathBuf>> {
184    let proc_mounts_path = std::path::Path::new(PROCFS_MOUNTS);
185
186    let file = std::fs::File::open(proc_mounts_path)
187        .with_context(|| format!("Failed to open {}", proc_mounts_path.display()))?;
188
189    let reader = BufReader::new(file);
190
191    let mut mounts = Vec::new();
192    for line in reader.lines() {
193        let line = line.context("Failed to read line from /proc/mounts")?;
194        let mount_info: Vec<&str> = line.split_whitespace().collect();
195
196        if mount_info.len() > 3 && mount_info[2] == mount_type {
197            let mount_path = std::path::PathBuf::from(mount_info[1]);
198            mounts.push(mount_path);
199        }
200    }
201
202    Ok(mounts)
203}
204
205/// Returns the tracefs mount point.
206pub fn tracefs_mount() -> Result<std::path::PathBuf> {
207    let mounts = get_fs_mount(TRACEFS)?;
208    mounts.into_iter().next().context("No tracefs mount found")
209}
210
211/// Returns the debugfs mount point.
212pub fn debugfs_mount() -> Result<std::path::PathBuf> {
213    let mounts = get_fs_mount(DEBUGFS)?;
214    mounts.into_iter().next().context("No debugfs mount found")
215}
216
217pub fn tracer_available(tracer: &str) -> Result<bool> {
218    let base_path = tracefs_mount().unwrap_or_else(|_| debugfs_mount().unwrap().join("tracing"));
219    let file = match std::fs::File::open(base_path.join("available_tracers")) {
220        Ok(f) => f,
221        Err(_) => return Ok(false),
222    };
223    let reader = std::io::BufReader::new(file);
224
225    for line in reader.lines() {
226        for tc in line.unwrap().split_whitespace() {
227            if tracer == tc {
228                return Ok(true);
229            }
230        }
231    }
232
233    Ok(false)
234}
235
236pub fn tracepoint_exists(tracepoint: &str) -> Result<bool> {
237    let base_path = tracefs_mount().unwrap_or_else(|_| debugfs_mount().unwrap().join("tracing"));
238    let file = match std::fs::File::open(base_path.join("available_events")) {
239        Ok(f) => f,
240        Err(_) => return Ok(false),
241    };
242    let reader = std::io::BufReader::new(file);
243
244    for line in reader.lines() {
245        for tp in line.unwrap().split_whitespace() {
246            if tracepoint == tp {
247                return Ok(true);
248            }
249        }
250    }
251
252    Ok(false)
253}
254
255pub fn cond_kprobe_enable<T>(sym: &str, prog_ptr: &OpenProgramImpl<T>) -> Result<bool> {
256    if in_kallsyms(sym)? {
257        unsafe {
258            bpf_program__set_autoload(prog_ptr.as_libbpf_object().as_ptr(), true);
259        }
260        return Ok(true);
261    } else {
262        warn!("symbol {sym} is missing, kprobe not loaded");
263    }
264
265    Ok(false)
266}
267
268pub fn cond_kprobes_enable<T>(kprobes: Vec<(&str, &OpenProgramImpl<T>)>) -> Result<bool> {
269    // Check if all the symbols exist.
270    for (sym, _) in kprobes.iter() {
271        if in_kallsyms(sym)? == false {
272            warn!("symbol {sym} is missing, kprobe not loaded");
273            return Ok(false);
274        }
275    }
276
277    // Enable all the tracepoints.
278    for (_, ptr) in kprobes.iter() {
279        unsafe {
280            bpf_program__set_autoload(ptr.as_libbpf_object().as_ptr(), true);
281        }
282    }
283
284    Ok(true)
285}
286
287pub fn cond_kprobe_load<T>(sym: &str, prog_ptr: &OpenProgramImpl<T>) -> Result<bool> {
288    if in_kallsyms(sym)? {
289        unsafe {
290            bpf_program__set_autoload(prog_ptr.as_libbpf_object().as_ptr(), true);
291            bpf_program__set_autoattach(prog_ptr.as_libbpf_object().as_ptr(), false);
292        }
293        return Ok(true);
294    } else {
295        warn!("symbol {sym} is missing, kprobe not loaded");
296    }
297
298    Ok(false)
299}
300
301pub fn cond_kprobe_attach<T>(sym: &str, prog_ptr: &ProgramImpl<T>) -> Result<bool> {
302    if in_kallsyms(sym)? {
303        unsafe {
304            bpf_program__attach(prog_ptr.as_libbpf_object().as_ptr());
305        }
306        return Ok(true);
307    } else {
308        warn!("symbol {sym} is missing, kprobe not loaded");
309    }
310
311    Ok(false)
312}
313
314pub fn cond_tracepoint_enable<T>(tracepoint: &str, prog_ptr: &OpenProgramImpl<T>) -> Result<bool> {
315    if tracepoint_exists(tracepoint)? {
316        unsafe {
317            bpf_program__set_autoload(prog_ptr.as_libbpf_object().as_ptr(), true);
318        }
319        return Ok(true);
320    } else {
321        warn!("tracepoint {tracepoint} is missing, tracepoint not loaded");
322    }
323
324    Ok(false)
325}
326
327pub fn cond_tracepoints_enable<T>(tracepoints: Vec<(&str, &OpenProgramImpl<T>)>) -> Result<bool> {
328    // Check if all the tracepoints exist.
329    for (tp, _) in tracepoints.iter() {
330        if tracepoint_exists(tp)? == false {
331            warn!("tracepoint {tp} is missing, tracepoint not loaded");
332            return Ok(false);
333        }
334    }
335
336    // Enable all the tracepoints.
337    for (_, ptr) in tracepoints.iter() {
338        unsafe {
339            bpf_program__set_autoload(ptr.as_libbpf_object().as_ptr(), true);
340        }
341    }
342
343    Ok(true)
344}
345
346pub fn is_sched_ext_enabled() -> io::Result<bool> {
347    let content = std::fs::read_to_string("/sys/kernel/sched_ext/state")?;
348
349    match content.trim() {
350        "enabled" => Ok(true),
351        "disabled" => Ok(false),
352        _ => {
353            // Error if the content is neither "enabled" nor "disabled"
354            Err(io::Error::new(
355                io::ErrorKind::InvalidData,
356                "Unexpected content in /sys/kernel/sched_ext/state",
357            ))
358        }
359    }
360}
361
362#[macro_export]
363macro_rules! unwrap_or_break {
364    ($expr: expr, $label: lifetime) => {{
365        match $expr {
366            Ok(val) => val,
367            Err(e) => break $label Err(e),
368        }
369    }};
370}
371
372pub fn check_min_requirements() -> Result<()> {
373    // ec7e3b0463e1 ("implement-ops") in https://github.com/sched-ext/sched_ext
374    // is the current minimum required kernel version.
375    if let Ok(false) | Err(_) = struct_has_field("sched_ext_ops", "dump") {
376        bail!("sched_ext_ops.dump() missing, kernel too old?");
377    }
378    Ok(())
379}
380
381/// struct sched_ext_ops can change over time. If compat.bpf.h::SCX_OPS_DEFINE()
382/// is used to define ops, and scx_ops_open!(), scx_ops_load!(), and
383/// scx_ops_attach!() are used to open, load and attach it, backward
384/// compatibility is automatically maintained where reasonable.
385#[rustfmt::skip]
386#[macro_export]
387macro_rules! scx_ops_open {
388    ($builder: expr, $obj_ref: expr, $ops: ident, $open_opts: expr) => { 'block: {
389        scx_utils::paste! {
390        scx_utils::unwrap_or_break!(scx_utils::compat::check_min_requirements(), 'block);
391            use ::anyhow::Context;
392            use ::libbpf_rs::skel::SkelBuilder;
393
394            let mut skel = match $open_opts {
395                Some(opts_ref) => { // Match a reference directly
396                    match $builder.open_opts(opts_ref, $obj_ref).context("Failed to open BPF program with options") {
397                        Ok(val) => val,
398                        Err(e) => break 'block Err(e),
399                    }
400                }
401                None => {
402                    match $builder.open($obj_ref).context("Failed to open BPF program") {
403                        Ok(val) => val,
404                        Err(e) => break 'block Err(e),
405                    }
406                }
407            };
408
409            let ops = skel.struct_ops.[<$ops _mut>]();
410            let path = std::path::Path::new("/sys/kernel/sched_ext/hotplug_seq");
411
412            let val = match std::fs::read_to_string(&path) {
413                Ok(val) => val,
414                Err(_) => {
415                    break 'block Err(anyhow::anyhow!("Failed to open or read file {:?}", path));
416                }
417            };
418
419            ops.hotplug_seq = match val.trim().parse::<u64>() {
420                Ok(parsed) => parsed,
421                Err(_) => {
422                    break 'block Err(anyhow::anyhow!("Failed to parse hotplug seq {}", val));
423                }
424            };
425
426            if let Ok(s) = ::std::env::var("SCX_TIMEOUT_MS") {
427                skel.struct_ops.[<$ops _mut>]().timeout_ms = match s.parse::<u32>() {
428                    Ok(ms) => {
429                        ::scx_utils::info!("Setting timeout_ms to {} based on environment", ms);
430                        ms
431                    },
432                    Err(e) => {
433                        break 'block anyhow::Result::Err(e).context("SCX_TIMEOUT_MS has invalid value");
434                    },
435                };
436            }
437
438            {
439                let ops = skel.struct_ops.[<$ops _mut>]();
440
441                let name_field = &mut ops.name;
442
443                let version_suffix = ::scx_utils::build_id::ops_version_suffix(env!("CARGO_PKG_VERSION"));
444                let bytes = version_suffix.as_bytes();
445                let mut i = 0;
446                let mut bytes_idx = 0;
447                let mut found_null = false;
448
449                while i < name_field.len() - 1 {
450                    found_null |= name_field[i] == 0;
451                    if !found_null {
452                        i += 1;
453                        continue;
454                    }
455
456                    if bytes_idx < bytes.len() {
457                        name_field[i] = bytes[bytes_idx] as i8;
458                        bytes_idx += 1;
459                    } else {
460                        break;
461                    }
462                    i += 1;
463                }
464                name_field[i] = 0;
465            }
466
467            $crate::import_enums!(skel);
468
469            let result = ::anyhow::Result::Ok(skel);
470
471            result
472        }
473    }};
474}
475
476/// struct sched_ext_ops can change over time. If compat.bpf.h::SCX_OPS_DEFINE()
477/// is used to define ops, and scx_ops_open!(), scx_ops_load!(), and
478/// scx_ops_attach!() are used to open, load and attach it, backward
479/// compatibility is automatically maintained where reasonable.
480#[rustfmt::skip]
481#[macro_export]
482macro_rules! scx_ops_load {
483    ($skel: expr, $ops: ident, $uei: ident) => { 'block: {
484        scx_utils::paste! {
485            use ::anyhow::Context;
486            use ::libbpf_rs::skel::OpenSkel;
487
488            scx_utils::uei_set_size!($skel, $ops, $uei);
489            $skel.load().context("Failed to load BPF program")
490        }
491    }};
492}
493
494/// Must be used together with scx_ops_load!(). See there.
495#[rustfmt::skip]
496#[macro_export]
497macro_rules! scx_ops_attach {
498    ($skel: expr, $ops: ident) => { 'block: {
499        use ::anyhow::Context;
500        use ::libbpf_rs::skel::Skel;
501
502        if scx_utils::compat::is_sched_ext_enabled().unwrap_or(false) {
503            break 'block Err(anyhow::anyhow!(
504                "another sched_ext scheduler is already running"
505            ));
506        }
507        $skel
508            .attach()
509            .context("Failed to attach non-struct_ops BPF programs")
510            .and_then(|_| {
511                $skel
512                    .maps
513                    .$ops
514                    .attach_struct_ops()
515                    .context("Failed to attach struct_ops BPF programs")
516            })
517    }};
518}
519
520#[cfg(test)]
521mod tests {
522    #[test]
523    fn test_read_enum() {
524        assert_eq!(super::read_enum("pid_type", "PIDTYPE_TGID").unwrap(), 1);
525    }
526
527    #[test]
528    fn test_struct_has_field() {
529        assert!(super::struct_has_field("task_struct", "flags").unwrap());
530        assert!(!super::struct_has_field("task_struct", "NO_SUCH_FIELD").unwrap());
531        assert!(super::struct_has_field("NO_SUCH_STRUCT", "NO_SUCH_FIELD").is_err());
532    }
533
534    #[test]
535    fn test_ksym_exists() {
536        assert!(super::ksym_exists("bpf_task_acquire").unwrap());
537        assert!(!super::ksym_exists("NO_SUCH_KFUNC").unwrap());
538    }
539}