1use crate::NR_CPU_IDS;
59use anyhow::bail;
60use anyhow::Context;
61use anyhow::Result;
62use bitvec::prelude::*;
63use sscanf::sscanf;
64use std::fmt;
65use std::ops::BitAndAssign;
66use std::ops::BitOrAssign;
67use std::ops::BitXorAssign;
68
69#[cfg(any(test, feature = "testutils"))]
70thread_local! {
71 static MASK_WIDTH_OVERRIDE: std::cell::Cell<usize> = const { std::cell::Cell::new(0) };
74}
75
76fn mask_width() -> usize {
78 #[cfg(any(test, feature = "testutils"))]
79 {
80 let ovr = MASK_WIDTH_OVERRIDE.with(|c| c.get());
81 if ovr > 0 {
82 return ovr;
83 }
84 }
85 *NR_CPU_IDS
86}
87
88#[cfg(any(test, feature = "testutils"))]
92pub fn set_cpumask_test_width(width: usize) {
93 MASK_WIDTH_OVERRIDE.with(|c| c.set(width));
94}
95
96#[derive(Debug, Eq, Clone, Hash, Ord, PartialEq, PartialOrd)]
97pub struct Cpumask {
98 mask: BitVec<u64, Lsb0>,
99}
100
101impl Cpumask {
102 fn check_cpu(&self, cpu: usize) -> Result<()> {
103 if cpu >= mask_width() {
104 bail!("Invalid CPU {} passed, max {}", cpu, mask_width());
105 }
106
107 Ok(())
108 }
109
110 pub fn new() -> Cpumask {
112 Cpumask {
113 mask: bitvec![u64, Lsb0; 0; mask_width()],
114 }
115 }
116
117 pub fn from_str(cpumask: &str) -> Result<Cpumask> {
119 match cpumask {
120 "none" => {
121 let mask = bitvec![u64, Lsb0; 0; mask_width()];
122 return Ok(Self { mask });
123 }
124 "all" => {
125 let mask = bitvec![u64, Lsb0; 1; mask_width()];
126 return Ok(Self { mask });
127 }
128 _ => {}
129 }
130 let hex_str = {
131 let mut tmp_str = cpumask
132 .strip_prefix("0x")
133 .unwrap_or(cpumask)
134 .replace('_', "");
135 if tmp_str.len() % 2 != 0 {
136 tmp_str = "0".to_string() + &tmp_str;
137 }
138 tmp_str
139 };
140 let byte_vec =
141 hex::decode(&hex_str).with_context(|| format!("Failed to parse cpumask: {cpumask}"))?;
142
143 let mut mask = bitvec![u64, Lsb0; 0; mask_width()];
144 for (index, &val) in byte_vec.iter().rev().enumerate() {
145 let mut v = val;
146 while v != 0 {
147 let lsb = v.trailing_zeros() as usize;
148 v &= !(1 << lsb);
149 let cpu = index * 8 + lsb;
150 if cpu >= mask_width() {
151 bail!(
152 concat!(
153 "Found cpu ({}) in cpumask ({}) which is larger",
154 " than the number of cpus on the machine ({})"
155 ),
156 cpu,
157 cpumask,
158 mask_width()
159 );
160 }
161 mask.set(cpu, true);
162 }
163 }
164
165 Ok(Self { mask })
166 }
167
168 pub fn from_cpulist(cpulist: &str) -> Result<Cpumask> {
169 let mut mask = Cpumask::new();
170 for cpu_id in read_cpulist(cpulist)? {
171 let _ = mask.set_cpu(cpu_id);
172 }
173
174 Ok(mask)
175 }
176
177 pub fn to_cpulist(&self) -> String {
180 let cpus: Vec<usize> = self.iter().collect();
181 if cpus.is_empty() {
182 return String::from("none");
183 }
184
185 let mut ranges = Vec::new();
186 let mut start = cpus[0];
187 let mut end = cpus[0];
188
189 for &cpu in &cpus[1..] {
190 if cpu == end + 1 {
191 end = cpu;
192 } else {
193 ranges.push(if start == end {
194 format!("{}", start)
195 } else {
196 format!("{}-{}", start, end)
197 });
198 start = cpu;
199 end = cpu;
200 }
201 }
202
203 ranges.push(if start == end {
204 format!("{}", start)
205 } else {
206 format!("{}-{}", start, end)
207 });
208
209 ranges.join(",")
210 }
211
212 pub fn from_vec(vec: Vec<u64>) -> Self {
213 Self {
214 mask: BitVec::from_vec(vec),
215 }
216 }
217
218 pub fn from_bitvec(bitvec: BitVec<u64, Lsb0>) -> Self {
219 Self { mask: bitvec }
220 }
221
222 pub fn as_raw_slice(&self) -> &[u64] {
224 self.mask.as_raw_slice()
225 }
226
227 pub fn as_raw_bitvec_mut(&mut self) -> &mut BitVec<u64, Lsb0> {
229 &mut self.mask
230 }
231
232 pub fn as_raw_bitvec(&self) -> &BitVec<u64, Lsb0> {
234 &self.mask
235 }
236
237 pub fn set_all(&mut self) {
239 self.mask.fill(true);
240 }
241
242 pub fn clear_all(&mut self) {
244 self.mask.fill(false);
245 }
246
247 pub fn set_cpu(&mut self, cpu: usize) -> Result<()> {
250 self.check_cpu(cpu)?;
251 self.mask.set(cpu, true);
252 Ok(())
253 }
254
255 pub fn clear_cpu(&mut self, cpu: usize) -> Result<()> {
258 self.check_cpu(cpu)?;
259 self.mask.set(cpu, false);
260 Ok(())
261 }
262
263 pub fn test_cpu(&self, cpu: usize) -> bool {
266 match self.mask.get(cpu) {
267 Some(bit) => *bit,
268 None => false,
269 }
270 }
271
272 pub fn weight(&self) -> usize {
274 self.mask.count_ones()
275 }
276
277 pub fn is_empty(&self) -> bool {
279 self.mask.count_ones() == 0
280 }
281
282 pub fn is_full(&self) -> bool {
284 self.mask.count_ones() == mask_width()
285 }
286
287 pub fn len(&self) -> usize {
289 mask_width()
290 }
291
292 pub fn not(&self) -> Cpumask {
294 let mut new = self.clone();
295 new.mask = !new.mask;
296 new
297 }
298
299 pub fn and(&self, other: &Cpumask) -> Cpumask {
301 let mut new = self.clone();
302 new.mask &= other.mask.clone();
303 new
304 }
305
306 pub fn or(&self, other: &Cpumask) -> Cpumask {
308 let mut new = self.clone();
309 new.mask |= other.mask.clone();
310 new
311 }
312
313 pub fn xor(&self, other: &Cpumask) -> Cpumask {
315 let mut new = self.clone();
316 new.mask ^= other.mask.clone();
317 new
318 }
319
320 pub fn iter(&self) -> CpumaskIterator<'_> {
335 CpumaskIterator {
336 mask: self,
337 index: 0,
338 }
339 }
340
341 pub unsafe fn write_to_ptr(&self, bpfptr: *mut u64, len: usize) -> Result<()> {
349 let cpumask_slice = self.as_raw_slice();
350 if len != cpumask_slice.len() {
351 bail!(
352 "BPF CPU mask has length {} u64s, Cpumask size is {}",
353 len,
354 cpumask_slice.len()
355 );
356 }
357
358 let ptr = bpfptr as *mut [u64; 64];
359 let bpfmask: &mut [u64; 64] = unsafe { &mut *ptr };
360 let (left, _) = bpfmask.split_at_mut(cpumask_slice.len());
361 left.clone_from_slice(cpumask_slice);
362
363 Ok(())
364 }
365
366 fn fmt_with(&self, f: &mut fmt::Formatter<'_>, case: char) -> fmt::Result {
367 let mut masks: Vec<u32> = self
368 .as_raw_slice()
369 .iter()
370 .flat_map(|x| [*x as u32, (x >> 32) as u32])
371 .collect();
372
373 masks.truncate((mask_width()).div_ceil(32));
375
376 let width = match (mask_width()).div_ceil(4) % 8 {
378 0 => 8,
379 v => v,
380 };
381 match case {
382 'x' => write!(f, "{:0width$x}", masks.pop().unwrap(), width = width)?,
383 'X' => write!(f, "{:0width$X}", masks.pop().unwrap(), width = width)?,
384 _ => unreachable!(),
385 }
386
387 for submask in masks.iter().rev() {
389 match case {
390 'x' => write!(f, ",{submask:08x}")?,
391 'X' => write!(f, ",{submask:08X}")?,
392 _ => unreachable!(),
393 }
394 }
395 Ok(())
396 }
397}
398
399pub fn read_cpulist(cpulist: &str) -> Result<Vec<usize>> {
400 let cpulist = cpulist.trim_end_matches('\0');
401 let cpu_groups: Vec<&str> = cpulist.split(',').collect();
402 let mut cpu_ids = vec![];
403 for group in cpu_groups.iter() {
404 let (min, max) = match sscanf!(group.trim(), "{usize}-{usize}") {
405 Some((x, y)) => (x, y),
406 None => match sscanf!(group.trim(), "{usize}") {
407 Some(x) => (x, x),
408 None => {
409 bail!("Failed to parse cpulist {}", group.trim());
410 }
411 },
412 };
413 for i in min..(max + 1) {
414 cpu_ids.push(i);
415 }
416 }
417
418 Ok(cpu_ids)
419}
420
421pub struct CpumaskIterator<'a> {
422 mask: &'a Cpumask,
423 index: usize,
424}
425
426impl Iterator for CpumaskIterator<'_> {
427 type Item = usize;
428
429 fn next(&mut self) -> Option<Self::Item> {
430 while self.index < mask_width() {
431 let index = self.index;
432 self.index += 1;
433 let bit_val = self.mask.test_cpu(index);
434 if bit_val {
435 return Some(index);
436 }
437 }
438
439 None
440 }
441}
442
443impl fmt::Display for Cpumask {
444 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
445 self.fmt_with(f, 'x')
446 }
447}
448
449impl fmt::LowerHex for Cpumask {
450 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
451 self.fmt_with(f, 'x')
452 }
453}
454
455impl fmt::UpperHex for Cpumask {
456 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
457 self.fmt_with(f, 'X')
458 }
459}
460
461impl BitAndAssign<&Self> for Cpumask {
462 fn bitand_assign(&mut self, rhs: &Self) {
463 self.mask &= &rhs.mask;
464 }
465}
466
467impl BitOrAssign<&Self> for Cpumask {
468 fn bitor_assign(&mut self, rhs: &Self) {
469 self.mask |= &rhs.mask;
470 }
471}
472
473impl BitXorAssign<&Self> for Cpumask {
474 fn bitxor_assign(&mut self, rhs: &Self) {
475 self.mask ^= &rhs.mask;
476 }
477}
478
479#[cfg(test)]
480mod tests {
481 use super::*;
482
483 #[test]
484 fn test_to_cpulist_empty() {
485 let mask = Cpumask::new();
486 assert_eq!(mask.to_cpulist(), "none");
487 }
488
489 #[test]
490 fn test_to_cpulist_single_cpu() {
491 set_cpumask_test_width(16);
492 let mut mask = Cpumask::new();
493 mask.set_cpu(5).unwrap();
494 assert_eq!(mask.to_cpulist(), "5");
495 }
496
497 #[test]
498 fn test_to_cpulist_contiguous_range() {
499 set_cpumask_test_width(16);
500 let mut mask = Cpumask::new();
501 for cpu in 0..8 {
502 mask.set_cpu(cpu).unwrap();
503 }
504 assert_eq!(mask.to_cpulist(), "0-7");
505 }
506
507 #[test]
508 fn test_to_cpulist_multiple_ranges() {
509 set_cpumask_test_width(16);
510 let mut mask = Cpumask::new();
511 for cpu in 0..4 {
512 mask.set_cpu(cpu).unwrap();
513 }
514 for cpu in 8..12 {
515 mask.set_cpu(cpu).unwrap();
516 }
517 assert_eq!(mask.to_cpulist(), "0-3,8-11");
518 }
519
520 #[test]
521 fn test_to_cpulist_scattered() {
522 set_cpumask_test_width(16);
523 let mut mask = Cpumask::new();
524 mask.set_cpu(1).unwrap();
525 mask.set_cpu(3).unwrap();
526 mask.set_cpu(5).unwrap();
527 assert_eq!(mask.to_cpulist(), "1,3,5");
528 }
529
530 #[test]
531 fn test_to_cpulist_mixed() {
532 set_cpumask_test_width(16);
533 let mut mask = Cpumask::new();
534 mask.set_cpu(0).unwrap();
535 mask.set_cpu(1).unwrap();
536 mask.set_cpu(2).unwrap();
537 mask.set_cpu(5).unwrap();
538 mask.set_cpu(10).unwrap();
539 mask.set_cpu(11).unwrap();
540 assert_eq!(mask.to_cpulist(), "0-2,5,10-11");
541 }
542
543 #[test]
544 fn test_to_cpulist_roundtrip() {
545 set_cpumask_test_width(32);
546 let original = "0-3,8-11,16";
547 let mask = Cpumask::from_cpulist(original).unwrap();
548 assert_eq!(mask.to_cpulist(), original);
549 }
550}