Skip to main content

scx_stats/
client.rs

1use crate::StatsErrno;
2use crate::StatsRequest;
3use crate::StatsResponse;
4use anyhow::anyhow;
5use anyhow::bail;
6use anyhow::Result;
7use log::trace;
8use serde::Deserialize;
9use std::io::BufRead;
10use std::io::BufReader;
11use std::io::Write;
12use std::io::{self};
13use std::os::unix::net::UnixStream;
14use std::path::Path;
15use std::path::PathBuf;
16use std::time::Duration;
17
18pub struct StatsClient {
19    base_path: PathBuf,
20    sched_path: PathBuf,
21    stats_path: PathBuf,
22    path: Option<PathBuf>,
23
24    stream: Option<UnixStream>,
25    reader: Option<BufReader<UnixStream>>,
26}
27
28impl StatsClient {
29    pub fn new() -> Self {
30        Self {
31            base_path: PathBuf::from("/var/run/scx"),
32            sched_path: PathBuf::from("root"),
33            stats_path: PathBuf::from("stats"),
34            path: None,
35
36            stream: None,
37            reader: None,
38        }
39    }
40
41    pub fn set_base_path<P: AsRef<Path>>(mut self, path: P) -> Self {
42        self.base_path = PathBuf::from(path.as_ref());
43        self
44    }
45
46    pub fn set_sched_path<P: AsRef<Path>>(mut self, path: P) -> Self {
47        self.sched_path = PathBuf::from(path.as_ref());
48        self
49    }
50
51    pub fn set_stats_path<P: AsRef<Path>>(mut self, path: P) -> Self {
52        self.stats_path = PathBuf::from(path.as_ref());
53        self
54    }
55
56    pub fn set_path<P: AsRef<Path>>(mut self, path: P) -> Self {
57        self.path = Some(PathBuf::from(path.as_ref()));
58        self
59    }
60
61    pub fn connect(mut self, timeout_ms: Option<u64>) -> Result<Self> {
62        if self.path.is_none() {
63            self.path = Some(self.base_path.join(&self.sched_path).join(&self.stats_path));
64        }
65        let path = self.path.as_ref().unwrap();
66
67        let stream = UnixStream::connect(path)?;
68        // Apply the same timeout to both writer and reader sides if provided
69        if let Some(ms) = timeout_ms {
70            let dur = Duration::from_millis(ms);
71            stream.set_write_timeout(Some(dur))?;
72            stream.set_read_timeout(Some(dur))?;
73        }
74
75        self.stream = Some(stream.try_clone()?);
76        self.reader = Some(BufReader::new(stream));
77        Ok(self)
78    }
79
80    pub fn send_request<T>(&mut self, req: &StatsRequest) -> Result<T>
81    where
82        T: for<'a> Deserialize<'a>,
83    {
84        if self.stream.is_none() {
85            bail!("not connected");
86        }
87
88        let req = serde_json::to_string(&req)? + "\n";
89        trace!("Sending: {}", req.trim());
90        // Attempt write with timeout
91        if let Err(e) = self.stream.as_ref().unwrap().write_all(req.as_bytes()) {
92            if e.kind() == io::ErrorKind::TimedOut || e.kind() == io::ErrorKind::WouldBlock {
93                return Err(anyhow!("write timed out"));
94            } else {
95                return Err(e.into());
96            }
97        }
98
99        let mut line = String::new();
100        match self.reader.as_mut().unwrap().read_line(&mut line) {
101            Ok(0) => return Err(anyhow!("connection closed")),
102            Ok(_) => { /* proceed */ }
103            Err(e) => {
104                if e.kind() == io::ErrorKind::TimedOut || e.kind() == io::ErrorKind::WouldBlock {
105                    return Err(anyhow!("read timed out"));
106                } else {
107                    return Err(e.into());
108                }
109            }
110        }
111
112        trace!("Received: {}", line.trim());
113        let mut resp: StatsResponse = serde_json::from_str(&line)?;
114
115        let (errno, resp) = (
116            resp.errno,
117            resp.args.remove("resp").unwrap_or(serde_json::Value::Null),
118        );
119
120        if errno != 0 {
121            Err(anyhow!("{}", &resp).context(StatsErrno(errno)))?;
122        }
123
124        Ok(serde_json::from_value(resp)?)
125    }
126
127    pub fn request<T>(&mut self, req: &str, args: Vec<(String, String)>) -> Result<T>
128    where
129        T: for<'a> Deserialize<'a>,
130    {
131        self.send_request(&StatsRequest::new(req, args))
132    }
133}