Skip to main content

scx_arena_selftests/
main.rs

1// Copyright (c) Meta Platforms, Inc. and affiliates.
2
3// This software may be used and distributed according to the terms of the
4// GNU General Public License version 2.
5mod bpf_skel;
6pub use bpf_skel::*;
7
8use std::mem::MaybeUninit;
9
10use anyhow::bail;
11use anyhow::Context;
12use anyhow::Result;
13
14use std::ffi::c_ulong;
15use std::ffi::c_void;
16use std::io::IsTerminal;
17
18use std::os::fd::AsFd;
19use std::os::fd::AsRawFd;
20use std::sync::Arc;
21
22use clap::Parser;
23
24use scx_utils::init_libbpf_logging;
25use scx_utils::Core;
26use scx_utils::Llc;
27use scx_utils::Topology;
28use scx_utils::NR_CPU_IDS;
29
30use simplelog::{ColorChoice, Config as SimplelogConfig, TermLogger, TerminalMode};
31
32use libbpf_rs::libbpf_sys;
33
34use libbpf_rs::skel::OpenSkel;
35use libbpf_rs::skel::SkelBuilder;
36use libbpf_rs::PrintLevel;
37use libbpf_rs::ProgramInput;
38
39const BPF_STDOUT: u32 = 1;
40const BPF_STDERR: u32 = 2;
41
42const COLOR_BRIGHT_GREEN: &str = "\x1b[92m";
43const COLOR_BRIGHT_RED: &str = "\x1b[91m";
44const COLOR_RESET: &str = "\x1b[0m";
45
46fn colorize(text: &str, color: &str, is_tty: bool) -> String {
47    if is_tty {
48        format!("{}{}{}", color, text, COLOR_RESET)
49    } else {
50        text.to_string()
51    }
52}
53
54// Mirrors enum scx_selftest_id in lib/selftests/selftest.h.
55// SCX_SELFTEST_ID_ALL (0) is reserved for "run all" and is not listed in
56// TEST_CASES; only the named per-test IDs appear there.
57#[repr(u32)]
58#[allow(non_camel_case_types)]
59enum SelfTestId {
60    #[allow(dead_code)]
61    SCX_SELFTEST_ID_ALL = 0,
62    SCX_SELFTEST_ID_ATQ = 1,
63    SCX_SELFTEST_ID_BTREE = 2,
64    SCX_SELFTEST_ID_LVQUEUE = 3,
65    SCX_SELFTEST_ID_MINHEAP = 4,
66    SCX_SELFTEST_ID_RBTREE = 5,
67    SCX_SELFTEST_ID_TOPOLOGY = 6,
68}
69
70fn available_tests() -> String {
71    TEST_CASES
72        .iter()
73        .map(|(name, _)| format!("  {}", name))
74        .collect::<Vec<_>>()
75        .join("\n")
76}
77
78const TEST_CASES: &[(&str, u32)] = &[
79    ("atq", SelfTestId::SCX_SELFTEST_ID_ATQ as u32),
80    ("btree", SelfTestId::SCX_SELFTEST_ID_BTREE as u32),
81    ("lvqueue", SelfTestId::SCX_SELFTEST_ID_LVQUEUE as u32),
82    ("minheap", SelfTestId::SCX_SELFTEST_ID_MINHEAP as u32),
83    ("rbtree", SelfTestId::SCX_SELFTEST_ID_RBTREE as u32),
84    ("topology", SelfTestId::SCX_SELFTEST_ID_TOPOLOGY as u32),
85];
86
87#[derive(Debug, Parser)]
88#[clap(about = "scx_arena library selftests")]
89struct Opts {
90    /// List all available test cases and exit.
91    #[clap(long)]
92    list: bool,
93
94    /// Run one or more specific test cases. Multiple names can be given after a
95    /// single --test flag (e.g. --test rbtree atq), or the flag can be repeated.
96    /// If not specified, all tests are run.
97    #[clap(long = "test", value_name = "NAME", num_args(1..))]
98    tests: Vec<String>,
99}
100
101fn setup_arenas(skel: &mut BpfSkel<'_>) -> Result<()> {
102    const STATIC_ALLOC_PAGES_GRANULARITY: c_ulong = 512;
103    const TASK_SIZE: c_ulong = 42;
104
105    // Allocate the arena memory from the BPF side so userspace initializes it before starting
106    // the scheduler. Despite the function call's name this is neither a test nor a test run,
107    // it's the recommended way of executing SEC("syscall") probes.
108    let mut args = types::arena_init_args {
109        static_pages: STATIC_ALLOC_PAGES_GRANULARITY,
110        task_ctx_size: TASK_SIZE,
111    };
112
113    let input = ProgramInput {
114        context_in: Some(unsafe {
115            std::slice::from_raw_parts_mut(
116                &mut args as *mut _ as *mut u8,
117                std::mem::size_of_val(&args),
118            )
119        }),
120        ..Default::default()
121    };
122
123    let output = skel.progs.arena_init.test_run(input)?;
124    if output.return_value != 0 {
125        bail!(
126            "Could not initialize arenas, arena_init returned {}",
127            output.return_value as i32
128        );
129    }
130
131    Ok(())
132}
133
134fn setup_topology_node(skel: &mut BpfSkel<'_>, mask: &[u64]) -> Result<()> {
135    let mut args = types::arena_alloc_mask_args {
136        bitmap: 0 as c_ulong,
137    };
138
139    let input = ProgramInput {
140        context_in: Some(unsafe {
141            std::slice::from_raw_parts_mut(
142                &mut args as *mut _ as *mut u8,
143                std::mem::size_of_val(&args),
144            )
145        }),
146        ..Default::default()
147    };
148
149    let output = skel.progs.arena_alloc_mask.test_run(input)?;
150    if output.return_value != 0 {
151        bail!(
152            "Could not initialize arenas, setup_topology_node returned {}",
153            output.return_value as i32
154        );
155    }
156
157    let ptr = unsafe {
158        &mut *std::ptr::with_exposed_provenance_mut::<[u64; 10]>(args.bitmap.try_into().unwrap())
159    };
160
161    let (valid_mask, _) = ptr.split_at_mut(mask.len());
162    valid_mask.clone_from_slice(mask);
163
164    let mut args = types::arena_topology_node_init_args {
165        bitmap: args.bitmap as c_ulong,
166        data_size: 0 as c_ulong,
167        id: 0 as c_ulong,
168    };
169
170    let input = ProgramInput {
171        context_in: Some(unsafe {
172            std::slice::from_raw_parts_mut(
173                &mut args as *mut _ as *mut u8,
174                std::mem::size_of_val(&args),
175            )
176        }),
177        ..Default::default()
178    };
179
180    let output = skel.progs.arena_topology_node_init.test_run(input)?;
181    if output.return_value != 0 {
182        bail!(
183            "arena_topology_node_init returned {}",
184            output.return_value as i32
185        );
186    }
187
188    Ok(())
189}
190
191fn setup_topology(skel: &mut BpfSkel<'_>) -> Result<()> {
192    let topo = Topology::new().expect("Failed to build host topology");
193
194    // Set per-level max children before registering any topology nodes.
195    // NOTE: rust/scx_arena/scx_arena/src/arenalib.rs::setup_topology_max_children()
196    // contains equivalent logic and must be kept in sync with this block.
197    let max_children: [u32; 5] = [
198        topo.nodes.len() as u32,
199        topo.nodes.values().map(|n| n.llcs.len()).max().unwrap_or(0) as u32,
200        topo.all_llcs
201            .values()
202            .map(|l| l.cores.len())
203            .max()
204            .unwrap_or(0) as u32,
205        topo.all_cores
206            .values()
207            .map(|c| c.cpus.len())
208            .max()
209            .unwrap_or(0) as u32,
210        0,
211    ];
212    let mut init_args = types::arena_topology_init_args { max_children };
213    let init_input = ProgramInput {
214        context_in: Some(unsafe {
215            std::slice::from_raw_parts_mut(
216                &mut init_args as *mut _ as *mut u8,
217                std::mem::size_of_val(&init_args),
218            )
219        }),
220        ..Default::default()
221    };
222    let output = skel.progs.arena_topology_init.test_run(init_input)?;
223    if output.return_value != 0 {
224        bail!(
225            "arena_topology_init returned {}",
226            output.return_value as i32
227        );
228    }
229
230    setup_topology_node(skel, topo.span.as_raw_slice())?;
231
232    for (_, node) in topo.nodes {
233        setup_topology_node(skel, node.span.as_raw_slice())?;
234    }
235
236    for (_, llc) in topo.all_llcs {
237        setup_topology_node(
238            skel,
239            Arc::<Llc>::into_inner(llc)
240                .expect("missing llc")
241                .span
242                .as_raw_slice(),
243        )?;
244    }
245
246    for (_, core) in topo.all_cores {
247        setup_topology_node(
248            skel,
249            Arc::<Core>::into_inner(core)
250                .expect("missing core")
251                .span
252                .as_raw_slice(),
253        )?;
254    }
255    for (_, cpu) in topo.all_cpus {
256        let mut mask = [0; 9];
257        mask[cpu.id / 64] |= 1 << (cpu.id % 64);
258        setup_topology_node(skel, &mask)?;
259    }
260
261    Ok(())
262}
263
264fn print_stream(skel: &mut BpfSkel<'_>, stream_id: u32) -> () {
265    let prog_fd = skel.progs.arena_selftest.as_fd().as_raw_fd();
266    let mut buf = vec![0u8; 4096];
267    let name = if stream_id == 1 { "OUTPUT" } else { "ERROR" };
268    let mut started = false;
269
270    loop {
271        let ret = unsafe {
272            libbpf_sys::bpf_prog_stream_read(
273                prog_fd,
274                stream_id,
275                buf.as_mut_ptr() as *mut c_void,
276                buf.len() as u32,
277                std::ptr::null_mut(),
278            )
279        };
280        if ret < 0 {
281            eprintln!("STREAM {} UNAVAILABLE (REQUIRES >= v6.17)", name);
282            return;
283        }
284
285        if !started {
286            println!("===BEGIN STREAM {}===", name);
287            started = true;
288        }
289
290        if ret == 0 {
291            break;
292        }
293
294        print!("{}", String::from_utf8_lossy(&buf[..ret as usize]));
295    }
296
297    println!("\n====END STREAM  {}====", name);
298}
299
300// Run the named test by setting selftest_run_id in the BPF bss and calling
301// arena_selftest. The ID comes from enum scx_selftest_id in selftest.h.
302fn run_test_by_name(skel: &mut BpfSkel<'_>, name: &str) -> Result<i32> {
303    let id = TEST_CASES
304        .iter()
305        .find(|(n, _)| *n == name)
306        .map(|(_, id)| *id)
307        .ok_or_else(|| {
308            anyhow::anyhow!(
309                "Unknown test: '{}'. Use --list to see available tests.",
310                name
311            )
312        })?;
313
314    skel.maps.bss_data.as_mut().unwrap().selftest_run_id = id;
315
316    let input = ProgramInput {
317        ..Default::default()
318    };
319    let output = skel.progs.arena_selftest.test_run(input)?;
320
321    Ok(output.return_value as i32)
322}
323
324fn main() {
325    TermLogger::init(
326        simplelog::LevelFilter::Info,
327        SimplelogConfig::default(),
328        TerminalMode::Mixed,
329        ColorChoice::Auto,
330    )
331    .unwrap();
332
333    let opts = Opts::parse();
334
335    if opts.list {
336        println!("Available test cases:\n{}", available_tests());
337        return;
338    }
339
340    // Validate test names before loading BPF.
341    for name in &opts.tests {
342        if !TEST_CASES.iter().any(|(n, _)| *n == name.as_str()) {
343            eprintln!(
344                "Unknown test: '{}'.\nAvailable tests:\n{}",
345                name,
346                available_tests()
347            );
348            std::process::exit(1);
349        }
350    }
351
352    let mut open_object = MaybeUninit::uninit();
353    let mut builder = BpfSkelBuilder::default();
354
355    builder.obj_builder.debug(true);
356    init_libbpf_logging(Some(PrintLevel::Debug));
357
358    let mut skel = builder
359        .open(&mut open_object)
360        .context("Failed to open BPF program")
361        .unwrap();
362
363    skel.maps.rodata_data.as_mut().unwrap().nr_cpu_ids = *NR_CPU_IDS as u32;
364
365    let mut skel = skel.load().context("Failed to load BPF program").unwrap();
366
367    setup_arenas(&mut skel).unwrap();
368    setup_topology(&mut skel).unwrap();
369
370    let to_run: Vec<&str> = if opts.tests.is_empty() {
371        TEST_CASES.iter().map(|(n, _)| *n).collect()
372    } else {
373        opts.tests.iter().map(String::as_str).collect()
374    };
375
376    let stdout_tty = std::io::stdout().is_terminal();
377    let stderr_tty = std::io::stderr().is_terminal();
378    let pass_label = colorize("[ PASS ]", COLOR_BRIGHT_GREEN, stdout_tty);
379    let fail_label = colorize("[ FAIL ]", COLOR_BRIGHT_RED, stderr_tty);
380
381    let mut any_failed = false;
382    for &name in &to_run {
383        match run_test_by_name(&mut skel, name) {
384            Ok(0) => println!("{} {}", pass_label, name),
385            Ok(ret) => {
386                eprintln!("{} {} (returned {})", fail_label, name, ret);
387                any_failed = true;
388            }
389            Err(e) => {
390                eprintln!("{} {} (error: {})", fail_label, name, e);
391                any_failed = true;
392            }
393        }
394
395        print_stream(&mut skel, BPF_STDOUT);
396        print_stream(&mut skel, BPF_STDERR);
397    }
398
399    if any_failed {
400        eprintln!(
401            "{}",
402            colorize(
403                "One or more selftests failed.",
404                COLOR_BRIGHT_RED,
405                stderr_tty
406            )
407        );
408        std::process::exit(1);
409    } else {
410        println!(
411            "{}",
412            colorize("All selftests passed.", COLOR_BRIGHT_GREEN, stdout_tty)
413        );
414    }
415}