scx_utils/
compat.rs

1// Copyright (c) Meta Platforms, Inc. and affiliates.
2//
3// This software may be used and distributed according to the terms of the
4// GNU General Public License version 2.
5
6use anyhow::{Context, Result, anyhow, bail};
7use libbpf_rs::libbpf_sys::*;
8use libbpf_rs::{AsRawLibbpf, OpenProgramImpl};
9use log::warn;
10use std::ffi::CStr;
11use std::ffi::CString;
12use std::ffi::c_void;
13use std::io;
14use std::io::BufRead;
15use std::io::BufReader;
16use std::mem::size_of;
17use std::slice::from_raw_parts;
18
19const PROCFS_MOUNTS: &str = "/proc/mounts";
20const TRACEFS: &str = "tracefs";
21const DEBUGFS: &str = "debugfs";
22
23lazy_static::lazy_static! {
24    pub static ref SCX_OPS_KEEP_BUILTIN_IDLE: u64 =
25        read_enum("scx_ops_flags", "SCX_OPS_KEEP_BUILTIN_IDLE").unwrap_or(0);
26    pub static ref SCX_OPS_ENQ_LAST: u64 =
27        read_enum("scx_ops_flags", "SCX_OPS_ENQ_LAST").unwrap_or(0);
28    pub static ref SCX_OPS_ENQ_EXITING: u64 =
29        read_enum("scx_ops_flags", "SCX_OPS_ENQ_EXITING").unwrap_or(0);
30    pub static ref SCX_OPS_SWITCH_PARTIAL: u64 =
31        read_enum("scx_ops_flags", "SCX_OPS_SWITCH_PARTIAL").unwrap_or(0);
32    pub static ref SCX_OPS_ENQ_MIGRATION_DISABLED: u64 =
33        read_enum("scx_ops_flags", "SCX_OPS_ENQ_MIGRATION_DISABLED").unwrap_or(0);
34    pub static ref SCX_OPS_ALLOW_QUEUED_WAKEUP: u64 =
35        read_enum("scx_ops_flags", "SCX_OPS_ALLOW_QUEUED_WAKEUP").unwrap_or(0);
36    pub static ref SCX_OPS_BUILTIN_IDLE_PER_NODE: u64 =
37        read_enum("scx_ops_flags", "SCX_OPS_BUILTIN_IDLE_PER_NODE").unwrap_or(0);
38
39    pub static ref SCX_PICK_IDLE_CORE: u64 =
40        read_enum("scx_pick_idle_cpu_flags", "SCX_PICK_IDLE_CORE").unwrap_or(0);
41    pub static ref SCX_PICK_IDLE_IN_NODE: u64 =
42        read_enum("scx_pick_idle_cpu_flags", "SCX_PICK_IDLE_IN_NODE").unwrap_or(0);
43}
44
45fn load_vmlinux_btf() -> &'static mut btf {
46    let btf = unsafe { btf__load_vmlinux_btf() };
47    if btf.is_null() {
48        panic!("btf__load_vmlinux_btf() returned NULL, was CONFIG_DEBUG_INFO_BTF enabled?")
49    }
50    unsafe { &mut *btf }
51}
52
53lazy_static::lazy_static! {
54    static ref VMLINUX_BTF: &'static mut btf = load_vmlinux_btf();
55}
56
57fn btf_kind(t: &btf_type) -> u32 {
58    (t.info >> 24) & 0x1f
59}
60
61fn btf_vlen(t: &btf_type) -> u32 {
62    t.info & 0xffff
63}
64
65fn btf_type_plus_1(t: &btf_type) -> *const c_void {
66    let ptr_val = t as *const btf_type as usize;
67    (ptr_val + size_of::<btf_type>()) as *const c_void
68}
69
70fn btf_enum(t: &btf_type) -> &[btf_enum] {
71    let ptr = btf_type_plus_1(t);
72    unsafe { from_raw_parts(ptr as *const btf_enum, btf_vlen(t) as usize) }
73}
74
75fn btf_enum64(t: &btf_type) -> &[btf_enum64] {
76    let ptr = btf_type_plus_1(t);
77    unsafe { from_raw_parts(ptr as *const btf_enum64, btf_vlen(t) as usize) }
78}
79
80fn btf_members(t: &btf_type) -> &[btf_member] {
81    let ptr = btf_type_plus_1(t);
82    unsafe { from_raw_parts(ptr as *const btf_member, btf_vlen(t) as usize) }
83}
84
85fn btf_name_str_by_offset(btf: &btf, name_off: u32) -> Result<&str> {
86    let n = unsafe { btf__name_by_offset(btf, name_off) };
87    if n.is_null() {
88        bail!("btf__name_by_offset() returned NULL");
89    }
90    Ok(unsafe { CStr::from_ptr(n) }
91        .to_str()
92        .with_context(|| format!("Failed to convert {:?} to string", n))?)
93}
94
95pub fn read_enum(type_name: &str, name: &str) -> Result<u64> {
96    let btf: &btf = *VMLINUX_BTF;
97
98    let type_name = CString::new(type_name).unwrap();
99    let tid = unsafe { btf__find_by_name(btf, type_name.as_ptr()) };
100    if tid < 0 {
101        bail!("type {:?} doesn't exist, ret={}", type_name, tid);
102    }
103
104    let t = unsafe { btf__type_by_id(btf, tid as _) };
105    if t.is_null() {
106        bail!("btf__type_by_id({}) returned NULL", tid);
107    }
108    let t = unsafe { &*t };
109
110    match btf_kind(t) {
111        BTF_KIND_ENUM => {
112            for e in btf_enum(t).iter() {
113                if btf_name_str_by_offset(btf, e.name_off)? == name {
114                    return Ok(e.val as u64);
115                }
116            }
117        }
118        BTF_KIND_ENUM64 => {
119            for e in btf_enum64(t).iter() {
120                if btf_name_str_by_offset(btf, e.name_off)? == name {
121                    return Ok(((e.val_hi32 as u64) << 32) | (e.val_lo32) as u64);
122                }
123            }
124        }
125        _ => (),
126    }
127
128    Err(anyhow!("{:?} doesn't exist in {:?}", name, type_name))
129}
130
131pub fn struct_has_field(type_name: &str, field: &str) -> Result<bool> {
132    let btf: &btf = *VMLINUX_BTF;
133
134    let type_name = CString::new(type_name).unwrap();
135    let tid = unsafe { btf__find_by_name_kind(btf, type_name.as_ptr(), BTF_KIND_STRUCT) };
136    if tid < 0 {
137        bail!("type {:?} doesn't exist, ret={}", type_name, tid);
138    }
139
140    let t = unsafe { btf__type_by_id(btf, tid as _) };
141    if t.is_null() {
142        bail!("btf__type_by_id({}) returned NULL", tid);
143    }
144    let t = unsafe { &*t };
145
146    for m in btf_members(t).iter() {
147        if btf_name_str_by_offset(btf, m.name_off)? == field {
148            return Ok(true);
149        }
150    }
151
152    Ok(false)
153}
154
155pub fn ksym_exists(ksym: &str) -> Result<bool> {
156    let btf: &btf = *VMLINUX_BTF;
157
158    let ksym_name = CString::new(ksym).unwrap();
159    let tid = unsafe { btf__find_by_name(btf, ksym_name.as_ptr()) };
160    Ok(tid >= 0)
161}
162
163pub fn in_kallsyms(ksym: &str) -> Result<bool> {
164    let file = std::fs::File::open("/proc/kallsyms")?;
165    let reader = std::io::BufReader::new(file);
166
167    for line in reader.lines() {
168        for sym in line.unwrap().split_whitespace() {
169            if ksym == sym {
170                return Ok(true);
171            }
172        }
173    }
174
175    Ok(false)
176}
177
178/// Returns the mount point for a filesystem type.
179pub fn get_fs_mount(mount_type: &str) -> Result<Vec<std::path::PathBuf>> {
180    let proc_mounts_path = std::path::Path::new(PROCFS_MOUNTS);
181
182    let file = std::fs::File::open(proc_mounts_path)
183        .with_context(|| format!("Failed to open {}", proc_mounts_path.display()))?;
184
185    let reader = BufReader::new(file);
186
187    let mut mounts = Vec::new();
188    for line in reader.lines() {
189        let line = line.context("Failed to read line from /proc/mounts")?;
190        let mount_info: Vec<&str> = line.split_whitespace().collect();
191
192        if mount_info.len() > 3 && mount_info[2] == mount_type {
193            let mount_path = std::path::PathBuf::from(mount_info[1]);
194            mounts.push(mount_path);
195        }
196    }
197
198    Ok(mounts)
199}
200
201/// Returns the tracefs mount point.
202pub fn tracefs_mount() -> Result<std::path::PathBuf> {
203    let mounts = get_fs_mount(TRACEFS)?;
204    mounts.into_iter().next().context("No tracefs mount found")
205}
206
207/// Returns the debugfs mount point.
208pub fn debugfs_mount() -> Result<std::path::PathBuf> {
209    let mounts = get_fs_mount(DEBUGFS)?;
210    mounts.into_iter().next().context("No debugfs mount found")
211}
212
213pub fn tracepoint_exists(tracepoint: &str) -> Result<bool> {
214    let base_path = tracefs_mount().unwrap_or_else(|_| debugfs_mount().unwrap().join("tracing"));
215    let file = std::fs::File::open(base_path.join("available_events"))?;
216    let reader = std::io::BufReader::new(file);
217
218    for line in reader.lines() {
219        for tp in line.unwrap().split_whitespace() {
220            if tracepoint == tp {
221                return Ok(true);
222            }
223        }
224    }
225
226    Ok(false)
227}
228
229pub fn cond_kprobe_enable<T>(sym: &str, prog_ptr: &OpenProgramImpl<T>) -> Result<bool> {
230    if in_kallsyms(sym)? {
231        unsafe {
232            bpf_program__set_autoload(prog_ptr.as_libbpf_object().as_ptr(), true);
233        }
234        return Ok(true);
235    } else {
236        warn!("symbol {} is missing, kprobe not loaded", sym);
237    }
238
239    Ok(false)
240}
241
242pub fn cond_tracepoint_enable<T>(tracepoint: &str, prog_ptr: &OpenProgramImpl<T>) -> Result<bool> {
243    if tracepoint_exists(tracepoint)? {
244        unsafe {
245            bpf_program__set_autoload(prog_ptr.as_libbpf_object().as_ptr(), true);
246        }
247        return Ok(true);
248    } else {
249        warn!(
250            "tradepoint {} is missing, tracepoint not loaded",
251            tracepoint
252        );
253    }
254
255    Ok(false)
256}
257pub fn is_sched_ext_enabled() -> io::Result<bool> {
258    let content = std::fs::read_to_string("/sys/kernel/sched_ext/state")?;
259
260    match content.trim() {
261        "enabled" => Ok(true),
262        "disabled" => Ok(false),
263        _ => {
264            // Error if the content is neither "enabled" nor "disabled"
265            Err(io::Error::new(
266                io::ErrorKind::InvalidData,
267                "Unexpected content in /sys/kernel/sched_ext/state",
268            ))
269        }
270    }
271}
272
273#[macro_export]
274macro_rules! unwrap_or_break {
275    ($expr: expr, $label: lifetime) => {{
276        match $expr {
277            Ok(val) => val,
278            Err(e) => break $label Err(e),
279        }
280    }};
281}
282
283pub fn check_min_requirements() -> Result<()> {
284    // ec7e3b0463e1 ("implement-ops") in https://github.com/sched-ext/sched_ext
285    // is the current minimum required kernel version.
286    if let Ok(false) | Err(_) = struct_has_field("sched_ext_ops", "dump") {
287        bail!("sched_ext_ops.dump() missing, kernel too old?");
288    }
289    Ok(())
290}
291
292/// struct sched_ext_ops can change over time. If compat.bpf.h::SCX_OPS_DEFINE()
293/// is used to define ops, and scx_ops_open!(), scx_ops_load!(), and
294/// scx_ops_attach!() are used to open, load and attach it, backward
295/// compatibility is automatically maintained where reasonable.
296#[rustfmt::skip]
297#[macro_export]
298macro_rules! scx_ops_open {
299    ($builder: expr, $obj_ref: expr, $ops: ident) => { 'block: {
300        scx_utils::paste! {
301	    scx_utils::unwrap_or_break!(scx_utils::compat::check_min_requirements(), 'block);
302            use ::anyhow::Context;
303            use ::libbpf_rs::skel::SkelBuilder;
304
305            let mut skel = match $builder.open($obj_ref).context("Failed to open BPF program") {
306                Ok(val) => val,
307                Err(e) => break 'block Err(e),
308            };
309
310            let ops = skel.struct_ops.[<$ops _mut>]();
311            let path = std::path::Path::new("/sys/kernel/sched_ext/hotplug_seq");
312
313            let val = match std::fs::read_to_string(&path) {
314                Ok(val) => val,
315                Err(_) => {
316                    break 'block Err(anyhow::anyhow!("Failed to open or read file {:?}", path));
317                }
318            };
319
320            ops.hotplug_seq = match val.trim().parse::<u64>() {
321                Ok(parsed) => parsed,
322                Err(_) => {
323                    break 'block Err(anyhow::anyhow!("Failed to parse hotplug seq {}", val));
324                }
325            };
326
327            if let Ok(s) = ::std::env::var("SCX_TIMEOUT_MS") {
328                skel.struct_ops.[<$ops _mut>]().timeout_ms = match s.parse::<u32>() {
329                    Ok(ms) => {
330                        ::scx_utils::info!("Setting timeout_ms to {} based on environment", ms);
331                        ms
332                    },
333                    Err(e) => {
334                        break 'block anyhow::Result::Err(e).context("SCX_TIMEOUT_MS has invalid value");
335                    },
336                };
337            }
338
339            $crate::import_enums!(skel);
340
341            let result = ::anyhow::Result::Ok(skel);
342
343            result
344        }
345    }};
346}
347
348/// struct sched_ext_ops can change over time. If compat.bpf.h::SCX_OPS_DEFINE()
349/// is used to define ops, and scx_ops_open!(), scx_ops_load!(), and
350/// scx_ops_attach!() are used to open, load and attach it, backward
351/// compatibility is automatically maintained where reasonable.
352#[rustfmt::skip]
353#[macro_export]
354macro_rules! scx_ops_load {
355    ($skel: expr, $ops: ident, $uei: ident) => { 'block: {
356        scx_utils::paste! {
357            use ::anyhow::Context;
358            use ::libbpf_rs::skel::OpenSkel;
359
360            scx_utils::uei_set_size!($skel, $ops, $uei);
361            $skel.load().context("Failed to load BPF program")
362        }
363    }};
364}
365
366/// Must be used together with scx_ops_load!(). See there.
367#[rustfmt::skip]
368#[macro_export]
369macro_rules! scx_ops_attach {
370    ($skel: expr, $ops: ident) => { 'block: {
371        use ::anyhow::Context;
372        use ::libbpf_rs::skel::Skel;
373
374        if scx_utils::compat::is_sched_ext_enabled().unwrap_or(false) {
375            break 'block Err(anyhow::anyhow!(
376                "another sched_ext scheduler is already running"
377            ));
378        }
379        $skel
380            .attach()
381            .context("Failed to attach non-struct_ops BPF programs")
382            .and_then(|_| {
383                $skel
384                    .maps
385                    .$ops
386                    .attach_struct_ops()
387                    .context("Failed to attach struct_ops BPF programs")
388            })
389    }};
390}
391
392#[cfg(test)]
393mod tests {
394    #[test]
395    fn test_read_enum() {
396        assert_eq!(super::read_enum("pid_type", "PIDTYPE_TGID").unwrap(), 1);
397    }
398
399    #[test]
400    fn test_struct_has_field() {
401        assert!(super::struct_has_field("task_struct", "flags").unwrap());
402        assert!(!super::struct_has_field("task_struct", "NO_SUCH_FIELD").unwrap());
403        assert!(super::struct_has_field("NO_SUCH_STRUCT", "NO_SUCH_FIELD").is_err());
404    }
405
406    #[test]
407    fn test_ksym_exists() {
408        assert!(super::ksym_exists("bpf_task_acquire").unwrap());
409        assert!(!super::ksym_exists("NO_SUCH_KFUNC").unwrap());
410    }
411}