1use crate::StatsErrno;
2use crate::StatsRequest;
3use crate::StatsResponse;
4use anyhow::anyhow;
5use anyhow::bail;
6use anyhow::Result;
7use log::trace;
8use serde::Deserialize;
9use std::io::BufRead;
10use std::io::BufReader;
11use std::io::Write;
12use std::io::{self};
13use std::os::unix::net::UnixStream;
14use std::path::Path;
15use std::path::PathBuf;
16use std::time::Duration;
17
18pub struct StatsClient {
19 base_path: PathBuf,
20 sched_path: PathBuf,
21 stats_path: PathBuf,
22 path: Option<PathBuf>,
23
24 stream: Option<UnixStream>,
25 reader: Option<BufReader<UnixStream>>,
26}
27
28impl StatsClient {
29 pub fn new() -> Self {
30 Self {
31 base_path: PathBuf::from("/var/run/scx"),
32 sched_path: PathBuf::from("root"),
33 stats_path: PathBuf::from("stats"),
34 path: None,
35
36 stream: None,
37 reader: None,
38 }
39 }
40
41 pub fn set_base_path<P: AsRef<Path>>(mut self, path: P) -> Self {
42 self.base_path = PathBuf::from(path.as_ref());
43 self
44 }
45
46 pub fn set_sched_path<P: AsRef<Path>>(mut self, path: P) -> Self {
47 self.sched_path = PathBuf::from(path.as_ref());
48 self
49 }
50
51 pub fn set_stats_path<P: AsRef<Path>>(mut self, path: P) -> Self {
52 self.stats_path = PathBuf::from(path.as_ref());
53 self
54 }
55
56 pub fn set_path<P: AsRef<Path>>(mut self, path: P) -> Self {
57 self.path = Some(PathBuf::from(path.as_ref()));
58 self
59 }
60
61 pub fn connect(mut self, timeout_ms: Option<u64>) -> Result<Self> {
62 if self.path.is_none() {
63 self.path = Some(self.base_path.join(&self.sched_path).join(&self.stats_path));
64 }
65 let path = self.path.as_ref().unwrap();
66
67 let stream = UnixStream::connect(path)?;
68 if let Some(ms) = timeout_ms {
70 let dur = Duration::from_millis(ms);
71 stream.set_write_timeout(Some(dur))?;
72 stream.set_read_timeout(Some(dur))?;
73 }
74
75 self.stream = Some(stream.try_clone()?);
76 self.reader = Some(BufReader::new(stream));
77 Ok(self)
78 }
79
80 pub fn send_request<T>(&mut self, req: &StatsRequest) -> Result<T>
81 where
82 T: for<'a> Deserialize<'a>,
83 {
84 if self.stream.is_none() {
85 bail!("not connected");
86 }
87
88 let req = serde_json::to_string(&req)? + "\n";
89 trace!("Sending: {}", req.trim());
90 if let Err(e) = self.stream.as_ref().unwrap().write_all(req.as_bytes()) {
92 if e.kind() == io::ErrorKind::TimedOut || e.kind() == io::ErrorKind::WouldBlock {
93 return Err(anyhow!("write timed out"));
94 } else {
95 return Err(e.into());
96 }
97 }
98
99 let mut line = String::new();
100 match self.reader.as_mut().unwrap().read_line(&mut line) {
101 Ok(0) => return Err(anyhow!("connection closed")),
102 Ok(_) => { }
103 Err(e) => {
104 if e.kind() == io::ErrorKind::TimedOut || e.kind() == io::ErrorKind::WouldBlock {
105 return Err(anyhow!("read timed out"));
106 } else {
107 return Err(e.into());
108 }
109 }
110 }
111
112 trace!("Received: {}", line.trim());
113 let mut resp: StatsResponse = serde_json::from_str(&line)?;
114
115 let (errno, resp) = (
116 resp.errno,
117 resp.args.remove("resp").unwrap_or(serde_json::Value::Null),
118 );
119
120 if errno != 0 {
121 Err(anyhow!("{}", &resp).context(StatsErrno(errno)))?;
122 }
123
124 Ok(serde_json::from_value(resp)?)
125 }
126
127 pub fn request<T>(&mut self, req: &str, args: Vec<(String, String)>) -> Result<T>
128 where
129 T: for<'a> Deserialize<'a>,
130 {
131 self.send_request(&StatsRequest::new(req, args))
132 }
133}