scx_wd40/
load_balance.rs

1// Copyright (c) Meta Platforms, Inc. and affiliates.
2
3// This software may be used and distributed according to the terms of the
4// GNU General Public License version 2.
5
6//! # Rusty load balancer
7//!
8//! The module that includes logic for performing load balancing in the
9//! scx_wd40 scheduler.
10//!
11//! Load Balancing
12//! --------------
13//!
14//! scx_wd40 performs load balancing using the following general workflow:
15//!
16//! 1. Determine domain load averages from the duty cycle buckets in the
17//!    dom_ctx_map_elem map, aggregate the load using
18//!    scx_utils::LoadCalculator, and then determine load distribution
19//!    (accounting for infeasible weights) the scx_utils::LoadLedger object.
20//!
21//! 2. Create a hierarchy representing load using NumaNode and Domain objects
22//!    as follows:
23//!
24//!                                 o--------------------------------o
25//!                                 |             LB Root            |
26//!                                 |                                |
27//!                                 | PushNodes: <Load, NumaNode>    |
28//!                                 | PullNodes: <Load, NumaNode>    |
29//!                                 | BalancedNodes: <Load, NumaNode>|
30//!                                 o----------------o---------------o
31//!                                                  |
32//!                              o-------------------o-------------------o
33//!                              |                   |                   |
34//!                              |                   |                   |
35//!                o-------------o--------------o   ...   o--------------o-------------o
36//!                |          NumaNode          |         |           NumaNode         |
37//!                | ID    0                    |         | ID    1                    |
38//!                | PushDomains <Load, Domain> |         | PushDomains <Load, Domain> |
39//!                | PullDomains <Load, Domain> |         | PullDomains <Load, Domain> |
40//!                | BalancedDomains <Domain>   |         | BalancedDomains <Domain>   |
41//!                | LoadSum f64                |         | LoadSum f64                |
42//!                | LoadAvg f64                |         | LoadAvg f64                |
43//!                | LoadImbal f64              |         | LoadImbal f64              |
44//!                | BalanceCost f64            |         | BalanceCost f64            |
45//!                | ...                        |         | ...                        |
46//!                o-------------o--------------o         o----------------------------o
47//!                              |
48//!                              |
49//!  o----------------------o   ...   o---------------------o
50//!  |        Domain        |         |        Domain       |
51//!  | ID    0              |         | ID    1             |
52//!  | Tasks   <Load, Task> |         | Tasks  <Load, Task> |
53//!  | LoadSum f64          |         | LoadSum f64         |
54//!  | LoadAvg f64          |         | LoadAvg f64         |
55//!  | LoadImbal f64        |         | LoadImbal f64       |
56//!  | BalanceCost f64      |         | BalanceCost f64     |
57//!  | ...                  |         | ...                 |
58//!  o----------------------o         o----------o----------o
59//!                                              |
60//!                                              |
61//!                        o----------------o   ...   o----------------o
62//!                        |      Task      |         |      Task      |
63//!                        | PID   0        |         | PID   1        |
64//!                        | Load  f64      |         | Load  f64      |
65//!                        | Migrated bool  |         | Migrated bool  |
66//!                        | IsKworker bool |         | IsKworker bool |
67//!                        o----------------o         o----------------o
68//!
69//! As mentioned above, the hierarchy is created by querying BPF for each
70//! domain's duty cycle, and using the infeasible.rs crate to determine load
71//! averages and load sums for each domain.
72//!
73//! 3. From the LB Root, we begin by iterating over all NUMA nodes, and
74//!    migrating load from any nodes with an excess of load (push nodes) to
75//!    nodes with a lack of load (pull domains). The cost of migrations here are
76//!    higher than when migrating load between domains within a node.
77//!    Ultimately, migrations are performed by moving tasks between domains. The
78//!    difference in this step is that imbalances are first addressed by moving
79//!    tasks between NUMA nodes, and that such migrations only take place when
80//!    imbalances are sufficiently high to warrant it.
81//!
82//! 4. Once load has been migrated between NUMA nodes, we iterate over each NUMA
83//!    node and migrate load between the domains inside of each. The cost of
84//!    migrations here are lower than between NUMA nodes. Like with load
85//!    balancing between NUMA nodes, migrations here are just moving tasks
86//!    between domains.
87//!
88//! The load hierarchy is always created when load_balance() is called on a
89//! LoadBalancer object, but actual load balancing is only performed if the
90//! balance_load option is specified.
91//!
92//! Statistics
93//! ----------
94//!
95//! After load balancing has occurred, statistics may be queried by invoking
96//! the get_stats() function on the LoadBalancer object:
97//!
98//! ```
99//! let lb = LoadBalancer::new(...)?;
100//! lb.load_balance()?;
101//!
102//! let stats = lb.get_stats();
103//! ...
104//! ```
105//!
106//! Statistics are exported as a vector of NumaStat objects, which each
107//! contains load balancing statistics for that NUMA node, as well as
108//! statistics for any Domains contained therein as DomainStats objects.
109//!
110//! Future Improvements
111//! -------------------
112//!
113//! There are a few ways that we could further improve the implementation here:
114//!
115//! - The logic for load balancing between NUMA nodes, and load balancing within
116//!   a specific NUMA node (i.e. between domains in that NUMA node), could
117//!   probably be improved to avoid code duplication using traits and/or
118//!   generics.
119//!
120//! - When deciding whether to migrate a task, we're only looking at its impact
121//!   on addressing load imbalances. In reality, this is a very complex,
122//!   multivariate cost function. For example, a domain with sufficiently low
123//!   load to warrant having an imbalance / requiring more load maybe should not
124//!   pull load if it's running tasks that are much better suited to isolation.
125//!   Or, a domain may not want to push a task to another domain if the task is
126//!   co-located with other tasks that benefit from shared L3 cache locality.
127//!
128//!   Coming up with an extensible and clean way to model and implement this is
129//!   likely itself a large project.
130//!
131//! - We're not accounting for cgroups when performing load balancing.
132
133use core::cmp::Ordering;
134use std::cell::Cell;
135use std::collections::BTreeMap;
136use std::collections::VecDeque;
137use std::fmt;
138use std::sync::Arc;
139
140use anyhow::anyhow;
141use anyhow::bail;
142use anyhow::Result;
143use log::debug;
144use log::trace;
145use ordered_float::OrderedFloat;
146use scx_utils::ravg::ravg_read;
147use scx_utils::LoadAggregator;
148use scx_utils::LoadLedger;
149use sorted_vec::SortedVec;
150
151use crate::bpf_intf;
152use crate::bpf_skel::*;
153use crate::stats::DomainStats;
154use crate::stats::NodeStats;
155use crate::DomainGroup;
156
157const DEFAULT_WEIGHT: f64 = bpf_intf::consts_LB_DEFAULT_WEIGHT as f64;
158const RAVG_FRAC_BITS: u32 = bpf_intf::ravg_consts_RAVG_FRAC_BITS;
159
160fn now_monotonic() -> u64 {
161    let time = nix::time::ClockId::CLOCK_MONOTONIC
162        .now()
163        .expect("Failed getting current monotonic time");
164    time.tv_sec() as u64 * 1_000_000_000 + time.tv_nsec() as u64
165}
166
167#[derive(Clone, Copy, Debug, PartialEq)]
168enum BalanceState {
169    Balanced,
170    NeedsPush,
171    NeedsPull,
172}
173
174impl fmt::Display for BalanceState {
175    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
176        match self {
177            BalanceState::Balanced => write!(f, "BALANCED"),
178            BalanceState::NeedsPush => write!(f, "OVER-LOADED"),
179            BalanceState::NeedsPull => write!(f, "UNDER-LOADED"),
180        }
181    }
182}
183
184macro_rules! impl_ord_for_type {
185    ($($t:ty),*) => {
186        $(
187            impl PartialEq for $t {
188                fn eq(&self, other: &Self) -> bool {
189                    <dyn LoadOrdered>::eq(self, other)
190                }
191            }
192
193            impl Eq for $t {}
194
195            impl PartialOrd for $t {
196                fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
197                    <dyn LoadOrdered>::partial_cmp(self, other)
198                }
199            }
200
201            impl Ord for $t {
202                fn cmp(&self, other: &Self) -> Ordering {
203                    <dyn LoadOrdered>::cmp(self, other)
204                }
205            }
206        )*
207    };
208}
209
210trait LoadOrdered {
211    fn get_load(&self) -> OrderedFloat<f64>;
212}
213
214impl dyn LoadOrdered {
215    #[inline]
216    fn eq(&self, other: &Self) -> bool {
217        self.get_load().eq(&other.get_load())
218    }
219
220    #[inline]
221    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
222        self.get_load().partial_cmp(&other.get_load())
223    }
224
225    #[inline]
226    fn cmp(&self, other: &Self) -> Ordering {
227        self.get_load().cmp(&other.get_load())
228    }
229}
230
231#[derive(Debug, Clone)]
232pub struct LoadEntity {
233    cost_ratio: f64,
234    push_max_ratio: f64,
235    xfer_ratio: f64,
236    load_sum: OrderedFloat<f64>,
237    load_avg: f64,
238    load_delta: f64,
239    bal_state: BalanceState,
240}
241
242impl LoadEntity {
243    fn new(
244        cost_ratio: f64,
245        push_max_ratio: f64,
246        xfer_ratio: f64,
247        load_sum: f64,
248        load_avg: f64,
249    ) -> Self {
250        let mut entity = Self {
251            cost_ratio,
252            push_max_ratio,
253            xfer_ratio,
254            load_sum: OrderedFloat(load_sum),
255            load_avg,
256            load_delta: 0.0f64,
257            bal_state: BalanceState::Balanced,
258        };
259        entity.add_load(0.0f64);
260        entity
261    }
262
263    pub fn load_sum(&self) -> f64 {
264        *self.load_sum
265    }
266
267    pub fn load_avg(&self) -> f64 {
268        self.load_avg
269    }
270
271    pub fn imbal(&self) -> f64 {
272        self.load_sum() - self.load_avg
273    }
274
275    pub fn delta(&self) -> f64 {
276        self.load_delta
277    }
278
279    fn state(&self) -> BalanceState {
280        self.bal_state
281    }
282
283    fn rebalance(&mut self, new_load: f64) {
284        self.load_sum = OrderedFloat(new_load);
285
286        let imbal = self.imbal();
287        let needs_balance = imbal.abs() > self.load_avg * self.cost_ratio;
288
289        self.bal_state = if needs_balance {
290            if imbal > 0f64 {
291                BalanceState::NeedsPush
292            } else {
293                BalanceState::NeedsPull
294            }
295        } else {
296            BalanceState::Balanced
297        };
298    }
299
300    fn add_load(&mut self, delta: f64) {
301        self.rebalance(self.load_sum() + delta);
302        self.load_delta += delta;
303    }
304
305    fn push_cutoff(&self) -> f64 {
306        self.imbal().abs() * self.push_max_ratio
307    }
308
309    fn xfer_between(&self, other: &LoadEntity) -> f64 {
310        self.imbal().abs().min(other.imbal().abs()) * self.xfer_ratio
311    }
312}
313
314#[derive(Debug)]
315struct TaskInfo {
316    taskc_p: *mut types::task_ctx,
317    load: OrderedFloat<f64>,
318    dom_mask: u64,
319    preferred_dom_mask: u64,
320    migrated: Cell<bool>,
321    is_kworker: bool,
322}
323
324impl LoadOrdered for TaskInfo {
325    fn get_load(&self) -> OrderedFloat<f64> {
326        self.load
327    }
328}
329impl_ord_for_type!(TaskInfo);
330
331#[derive(Debug)]
332struct Domain {
333    id: usize,
334    queried_tasks: bool,
335    load: LoadEntity,
336    tasks: SortedVec<TaskInfo>,
337}
338
339impl Domain {
340    const LOAD_IMBAL_HIGH_RATIO: f64 = 0.05;
341    const LOAD_IMBAL_XFER_TARGET_RATIO: f64 = 0.50;
342    const LOAD_IMBAL_PUSH_MAX_RATIO: f64 = 0.50;
343
344    fn new(id: usize, load_sum: f64, load_avg: f64) -> Self {
345        Self {
346            id,
347            queried_tasks: false,
348            load: LoadEntity::new(
349                Domain::LOAD_IMBAL_HIGH_RATIO,
350                Domain::LOAD_IMBAL_PUSH_MAX_RATIO,
351                Domain::LOAD_IMBAL_XFER_TARGET_RATIO,
352                load_sum,
353                load_avg,
354            ),
355            tasks: SortedVec::new(),
356        }
357    }
358
359    fn transfer_load(&mut self, load: f64, taskc: &mut types::task_ctx, other: &mut Domain) {
360        trace!("XFER pid={} dom={}->{}", taskc.pid, self.id, other.id);
361
362        let dom_id: u32 = other.id.try_into().unwrap();
363        taskc.target_dom = dom_id;
364
365        self.load.add_load(-load);
366        other.load.add_load(load);
367    }
368
369    fn xfer_between(&self, other: &Domain) -> f64 {
370        self.load.xfer_between(&other.load)
371    }
372}
373
374impl LoadOrdered for Domain {
375    fn get_load(&self) -> OrderedFloat<f64> {
376        self.load.load_sum
377    }
378}
379impl_ord_for_type!(Domain);
380
381#[derive(Debug)]
382struct NumaNode {
383    id: usize,
384    load: LoadEntity,
385    domains: SortedVec<Domain>,
386}
387
388impl NumaNode {
389    const LOAD_IMBAL_HIGH_RATIO: f64 = 0.17;
390    const LOAD_IMBAL_XFER_TARGET_RATIO: f64 = 0.50;
391    const LOAD_IMBAL_PUSH_MAX_RATIO: f64 = 0.50;
392
393    fn new(id: usize, numa_load_avg: f64) -> Self {
394        Self {
395            id,
396            load: LoadEntity::new(
397                NumaNode::LOAD_IMBAL_HIGH_RATIO,
398                NumaNode::LOAD_IMBAL_PUSH_MAX_RATIO,
399                NumaNode::LOAD_IMBAL_XFER_TARGET_RATIO,
400                0.0f64,
401                numa_load_avg,
402            ),
403            domains: SortedVec::new(),
404        }
405    }
406
407    fn allocate_domain(&mut self, id: usize, load: f64, dom_load_avg: f64) {
408        let domain = Domain::new(id, load, dom_load_avg);
409
410        self.insert_domain(domain);
411        self.load.rebalance(self.load.load_sum() + load);
412    }
413
414    fn xfer_between(&self, other: &NumaNode) -> f64 {
415        self.load.xfer_between(&other.load)
416    }
417
418    fn insert_domain(&mut self, domain: Domain) {
419        self.domains.insert(domain);
420    }
421
422    fn update_load(&mut self, delta: f64) {
423        self.load.add_load(delta);
424    }
425
426    fn stats(&self) -> NodeStats {
427        let mut stats = NodeStats::new(
428            self.load.load_sum(),
429            self.load.imbal(),
430            self.load.delta(),
431            BTreeMap::new(),
432        );
433        for dom in self.domains.iter() {
434            stats.doms.insert(
435                dom.id,
436                DomainStats::new(dom.load.load_sum(), dom.load.imbal(), dom.load.delta()),
437            );
438        }
439        stats
440    }
441}
442
443impl LoadOrdered for NumaNode {
444    fn get_load(&self) -> OrderedFloat<f64> {
445        self.load.load_sum
446    }
447}
448impl_ord_for_type!(NumaNode);
449
450pub struct LoadBalancer<'a, 'b> {
451    skel: &'a mut BpfSkel<'b>,
452    dom_group: Arc<DomainGroup>,
453    skip_kworkers: bool,
454
455    infeas_threshold: f64,
456
457    nodes: SortedVec<NumaNode>,
458
459    lb_apply_weight: bool,
460    balance_load: bool,
461}
462
463// Verify that the number of buckets is a factor of the maximum weight to
464// ensure that the range of weight can be split evenly amongst every bucket.
465const_assert_eq!(
466    bpf_intf::consts_LB_MAX_WEIGHT % bpf_intf::consts_LB_LOAD_BUCKETS,
467    0
468);
469
470impl<'a, 'b> LoadBalancer<'a, 'b> {
471    pub fn new(
472        skel: &'a mut BpfSkel<'b>,
473        dom_group: Arc<DomainGroup>,
474        skip_kworkers: bool,
475        lb_apply_weight: bool,
476        balance_load: bool,
477    ) -> Self {
478        Self {
479            skel,
480            skip_kworkers,
481
482            infeas_threshold: bpf_intf::consts_LB_MAX_WEIGHT as f64,
483
484            nodes: SortedVec::new(),
485
486            lb_apply_weight,
487            balance_load,
488
489            dom_group,
490        }
491    }
492
493    /// Perform load balancing calculations. When load balancing is enabled,
494    /// also perform rebalances between NUMA nodes (when running on a
495    /// multi-socket host) and domains.
496    pub fn load_balance(&mut self) -> Result<()> {
497        self.create_domain_hierarchy()?;
498
499        if self.balance_load {
500            self.perform_balancing()?
501        }
502
503        Ok(())
504    }
505
506    pub fn get_stats(&self) -> BTreeMap<usize, NodeStats> {
507        self.nodes
508            .iter()
509            .map(|node| (node.id, node.stats()))
510            .collect()
511    }
512
513    fn create_domain_hierarchy(&mut self) -> Result<()> {
514        let ledger = self.calculate_load_avgs()?;
515
516        let (dom_loads, total_load) = if !self.lb_apply_weight {
517            (
518                ledger
519                    .dom_dcycle_sums()
520                    .into_iter()
521                    .map(|d| DEFAULT_WEIGHT * d)
522                    .collect(),
523                DEFAULT_WEIGHT * ledger.global_dcycle_sum(),
524            )
525        } else {
526            self.infeas_threshold = ledger.effective_max_weight();
527            (ledger.dom_load_sums().to_vec(), ledger.global_load_sum())
528        };
529
530        let num_numa_nodes = self.dom_group.nr_nodes();
531        let numa_load_avg = total_load / num_numa_nodes as f64;
532
533        let mut nodes: Vec<NumaNode> = (0..num_numa_nodes)
534            .map(|id| NumaNode::new(id, numa_load_avg))
535            .collect();
536
537        let dom_load_avg = total_load / dom_loads.len() as f64;
538        for (dom_id, load) in dom_loads.iter().enumerate() {
539            let numa_id = self
540                .dom_group
541                .dom_numa_id(&dom_id)
542                .ok_or_else(|| anyhow!("Failed to get NUMA ID for domain {}", dom_id))?;
543
544            if numa_id >= num_numa_nodes {
545                bail!("NUMA ID {} exceeds maximum {}", numa_id, num_numa_nodes);
546            }
547
548            let node = &mut nodes[numa_id];
549            node.allocate_domain(dom_id, *load, dom_load_avg);
550        }
551
552        self.nodes = SortedVec::from_unsorted(nodes);
553
554        Ok(())
555    }
556
557    fn calculate_load_avgs(&mut self) -> Result<LoadLedger> {
558        const NUM_BUCKETS: u64 = bpf_intf::consts_LB_LOAD_BUCKETS as u64;
559        let now_mono = now_monotonic();
560        let load_half_life = self.skel.maps.rodata_data.as_ref().unwrap().load_half_life;
561
562        let mut aggregator =
563            LoadAggregator::new(self.dom_group.weight(), !self.lb_apply_weight.clone());
564
565        for (dom_id, dom) in self.dom_group.doms() {
566            aggregator.init_domain(*dom_id);
567
568            let dom_ctx = dom.ctx().unwrap();
569
570            for bucket in 0..NUM_BUCKETS {
571                let bucket_ctx = &dom_ctx.buckets[bucket as usize];
572                let rd = &bucket_ctx.rd;
573                let duty_cycle = ravg_read(
574                    rd.val,
575                    rd.val_at,
576                    rd.old,
577                    rd.cur,
578                    now_mono,
579                    load_half_life,
580                    RAVG_FRAC_BITS,
581                );
582
583                if duty_cycle == 0.0f64 {
584                    continue;
585                }
586
587                let weight = self.bucket_weight(bucket);
588                aggregator.record_dom_load(*dom_id, weight, duty_cycle)?;
589            }
590        }
591
592        Ok(aggregator.calculate())
593    }
594
595    fn bucket_range(&self, bucket: u64) -> (f64, f64) {
596        const MAX_WEIGHT: u64 = bpf_intf::consts_LB_MAX_WEIGHT as u64;
597        const NUM_BUCKETS: u64 = bpf_intf::consts_LB_LOAD_BUCKETS as u64;
598        const WEIGHT_PER_BUCKET: u64 = MAX_WEIGHT / NUM_BUCKETS;
599
600        if bucket >= NUM_BUCKETS {
601            panic!("Invalid bucket {}, max {}", bucket, NUM_BUCKETS);
602        }
603
604        // w_x = [1 + (10000 * x) / N, 10000 * (x + 1) / N]
605        let min_w = 1 + (MAX_WEIGHT * bucket) / NUM_BUCKETS;
606        let max_w = min_w + WEIGHT_PER_BUCKET - 1;
607
608        (min_w as f64, max_w as f64)
609    }
610
611    fn bucket_weight(&self, bucket: u64) -> usize {
612        const WEIGHT_PER_BUCKET: f64 = bpf_intf::consts_LB_WEIGHT_PER_BUCKET as f64;
613        let (min_weight, _) = self.bucket_range(bucket);
614
615        // Use the mid-point of the bucket when determining weight
616        (min_weight + (WEIGHT_PER_BUCKET / 2.0f64)).ceil() as usize
617    }
618
619    /// @dom needs to push out tasks to balance loads. Make sure its
620    /// tasks_by_load is populated so that the victim tasks can be picked.
621    fn populate_tasks_by_load(&mut self, dom: &mut Domain) -> Result<()> {
622        if dom.queried_tasks {
623            return Ok(());
624        }
625        dom.queried_tasks = true;
626
627        // Read active_tasks and update read_idx and gen.
628        const MAX_TPTRS: u64 = bpf_intf::consts_MAX_DOM_ACTIVE_TPTRS as u64;
629
630        let types::topo_level(index) = types::topo_level::TOPO_LLC;
631        let ptr = self.skel.maps.bss_data.as_ref().unwrap().topo_nodes[index as usize][dom.id];
632        let dom_ctx = unsafe { std::mem::transmute::<u64, &mut types::dom_ctx>(ptr) };
633        let active_tasks = &mut dom_ctx.active_tasks;
634
635        let (mut ridx, widx) = (active_tasks.read_idx, active_tasks.write_idx);
636        active_tasks.read_idx = active_tasks.write_idx;
637        active_tasks.gen += 1;
638
639        if widx - ridx > MAX_TPTRS {
640            ridx = widx - MAX_TPTRS;
641        }
642
643        // Read task_ctx and load.
644        let load_half_life = self.skel.maps.rodata_data.as_ref().unwrap().load_half_life;
645        let now_mono = now_monotonic();
646
647        for idx in ridx..widx {
648            let taskc_p = active_tasks.tasks[(idx % MAX_TPTRS) as usize];
649            let taskc = unsafe { &mut *taskc_p };
650
651            if taskc.target_dom as usize != dom.id {
652                continue;
653            }
654
655            let rd = &taskc.dcyc_rd;
656            let mut load = ravg_read(
657                rd.val,
658                rd.val_at,
659                rd.old,
660                rd.cur,
661                now_mono,
662                load_half_life,
663                RAVG_FRAC_BITS,
664            );
665
666            let weight = if self.lb_apply_weight {
667                (taskc.weight as f64).min(self.infeas_threshold)
668            } else {
669                DEFAULT_WEIGHT
670            };
671            load *= weight;
672
673            dom.tasks.insert(TaskInfo {
674                taskc_p,
675                load: OrderedFloat(load),
676                dom_mask: taskc.dom_mask,
677                preferred_dom_mask: taskc.preferred_dom_mask,
678                migrated: Cell::new(false),
679                is_kworker: unsafe { taskc.is_kworker.assume_init() },
680            });
681        }
682
683        Ok(())
684    }
685
686    // Find the first candidate task which hasn't already been migrated and
687    // can run in @pull_dom.
688    fn find_first_candidate<'d, I>(tasks_by_load: I) -> Option<&'d TaskInfo>
689    where
690        I: IntoIterator<Item = &'d TaskInfo>,
691    {
692        tasks_by_load.into_iter().next()
693    }
694
695    /// Try to find a task in @push_dom to be moved into @pull_dom. If a task is
696    /// found, move the task between the domains, and return the amount of load
697    /// transferred between the two.
698    fn try_find_move_task(
699        &mut self,
700        (push_dom, to_push): (&mut Domain, f64),
701        (pull_dom, to_pull): (&mut Domain, f64),
702        task_filter: impl Fn(&TaskInfo, u32) -> bool,
703        to_xfer: f64,
704    ) -> Result<Option<f64>> {
705        let to_pull = to_pull.abs();
706        let calc_new_imbal = |xfer: f64| (to_push - xfer).abs() + (to_pull - xfer).abs();
707
708        self.populate_tasks_by_load(push_dom)?;
709
710        // We want to pick a task to transfer from push_dom to pull_dom to
711        // reduce the load imbalance between the two closest to $to_xfer.
712        // IOW, pick a task which has the closest load value to $to_xfer
713        // that can be migrated. Find such task by locating the first
714        // migratable task while scanning left from $to_xfer and the
715        // counterpart while scanning right and picking the better of the
716        // two.
717        let pull_dom_id: u32 = pull_dom.id.try_into().unwrap();
718        let tasks: Vec<TaskInfo> = std::mem::take(&mut push_dom.tasks)
719            .into_vec()
720            .into_iter()
721            .filter(|task| {
722                task.dom_mask & (1 << pull_dom_id) != 0
723                    && !(self.skip_kworkers && task.is_kworker)
724                    && !task.migrated.get()
725            })
726            .collect();
727
728        let (task, new_imbal) = match (
729            Self::find_first_candidate(
730                tasks
731                    .as_slice()
732                    .iter()
733                    .filter(|x| x.load <= OrderedFloat(to_xfer) && task_filter(x, pull_dom_id))
734                    .rev(),
735            ),
736            Self::find_first_candidate(
737                tasks
738                    .as_slice()
739                    .iter()
740                    .filter(|x| x.load >= OrderedFloat(to_xfer) && task_filter(x, pull_dom_id)),
741            ),
742        ) {
743            (None, None) => {
744                std::mem::swap(&mut push_dom.tasks, &mut SortedVec::from_unsorted(tasks));
745                return Ok(None);
746            }
747            (Some(task), None) | (None, Some(task)) => (task, calc_new_imbal(*task.load)),
748            (Some(task0), Some(task1)) => {
749                let (new_imbal0, new_imbal1) =
750                    (calc_new_imbal(*task0.load), calc_new_imbal(*task1.load));
751                if new_imbal0 <= new_imbal1 {
752                    (task0, new_imbal0)
753                } else {
754                    (task1, new_imbal1)
755                }
756            }
757        };
758
759        // If the best candidate can't reduce the imbalance, there's nothing
760        // to do for this pair.
761        let old_imbal = to_push + to_pull;
762        if old_imbal < new_imbal {
763            std::mem::swap(&mut push_dom.tasks, &mut SortedVec::from_unsorted(tasks));
764            return Ok(None);
765        }
766
767        let load = *(task.load);
768        let taskc_p = task.taskc_p;
769        task.migrated.set(true);
770        std::mem::swap(&mut push_dom.tasks, &mut SortedVec::from_unsorted(tasks));
771
772        push_dom.transfer_load(load, unsafe { &mut *taskc_p }, pull_dom);
773        Ok(Some(load))
774    }
775
776    fn transfer_between_nodes(
777        &mut self,
778        push_node: &mut NumaNode,
779        pull_node: &mut NumaNode,
780    ) -> Result<f64> {
781        debug!("Inter node {} -> {} started", push_node.id, pull_node.id);
782
783        let push_imbal = push_node.load.imbal();
784        let pull_imbal = pull_node.load.imbal();
785        let xfer = push_node.xfer_between(pull_node);
786
787        if push_imbal <= 0.0f64 || pull_imbal >= 0.0f64 {
788            bail!(
789                "push node {}:{}, pull node {}:{}",
790                push_node.id,
791                push_imbal,
792                pull_node.id,
793                pull_imbal
794            );
795        }
796        let mut pushers = VecDeque::with_capacity(push_node.domains.len());
797        let mut pullers = Vec::with_capacity(pull_node.domains.len());
798        let mut pushed = 0f64;
799
800        while push_node.domains.len() > 0 {
801            // Push from the busiest node
802            let mut push_dom = push_node.domains.pop().unwrap();
803            if push_dom.load.state() != BalanceState::NeedsPush {
804                push_node.domains.insert(push_dom);
805                break;
806            }
807
808            while pull_node.domains.len() > 0 {
809                let mut pull_dom = pull_node.domains.remove_index(0);
810                if pull_dom.load.state() != BalanceState::NeedsPull {
811                    pull_node.domains.insert(pull_dom);
812                    break;
813                }
814                let mut transferred = self.try_find_move_task(
815                    (&mut push_dom, push_imbal),
816                    (&mut pull_dom, pull_imbal),
817                    |task: &TaskInfo, pull_dom: u32| -> bool {
818                        (task.preferred_dom_mask & (1 << pull_dom)) > 0
819                    },
820                    xfer,
821                )?;
822                if transferred.is_none() {
823                    transferred = self.try_find_move_task(
824                        (&mut push_dom, push_imbal),
825                        (&mut pull_dom, pull_imbal),
826                        |_task: &TaskInfo, _pull_dom: u32| -> bool { true },
827                        xfer,
828                    )?;
829                }
830
831                pullers.push(pull_dom);
832                if let Some(transferred) = transferred {
833                    pushed = transferred;
834                    push_node.update_load(-transferred);
835                    pull_node.update_load(transferred);
836                    break;
837                }
838            }
839            while let Some(puller) = pullers.pop() {
840                pull_node.domains.insert(puller);
841            }
842            pushers.push_back(push_dom);
843            if pushed > 0.0f64 {
844                break;
845            }
846        }
847        while let Some(pusher) = pushers.pop_front() {
848            push_node.domains.insert(pusher);
849        }
850
851        Ok(pushed)
852    }
853
854    fn balance_between_nodes(&mut self) -> Result<()> {
855        if self.nodes.len() < 2 {
856            return Ok(());
857        }
858
859        debug!("Node <-> Node LB started");
860
861        // Keep track of the nodes we're pushing load from, and pulling load to,
862        // respectively. We use separate vectors like this to allow us to
863        // mutably iterate over the same list, and pull nodes from the front and
864        // back in a nested fashion. The load algorithm looks roughly like this:
865        //
866        // In sorted order from most -> least loaded:
867        //
868        // For each "push node" (i.e. node with a positive load imbalance):
869        // restart_push:
870        //      For each "pull node" (i.e. node with a negative load imbalance):
871        //              For each "push domain" (i.e. each domain in "push node"
872        //              with a positive load imbalance):
873        //                      For each "pull domain" (i.e. each domain in
874        //                      "pull node" with a negative load imbalance):
875        //                              load = try_move_load(push_dom -> pull-dom)
876        //                              if load > 0
877        //                                      goto restart_pushext_pull
878        //
879        // There are four levels of nesting here, but in practice these are very
880        // shallow loops, as a system doesn't usually have many nodes or domains
881        // per node, only a subset of them will be imbalanced, and the
882        // imbalanced nodes and domains will only ever be in push imbalance, or
883        // pull imbalance at any given time.
884        //
885        // Because we're iterating mutably over these lists, we pop nodes and
886        // domains off of their lists, and then re-insert them after we're done
887        // doing migrations. The lists below are how we keep track of
888        // already-visited nodes while we're still iterating over the lists.
889        // Note that we immediately go back to iterating over every pull node
890        // any time we successfully transfer load, so that we ensure that we're
891        // always sending load to the least-loaded node.
892        //
893        // Note that we use a VecDeque for the pushers because we're iterating
894        // over self.nodes in descending-load order. Thus, when we're done
895        // iterating and we're adding the popped nodes back into self.nodes, we
896        // want to add them back in _ascending_ order so that we don't have to
897        // unnecessarily shift any already-re-added nodes to the right in the
898        // backing vector. In other words, this lets us do a true append in the
899        // SortedVec, rather than doing an insert(list.len() - 2, node). This
900        // applies both to iterating over push-imbalanced nodes, and iterating
901        // over push-imbalanced domains in the inner loops.
902        let mut pushers = VecDeque::with_capacity(self.nodes.len());
903        let mut pullers = Vec::with_capacity(self.nodes.len());
904
905        while self.nodes.len() >= 2 {
906            // Push from the busiest node
907            let mut push_node = self.nodes.pop().unwrap();
908            if push_node.load.state() != BalanceState::NeedsPush {
909                self.nodes.insert(push_node);
910                break;
911            }
912
913            let push_cutoff = push_node.load.push_cutoff();
914            let mut pushed = 0f64;
915            while self.nodes.len() > 0 && pushed < push_cutoff {
916                // To the least busy node
917                let mut pull_node = self.nodes.remove_index(0);
918                let pull_id = pull_node.id;
919                if pull_node.load.state() != BalanceState::NeedsPull {
920                    self.nodes.insert(pull_node);
921                    break;
922                }
923                let migrated = self.transfer_between_nodes(&mut push_node, &mut pull_node)?;
924                pullers.push(pull_node);
925                if migrated > 0.0f64 {
926                    // Break after a successful migration so that we can
927                    // rebalance the pulling domains before the next
928                    // transfer attempt, and ensure that we're trying to
929                    // pull from domains in descending-imbalance order.
930                    pushed += migrated;
931                    debug!(
932                        "NODE {} sending {:.06} --> NODE {}",
933                        push_node.id, migrated, pull_id
934                    );
935                }
936            }
937            while let Some(puller) = pullers.pop() {
938                self.nodes.insert(puller);
939            }
940
941            if pushed > 0.0f64 {
942                debug!("NODE {} pushed {:.06} total load", push_node.id, pushed);
943            }
944            pushers.push_back(push_node);
945        }
946
947        while !pushers.is_empty() {
948            self.nodes.insert(pushers.pop_front().unwrap());
949        }
950
951        Ok(())
952    }
953
954    fn balance_within_node(&mut self, node: &mut NumaNode) -> Result<()> {
955        if node.domains.len() < 2 {
956            return Ok(());
957        }
958
959        debug!("Intra node {} LB started", node.id);
960
961        // See the comment in balance_between_nodes() for the purpose of these
962        // lists. Everything is roughly the same here as in that comment block,
963        // with the notable exception that we're only iterating over domains
964        // inside of a single node.
965        let mut pushers = VecDeque::with_capacity(node.domains.len());
966        let mut pullers = Vec::new();
967
968        while node.domains.len() >= 2 {
969            let mut push_dom = node.domains.pop().unwrap();
970            if node.domains.len() == 0 || push_dom.load.state() != BalanceState::NeedsPush {
971                node.domains.insert(push_dom);
972                break;
973            }
974
975            let mut pushed = 0.0f64;
976            let push_cutoff = push_dom.load.push_cutoff();
977            let push_imbal = push_dom.load.imbal();
978            if push_imbal < 0.0f64 {
979                bail!(
980                    "Node {} push dom {} had imbal {}",
981                    node.id,
982                    push_dom.id,
983                    push_imbal
984                );
985            }
986
987            while node.domains.len() > 0 && pushed < push_cutoff {
988                let mut pull_dom = node.domains.remove_index(0);
989                if pull_dom.load.state() != BalanceState::NeedsPull {
990                    node.domains.push(pull_dom);
991                    break;
992                }
993                let pull_imbal = pull_dom.load.imbal();
994                if pull_imbal >= 0.0f64 {
995                    bail!(
996                        "Node {} pull dom {} had imbal {}",
997                        node.id,
998                        pull_dom.id,
999                        pull_imbal
1000                    );
1001                }
1002                let xfer = push_dom.xfer_between(&pull_dom);
1003                let mut transferred = self.try_find_move_task(
1004                    (&mut push_dom, push_imbal),
1005                    (&mut pull_dom, pull_imbal),
1006                    |task: &TaskInfo, pull_dom: u32| -> bool {
1007                        (task.preferred_dom_mask & (1 << pull_dom)) > 0
1008                    },
1009                    xfer,
1010                )?;
1011                if transferred.is_none() {
1012                    transferred = self.try_find_move_task(
1013                        (&mut push_dom, push_imbal),
1014                        (&mut pull_dom, pull_imbal),
1015                        |_task: &TaskInfo, _pull_dom: u32| -> bool { true },
1016                        xfer,
1017                    )?;
1018                }
1019
1020                if let Some(transferred) = transferred {
1021                    if transferred <= 0.0f64 {
1022                        bail!("Expected nonzero load transfer")
1023                    }
1024                    pushed += transferred;
1025                    // We've pushed load to pull_dom, and have already updated
1026                    // its load (in try_find_move_task()). Re-insert it into the
1027                    // sorted list (thus ensuring we're still iterating from
1028                    // least load -> most load in the loop above), and try to
1029                    // push more load.
1030                    node.domains.insert(pull_dom);
1031                    continue;
1032                }
1033
1034                // Couldn't push any load to this domain, try the next one.
1035                pullers.push(pull_dom);
1036            }
1037            while let Some(puller) = pullers.pop() {
1038                node.domains.insert(puller);
1039            }
1040
1041            if pushed > 0.0f64 {
1042                debug!("DOM {} pushed {:.06} total load", push_dom.id, pushed);
1043            }
1044            pushers.push_back(push_dom);
1045        }
1046        while let Some(pusher) = pushers.pop_front() {
1047            node.domains.insert(pusher);
1048        }
1049
1050        Ok(())
1051    }
1052
1053    fn perform_balancing(&mut self) -> Result<()> {
1054        // First balance load between the NUMA nodes. Balancing here has a
1055        // higher cost function than balancing between domains inside of NUMA
1056        // nodes, but the mechanics are the same. Adjustments made here are
1057        // reflected in intra-node balancing decisions made next.
1058        if self.dom_group.nr_nodes() > 1 {
1059            self.balance_between_nodes()?;
1060        }
1061
1062        // Now that the NUMA nodes have been balanced, do another balance round
1063        // amongst the domains in each node.
1064
1065        debug!("Intra node LBs started");
1066
1067        // Assume all nodes are now balanced.
1068
1069        let mut nodes = std::mem::take(&mut self.nodes).into_vec();
1070        for node in nodes.iter_mut() {
1071            self.balance_within_node(node)?;
1072        }
1073        std::mem::swap(&mut self.nodes, &mut SortedVec::from_unsorted(nodes));
1074
1075        Ok(())
1076    }
1077}