diff --git a/src/prs.rs b/src/prs.rs index ed4fe2d..487f320 100644 --- a/src/prs.rs +++ b/src/prs.rs @@ -1285,24 +1285,165 @@ async fn guess_pr( api: &Forgejo, ) -> eyre::Result { let local_repo = git2::Repository::open(".")?; - let head_id = local_repo.head()?.peel_to_commit()?.id(); - let sha = oid_to_string(head_id); - let pr = api - .repo_get_commit_pull_request(repo.owner(), repo.name(), &sha) - .await?; - Ok(pr) + let head = local_repo.head()?; + eyre::ensure!(head.is_branch(), "head is not on branch"); + let local_branch = git2::Branch::wrap(head); + let remote_branch = local_branch.upstream()?; + let remote_head_name = remote_branch + .get() + .name() + .ok_or_eyre("remote branch does not have valid name")?; + let remote_head_short = remote_head_name + .rsplit_once("/") + .map(|(_, b)| b) + .unwrap_or(remote_head_name); + let this_repo = api.repo_get(repo.owner(), repo.name()).await?; + + // check for PRs on the main branch first + let base = this_repo + .default_branch + .as_deref() + .ok_or_eyre("repo does not have default branch")?; + if let Ok(pr) = api + .repo_get_pull_request_by_base_head(repo.owner(), repo.name(), base, remote_head_short) + .await + { + return Ok(pr); + } + + let this_full_name = this_repo + .full_name + .as_deref() + .ok_or_eyre("repo does not have full name")?; + let parent_remote_head_name = format!("{this_full_name}:{remote_head_short}"); + + if let Some(parent) = this_repo.parent.as_deref() { + let (parent_owner, parent_name) = repo_name_from_repo(parent)?; + let parent_base = this_repo + .default_branch + .as_deref() + .ok_or_eyre("repo does not have default branch")?; + if let Ok(pr) = api + .repo_get_pull_request_by_base_head( + parent_owner, + parent_name, + parent_base, + &parent_remote_head_name, + ) + .await + { + return Ok(pr); + } + } + + // then iterate all branches + if let Some(pr) = find_pr_from_branch(repo.owner(), repo.name(), api, remote_head_short).await? + { + return Ok(pr); + } + + if let Some(parent) = this_repo.parent.as_deref() { + let (parent_owner, parent_name) = repo_name_from_repo(parent)?; + + if let Some(pr) = + find_pr_from_branch(parent_owner, parent_name, api, &parent_remote_head_name).await? + { + return Ok(pr); + } + } + + eyre::bail!("could not find PR"); } -fn oid_to_string(oid: git2::Oid) -> String { - let mut s = String::with_capacity(40); - for byte in oid.as_bytes() { - s.push( - char::from_digit((byte & 0xF) as u32, 16).expect("every nibble is a valid hex digit"), - ); - s.push( - char::from_digit(((byte >> 4) & 0xF) as u32, 16) - .expect("every nibble is a valid hex digit"), - ); +async fn find_pr_from_branch( + repo_owner: &str, + repo_name: &str, + api: &Forgejo, + head: &str, +) -> eyre::Result> { + for page in 1.. { + let branch_query = forgejo_api::structs::RepoListBranchesQuery { + page: Some(page), + limit: Some(30), + }; + let remote_branches = match api + .repo_list_branches(repo_owner, repo_name, branch_query) + .await + { + Ok(x) if !x.is_empty() => x, + _ => break, + }; + + let prs = futures::future::try_join_all( + remote_branches + .into_iter() + .map(|branch| check_branch_pair(repo_owner, repo_name, api, branch, head)), + ) + .await?; + for pr in prs { + if pr.is_some() { + return Ok(pr); + } + } } - s + Ok(None) } + +async fn check_branch_pair( + repo_owner: &str, + repo_name: &str, + api: &Forgejo, + base: forgejo_api::structs::Branch, + head: &str, +) -> eyre::Result> { + let base_name = base + .name + .as_deref() + .ok_or_eyre("remote branch does not have name")?; + match api + .repo_get_pull_request_by_base_head(repo_owner, repo_name, base_name, head) + .await + { + Ok(pr) => Ok(Some(pr)), + Err(_) => Ok(None), + } +} + +fn repo_name_from_repo(repo: &forgejo_api::structs::Repository) -> eyre::Result<(&str, &str)> { + let owner = repo + .owner + .as_ref() + .ok_or_eyre("repo does not have owner")? + .login + .as_deref() + .ok_or_eyre("repo owner does not have name")?; + let name = repo.name.as_deref().ok_or_eyre("repo does not have name")?; + Ok((owner, name)) +} + +//async fn guess_pr( +// repo: &RepoName, +// api: &Forgejo, +//) -> eyre::Result { +// let local_repo = git2::Repository::open(".")?; +// let head_id = local_repo.head()?.peel_to_commit()?.id(); +// let sha = oid_to_string(head_id); +// let pr = api +// .repo_get_commit_pull_request(repo.owner(), repo.name(), &sha) +// .await?; +// Ok(pr) +//} +// +//fn oid_to_string(oid: git2::Oid) -> String { +// let mut s = String::with_capacity(40); +// for byte in oid.as_bytes() { +// s.push( +// char::from_digit((byte & 0xF) as u32, 16).expect("every nibble is a valid hex digit"), +// ); +// s.push( +// char::from_digit(((byte >> 4) & 0xF) as u32, 16) +// .expect("every nibble is a valid hex digit"), +// ); +// } +// s +//}