diff --git a/src/commands/git/revert.ts b/src/commands/git/revert.ts index 8615d96478bf6..9cd8725a5469f 100644 --- a/src/commands/git/revert.ts +++ b/src/commands/git/revert.ts @@ -1,13 +1,16 @@ +import { Commands } from '../../constants.commands'; import type { Container } from '../../container'; +import { RevertError, RevertErrorReason } from '../../git/errors'; import type { GitBranch } from '../../git/models/branch'; import type { GitLog } from '../../git/models/log'; import type { GitRevisionReference } from '../../git/models/reference'; import { getReferenceLabel } from '../../git/models/reference'; import type { Repository } from '../../git/models/repository'; -import { showGenericErrorMessage } from '../../messages'; +import { showGenericErrorMessage, showShouldCommitOrStashPrompt } from '../../messages'; import type { FlagsQuickPickItem } from '../../quickpicks/items/flags'; import { createFlagsQuickPickItem } from '../../quickpicks/items/flags'; import { Logger } from '../../system/logger'; +import { executeCommand, executeCoreCommand } from '../../system/vscode/command'; import type { ViewsWithRepositoryFolders } from '../../views/viewBase'; import type { PartialStepState, @@ -74,11 +77,37 @@ export class RevertGitCommand extends QuickCommand { } async execute(state: RevertStepState>) { - const references = state.references.map(c => c.ref).reverse(); - for (const ref of references) { + for (const ref of state.references.reverse()) { try { - await state.repo.git.revert(ref, state.flags); + await state.repo.git.revert(ref.ref, state.flags); } catch (ex) { + if (ex instanceof RevertError) { + let shouldRetry = false; + if (ex.reason === RevertErrorReason.LocalChangesWouldBeOverwritten) { + const response = await showShouldCommitOrStashPrompt(); + if (response === 'Stash') { + await executeCommand(Commands.GitCommandsStashPush); + shouldRetry = true; + } else if (response === 'Commit') { + await executeCoreCommand('workbench.view.scm'); + shouldRetry = true; + } else { + continue; + } + } + + if (shouldRetry) { + try { + await state.repo.git.revert(ref.ref, state.flags); + } catch (ex) { + Logger.error(ex, this.title); + void showGenericErrorMessage(ex.message); + } + } + + continue; + } + Logger.error(ex, this.title); void showGenericErrorMessage(ex.message); } diff --git a/src/env/node/git/git.ts b/src/env/node/git/git.ts index 2048bc6c5b26b..dbae91222c5ab 100644 --- a/src/env/node/git/git.ts +++ b/src/env/node/git/git.ts @@ -180,6 +180,7 @@ const revertErrorAndReason = [ [GitErrors.badRevision, RevertErrorReason.BadRevision], [GitErrors.invalidObjectName, RevertErrorReason.InvalidObjectName], [GitErrors.revertHasConflicts, RevertErrorReason.Conflict], + [GitErrors.changesWouldBeOverwritten, RevertErrorReason.LocalChangesWouldBeOverwritten], ]; export class Git { @@ -1597,13 +1598,13 @@ export class Git { return this.git({ cwd: repoPath }, 'reset', '-q', '--', ...pathspecs); } - revert(repoPath: string, ...args: string[]) { + async revert(repoPath: string, ...args: string[]) { try { - return this.git({ cwd: repoPath }, 'revert', ...args); + await this.git({ cwd: repoPath }, 'revert', ...args); } catch (ex) { const msg: string = ex?.toString() ?? ''; for (const [error, reason] of revertErrorAndReason) { - if (error.test(msg)) { + if (error.test(msg) || error.test(ex.stderr ?? '')) { throw new RevertError(reason, ex); } } diff --git a/src/git/errors.ts b/src/git/errors.ts index 5c80a81009a85..1117093e0bc6c 100644 --- a/src/git/errors.ts +++ b/src/git/errors.ts @@ -572,6 +572,7 @@ export const enum RevertErrorReason { BadRevision, InvalidObjectName, Conflict, + LocalChangesWouldBeOverwritten, Other, } @@ -621,6 +622,8 @@ export class RevertError extends Error { return `${baseMessage} because it is not a valid object name.`; case RevertErrorReason.Conflict: return `${baseMessage} it has unresolved conflicts. Resolve the conflicts and try again.`; + case RevertErrorReason.LocalChangesWouldBeOverwritten: + return `${baseMessage} because local changes would be overwritten. Commit or stash your changes first.`; default: return `${baseMessage}.`; } diff --git a/src/messages.ts b/src/messages.ts index d0370344c4995..81a6040c73f0e 100644 --- a/src/messages.ts +++ b/src/messages.ts @@ -230,6 +230,22 @@ export function showIntegrationRequestTimedOutWarningMessage(providerName: strin ); } +export async function showShouldCommitOrStashPrompt(): Promise { + const stash = { title: 'Stash' }; + const commit = { title: 'Commit' }; + const cancel = { title: 'Cancel', isCloseAffordance: true }; + const result = await showMessage( + 'warn', + 'You have changes in your working tree. Commit or stash them before reverting', + undefined, + null, + stash, + commit, + cancel, + ); + return result?.title; +} + export async function showWhatsNewMessage(majorVersion: string) { const confirm = { title: 'OK', isCloseAffordance: true }; const releaseNotes = { title: 'View Release Notes' };