Skip to main content

xtask/
bump_versions.rs

1// Copyright (c) Meta Platforms, Inc. and affiliates.
2//
3// This software may be used and distributed according to the terms of the
4// GNU General Public License version 2.
5
6use anyhow::Result;
7use regex::Regex;
8use std::collections::{HashMap, HashSet};
9use std::fs;
10use std::path::PathBuf;
11
12use crate::get_cargo_metadata;
13use crate::get_rust_paths;
14
15pub fn bump_versions_command(
16    packages: Vec<String>,
17    all: bool,
18    min_version: Option<String>,
19) -> Result<()> {
20    // Determine target crates
21    let target_crates = if all {
22        get_all_workspace_crates()?
23    } else {
24        packages
25    };
26
27    if target_crates.is_empty() {
28        log::info!("No crates to bump.");
29        return Ok(());
30    }
31
32    log::info!("Analyzing workspace dependencies...");
33
34    // Get cargo metadata
35    let metadata = get_cargo_metadata()?;
36
37    // Build map of workspace crates
38    let workspace_member_ids: HashSet<String> = metadata
39        .workspace_members
40        .iter()
41        .map(|id| id.to_string())
42        .collect();
43
44    let mut workspace_members = HashSet::new();
45    let mut crate_paths = HashMap::new();
46
47    for pkg in &metadata.packages {
48        if workspace_member_ids.contains(&pkg.id.to_string()) {
49            workspace_members.insert(pkg.name.to_string());
50            crate_paths.insert(
51                pkg.name.to_string(),
52                pkg.manifest_path.as_std_path().to_path_buf(),
53            );
54        }
55    }
56
57    // Validate target crates exist
58    for crate_name in &target_crates {
59        if !workspace_members.contains(crate_name) {
60            return Err(anyhow::anyhow!(
61                "Crate '{}' not found in workspace",
62                crate_name
63            ));
64        }
65    }
66
67    // Find all crates that need to be bumped
68    let mut crates_to_bump = HashSet::new();
69    let mut version_updates = HashMap::new();
70
71    // Start with target crates
72    for target in &target_crates {
73        crates_to_bump.insert(target.clone());
74    }
75
76    // Find dependencies of target crates (what the target crates depend on)
77    for target_crate in &target_crates {
78        // Find the target crate's package in metadata
79        for pkg in &metadata.packages {
80            let pkg_name = pkg.name.as_str();
81
82            if pkg_name == target_crate && workspace_members.contains(pkg_name) {
83                // Add all workspace dependencies of this target crate (exclude dev dependencies)
84                for dep in &pkg.dependencies {
85                    let dep_name = dep.name.as_str();
86                    let is_workspace_dep = dep.source.is_none(); // workspace dependency has null source
87                                                                 // Only include regular dependencies and build dependencies
88                                                                 // Exclude dev dependencies
89                    if is_workspace_dep
90                        && workspace_members.contains(dep_name)
91                        && !matches!(dep.kind, cargo_metadata::DependencyKind::Development)
92                    {
93                        crates_to_bump.insert(dep_name.to_string());
94                    }
95                }
96                break;
97            }
98        }
99    }
100
101    let sorted_crates: Vec<String> = crates_to_bump.iter().cloned().collect();
102    log::info!("Bumping versions for: {}", sorted_crates.join(", "));
103
104    // Show dependencies being bumped
105    let target_set: HashSet<String> = target_crates.iter().cloned().collect();
106    let deps: Vec<String> = crates_to_bump.difference(&target_set).cloned().collect();
107    if !deps.is_empty() {
108        log::info!("Found dependencies: {}", deps.join(", "));
109    }
110
111    // Bump all versions
112    for crate_name in &crates_to_bump {
113        if let Some(crate_path) = crate_paths.get(crate_name) {
114            let (old_version, new_version) =
115                bump_crate_version(crate_path, min_version.as_deref())?;
116            version_updates.insert(crate_name.clone(), new_version.clone());
117            log::info!("Bumping {crate_name}: {old_version} → {new_version}");
118        }
119    }
120
121    // Update dependency references in all affected files
122    update_dependent_versions(&version_updates)?;
123
124    log::info!("\nUpdated {} crates successfully.", crates_to_bump.len());
125    Ok(())
126}
127
128pub fn get_all_workspace_crates() -> Result<Vec<String>> {
129    let metadata = get_cargo_metadata()?;
130    let mut crates = Vec::new();
131
132    let workspace_member_ids: HashSet<String> = metadata
133        .workspace_members
134        .iter()
135        .map(|id| id.to_string())
136        .collect();
137
138    for pkg in &metadata.packages {
139        if workspace_member_ids.contains(&pkg.id.to_string()) {
140            crates.push(pkg.name.to_string());
141        }
142    }
143
144    Ok(crates)
145}
146
147pub fn bump_crate_version(
148    crate_path: &PathBuf,
149    min_version: Option<&str>,
150) -> Result<(String, String)> {
151    let content = fs::read_to_string(crate_path)?;
152    let lines: Vec<&str> = content.lines().collect();
153
154    let version_re = Regex::new(r#"(^\s*version\s*=\s*")([^"]*)(".*$)"#)?;
155
156    for (line_no, line) in lines.iter().enumerate() {
157        if let Some(captures) = version_re.captures(line) {
158            let current_version = captures.get(2).unwrap().as_str();
159            let new_version = if let Some(min) = min_version {
160                if version_tuple(current_version)? < version_tuple(min)? {
161                    min.to_string()
162                } else {
163                    increment_patch_version(current_version)?
164                }
165            } else {
166                increment_patch_version(current_version)?
167            };
168
169            // Update the file
170            let mut new_lines: Vec<String> = lines.iter().map(|s| s.to_string()).collect();
171            new_lines[line_no] = format!(
172                "{}{}{}",
173                captures.get(1).unwrap().as_str(),
174                new_version,
175                captures.get(3).unwrap().as_str()
176            );
177
178            let new_content = new_lines.join("\n") + "\n";
179            fs::write(crate_path, new_content)?;
180
181            return Ok((current_version.to_string(), new_version));
182        }
183    }
184
185    Err(anyhow::anyhow!(
186        "Could not find version in {:?}",
187        crate_path
188    ))
189}
190
191fn version_tuple(version: &str) -> Result<(u32, u32, u32)> {
192    let parts: Vec<&str> = version.split('.').collect();
193    if parts.len() < 3 {
194        return Err(anyhow::anyhow!("Invalid version format: {}", version));
195    }
196    Ok((parts[0].parse()?, parts[1].parse()?, parts[2].parse()?))
197}
198
199fn increment_patch_version(version: &str) -> Result<String> {
200    let parts: Vec<&str> = version.split('.').collect();
201    if parts.len() >= 3 {
202        let major = parts[0];
203        let minor = parts[1];
204        let patch: u32 = parts[2]
205            .parse()
206            .map_err(|_| anyhow::anyhow!("Invalid patch version: {}", parts[2]))?;
207        let new_patch = patch + 1;
208
209        // Handle any additional parts (like pre-release identifiers)
210        if parts.len() > 3 {
211            let extra: Vec<&str> = parts[3..].to_vec();
212            Ok(format!(
213                "{}.{}.{}.{}",
214                major,
215                minor,
216                new_patch,
217                extra.join(".")
218            ))
219        } else {
220            Ok(format!("{major}.{minor}.{new_patch}"))
221        }
222    } else {
223        Err(anyhow::anyhow!("Invalid version format: {}", version))
224    }
225}
226
227pub fn update_dependent_versions(updates: &HashMap<String, String>) -> Result<()> {
228    let rust_paths = get_rust_paths()?;
229    let section_re = Regex::new(r"^\s*\[([^\[\]]*)\]\s*$")?;
230
231    for path in rust_paths {
232        let content = fs::read_to_string(&path)?;
233        let lines: Vec<&str> = content.lines().collect();
234        let mut new_lines: Vec<String> = lines.iter().map(|s| s.to_string()).collect();
235        let mut modified = false;
236
237        let mut in_dep_section = false;
238        let mut block_depth = 0;
239
240        for (line_no, line) in lines.iter().enumerate() {
241            // Check for dependency sections
242            if let Some(captures) = section_re.captures(line) {
243                if block_depth != 0 {
244                    continue;
245                }
246                let section = captures.get(1).unwrap().as_str().trim();
247                // Include all dependency sections
248                in_dep_section = section == "dependencies"
249                    || section == "build-dependencies"
250                    || section == "dev-dependencies";
251                continue;
252            }
253
254            if !in_dep_section {
255                continue;
256            }
257
258            // Track nesting depth
259            block_depth += line.matches('{').count() as i32 - line.matches('}').count() as i32;
260            block_depth += line.matches('[').count() as i32 - line.matches(']').count() as i32;
261
262            if block_depth == 0 {
263                // Look for workspace dependencies that need version updates
264                for (crate_name, new_version) in updates {
265                    let pattern = format!(
266                        r#"(^\s*{}\s*=.*version\s*=\s*")([^"]*)(".*$)"#,
267                        regex::escape(crate_name)
268                    );
269                    if let Some(captures) = Regex::new(&pattern)?.captures(line) {
270                        new_lines[line_no] = format!(
271                            "{}{}{}",
272                            captures.get(1).unwrap().as_str(),
273                            new_version,
274                            captures.get(3).unwrap().as_str()
275                        );
276                        modified = true;
277                        break;
278                    }
279                }
280            }
281        }
282
283        if modified {
284            let new_content = new_lines.join("\n") + "\n";
285            fs::write(&path, new_content)?;
286        }
287    }
288
289    Ok(())
290}