From 6fe22f5383b6ac54368892b2aa4b126c1e8870b3 Mon Sep 17 00:00:00 2001 From: Cyborus Date: Sat, 13 Jul 2024 20:28:42 -0400 Subject: [PATCH] fix: prioritize remote tracking branch in repo detection --- src/repo.rs | 53 ++++++++++++++++++++++++++++------------------------- 1 file changed, 28 insertions(+), 25 deletions(-) diff --git a/src/repo.rs b/src/repo.rs index 4a418c9..6ede776 100644 --- a/src/repo.rs +++ b/src/repo.rs @@ -73,19 +73,37 @@ impl RepoInfo { let (remote_url, remote_repo_name) = { let mut out = (None, None); if let Ok(local_repo) = git2::Repository::open(".") { - // help to escape scopes - let tmp; - let mut tmp2; - - let mut name = remote; + let mut name = remote.map(|s| s.to_owned()); // if there's only one remote, use that if name.is_none() { let all_remotes = local_repo.remotes()?; if all_remotes.len() == 1 { if let Some(remote_name) = all_remotes.get(0) { - tmp2 = Some(remote_name.to_owned()); - name = tmp2.as_deref(); + name = Some(remote_name.to_owned()); + } + } + } + + // if the current branch is tracking a remote branch, use that remote + if name.is_none() { + let head = local_repo.head()?; + let branch_name = head.name().ok_or_eyre("branch name not UTF-8")?; + + if let Ok(remote_name) = local_repo.branch_upstream_remote(branch_name) { + let remote_name_s = + remote_name.as_str().ok_or_eyre("remote name invalid")?; + + if let Some(host_url) = &host_url { + let remote = local_repo.find_remote(&remote_name_s)?; + let url_s = std::str::from_utf8(remote.url_bytes())?; + let url = Url::parse(url_s)?; + + if url.host_str() == host_url.host_str() { + name = Some(remote_name_s.to_owned()); + } + } else { + name = Some(remote_name_s.to_owned()); } } } @@ -109,31 +127,16 @@ impl RepoInfo { if url.host_str() == host_url.host_str() && url.path() == host_url.path() { - tmp2 = Some(remote_name.to_owned()); - name = tmp2.as_deref(); + name = Some(remote_name.to_owned()); + break; } } } } } - // if the current branch is tracking a remote branch, use that remote - if name.is_none() { - let head = local_repo.head()?; - let branch_name = head.name().ok_or_else(|| eyre!("branch name not UTF-8"))?; - tmp = local_repo.branch_upstream_remote(branch_name).ok(); - name = tmp - .as_ref() - .map(|remote| { - remote - .as_str() - .ok_or_else(|| eyre!("remote name not UTF-8")) - }) - .transpose()?; - } - if let Some(name) = name { - if let Ok(remote) = local_repo.find_remote(name) { + if let Ok(remote) = local_repo.find_remote(&name) { let url_s = std::str::from_utf8(remote.url_bytes())?; let url = Url::parse(url_s)?; let (url, name) = url_strip_repo_name(url)?;