use ruff_python_ast::{
    self as ast, Arguments, CmpOp, Comprehension, Constant, Expr, ExprContext, Ranged, Stmt,
    UnaryOp,
};
use ruff_text_size::TextRange;

use ruff_diagnostics::{AutofixKind, Diagnostic, Edit, Fix, Violation};
use ruff_macros::{derive_message_formats, violation};
use ruff_python_ast::helpers::any_over_expr;
use ruff_python_ast::traversal;
use ruff_python_codegen::Generator;

use crate::checkers::ast::Checker;
use crate::line_width::LineWidth;
use crate::registry::AsRule;

/// ## What it does
/// Checks for `for` loops that can be replaced with a builtin function, like
/// `any` or `all`.
///
/// ## Why is this bad?
/// Using a builtin function is more concise and readable. Builtins are also
/// more efficient than `for` loops.
///
/// ## Example
/// ```python
/// for item in iterable:
///     if predicate(item):
///         return True
/// return False
/// ```
///
/// Use instead:
/// ```python
/// return any(predicate(item) for item in iterable)
/// ```
///
/// ## References
/// - [Python documentation: `any`](https://docs.python.org/3/library/functions.html#any)
/// - [Python documentation: `all`](https://docs.python.org/3/library/functions.html#all)
#[violation]
pub struct ReimplementedBuiltin {
    replacement: String,
}

impl Violation for ReimplementedBuiltin {
    const AUTOFIX: AutofixKind = AutofixKind::Sometimes;

    #[derive_message_formats]
    fn message(&self) -> String {
        let ReimplementedBuiltin { replacement } = self;
        format!("Use `{replacement}` instead of `for` loop")
    }

    fn autofix_title(&self) -> Option<String> {
        let ReimplementedBuiltin { replacement } = self;
        Some(format!("Replace with `{replacement}`"))
    }
}

/// SIM110, SIM111
pub(crate) fn convert_for_loop_to_any_all(checker: &mut Checker, stmt: &Stmt) {
    if !checker.semantic().current_scope().kind.is_function() {
        return;
    }

    // The `for` loop itself must consist of an `if` with a `return`.
    let Some(loop_) = match_loop(stmt) else {
        return;
    };

    // Afterwards, there are two cases to consider:
    // - `for` loop with an `else: return True` or `else: return False`.
    // - `for` loop followed by `return True` or `return False`.
    let Some(terminal) = match_else_return(stmt).or_else(|| {
        let parent = checker.semantic().current_statement_parent()?;
        let suite = traversal::suite(stmt, parent)?;
        let sibling = traversal::next_sibling(stmt, suite)?;
        match_sibling_return(stmt, sibling)
    }) else {
        return;
    };

    // Check if any of the expressions contain an `await` expression.
    if contains_await(loop_.target) || contains_await(loop_.iter) || contains_await(loop_.test) {
        return;
    }

    match (loop_.return_value, terminal.return_value) {
        // Replace with `any`.
        (true, false) => {
            let contents = return_stmt(
                "any",
                loop_.test,
                loop_.target,
                loop_.iter,
                checker.generator(),
            );

            // Don't flag if the resulting expression would exceed the maximum line length.
            let line_start = checker.locator().line_start(stmt.start());
            if LineWidth::new(checker.settings.tab_size)
                .add_str(&checker.locator().contents()[TextRange::new(line_start, stmt.start())])
                .add_str(&contents)
                > checker.settings.line_length
            {
                return;
            }

            let mut diagnostic = Diagnostic::new(
                ReimplementedBuiltin {
                    replacement: contents.to_string(),
                },
                TextRange::new(stmt.start(), terminal.stmt.end()),
            );
            if checker.patch(diagnostic.kind.rule()) && checker.semantic().is_builtin("any") {
                diagnostic.set_fix(Fix::suggested(Edit::replacement(
                    contents,
                    stmt.start(),
                    terminal.stmt.end(),
                )));
            }
            checker.diagnostics.push(diagnostic);
        }
        // Replace with `all`.
        (false, true) => {
            // Invert the condition.
            let test = {
                if let Expr::UnaryOp(ast::ExprUnaryOp {
                    op: UnaryOp::Not,
                    operand,
                    range: _,
                }) = &loop_.test
                {
                    *operand.clone()
                } else if let Expr::Compare(ast::ExprCompare {
                    left,
                    ops,
                    comparators,
                    range: _,
                }) = &loop_.test
                {
                    if let ([op], [comparator]) = (ops.as_slice(), comparators.as_slice()) {
                        let op = match op {
                            CmpOp::Eq => CmpOp::NotEq,
                            CmpOp::NotEq => CmpOp::Eq,
                            CmpOp::Lt => CmpOp::GtE,
                            CmpOp::LtE => CmpOp::Gt,
                            CmpOp::Gt => CmpOp::LtE,
                            CmpOp::GtE => CmpOp::Lt,
                            CmpOp::Is => CmpOp::IsNot,
                            CmpOp::IsNot => CmpOp::Is,
                            CmpOp::In => CmpOp::NotIn,
                            CmpOp::NotIn => CmpOp::In,
                        };
                        let node = ast::ExprCompare {
                            left: left.clone(),
                            ops: vec![op],
                            comparators: vec![comparator.clone()],
                            range: TextRange::default(),
                        };
                        node.into()
                    } else {
                        let node = ast::ExprUnaryOp {
                            op: UnaryOp::Not,
                            operand: Box::new(loop_.test.clone()),
                            range: TextRange::default(),
                        };
                        node.into()
                    }
                } else {
                    let node = ast::ExprUnaryOp {
                        op: UnaryOp::Not,
                        operand: Box::new(loop_.test.clone()),
                        range: TextRange::default(),
                    };
                    node.into()
                }
            };
            let contents = return_stmt("all", &test, loop_.target, loop_.iter, checker.generator());

            // Don't flag if the resulting expression would exceed the maximum line length.
            let line_start = checker.locator().line_start(stmt.start());
            if LineWidth::new(checker.settings.tab_size)
                .add_str(&checker.locator().contents()[TextRange::new(line_start, stmt.start())])
                .add_str(&contents)
                > checker.settings.line_length
            {
                return;
            }

            let mut diagnostic = Diagnostic::new(
                ReimplementedBuiltin {
                    replacement: contents.to_string(),
                },
                TextRange::new(stmt.start(), terminal.stmt.end()),
            );
            if checker.patch(diagnostic.kind.rule()) && checker.semantic().is_builtin("all") {
                diagnostic.set_fix(Fix::suggested(Edit::replacement(
                    contents,
                    stmt.start(),
                    terminal.stmt.end(),
                )));
            }
            checker.diagnostics.push(diagnostic);
        }
        _ => {}
    }
}

/// Represents a `for` loop with a conditional `return`, like:
/// ```python
/// for x in y:
///     if x == 0:
///         return True
/// ```
#[derive(Debug)]
struct Loop<'a> {
    /// The `return` value of the loop.
    return_value: bool,
    /// The test condition in the loop.
    test: &'a Expr,
    /// The target of the loop.
    target: &'a Expr,
    /// The iterator of the loop.
    iter: &'a Expr,
}

/// Represents a `return` statement following a `for` loop, like:
/// ```python
/// for x in y:
///     if x == 0:
///         return True
/// return False
/// ```
///
/// Or:
/// ```python
/// for x in y:
///     if x == 0:
///         return True
/// else:
///     return False
/// ```
#[derive(Debug)]
struct Terminal<'a> {
    return_value: bool,
    stmt: &'a Stmt,
}

fn match_loop(stmt: &Stmt) -> Option<Loop> {
    let Stmt::For(ast::StmtFor {
        body, target, iter, ..
    }) = stmt
    else {
        return None;
    };

    // The loop itself should contain a single `if` statement, with a single `return` statement in
    // the body.
    let [Stmt::If(ast::StmtIf {
        body: nested_body,
        test: nested_test,
        elif_else_clauses: nested_elif_else_clauses,
        range: _,
    })] = body.as_slice()
    else {
        return None;
    };
    if !nested_elif_else_clauses.is_empty() {
        return None;
    }
    let [Stmt::Return(ast::StmtReturn { value, range: _ })] = nested_body.as_slice() else {
        return None;
    };
    let Some(value) = value else {
        return None;
    };
    let Expr::Constant(ast::ExprConstant {
        value: Constant::Bool(value),
        ..
    }) = value.as_ref()
    else {
        return None;
    };

    Some(Loop {
        return_value: *value,
        test: nested_test,
        target,
        iter,
    })
}

/// If a `Stmt::For` contains an `else` with a single boolean `return`, return the [`Terminal`]
/// representing that `return`.
///
/// For example, matches the `return` in:
/// ```python
/// for x in y:
///     if x == 0:
///         return True
/// return False
/// ```
fn match_else_return(stmt: &Stmt) -> Option<Terminal> {
    let Stmt::For(ast::StmtFor { orelse, .. }) = stmt else {
        return None;
    };

    // The `else` block has to contain a single `return True` or `return False`.
    let [Stmt::Return(ast::StmtReturn {
        value: Some(next_value),
        range: _,
    })] = orelse.as_slice()
    else {
        return None;
    };
    let Expr::Constant(ast::ExprConstant {
        value: Constant::Bool(next_value),
        ..
    }) = next_value.as_ref()
    else {
        return None;
    };

    Some(Terminal {
        return_value: *next_value,
        stmt,
    })
}

/// If a `Stmt::For` is followed by a boolean `return`, return the [`Terminal`] representing that
/// `return`.
///
/// For example, matches the `return` in:
/// ```python
/// for x in y:
///     if x == 0:
///         return True
/// else:
///     return False
/// ```
fn match_sibling_return<'a>(stmt: &'a Stmt, sibling: &'a Stmt) -> Option<Terminal<'a>> {
    let Stmt::For(ast::StmtFor { orelse, .. }) = stmt else {
        return None;
    };

    // The loop itself shouldn't have an `else` block.
    if !orelse.is_empty() {
        return None;
    }

    // The next statement has to be a `return True` or `return False`.
    let Stmt::Return(ast::StmtReturn {
        value: Some(next_value),
        range: _,
    }) = &sibling
    else {
        return None;
    };
    let Expr::Constant(ast::ExprConstant {
        value: Constant::Bool(next_value),
        ..
    }) = next_value.as_ref()
    else {
        return None;
    };

    Some(Terminal {
        return_value: *next_value,
        stmt: sibling,
    })
}

/// Generate a return statement for an `any` or `all` builtin comprehension.
fn return_stmt(id: &str, test: &Expr, target: &Expr, iter: &Expr, generator: Generator) -> String {
    let node = ast::ExprGeneratorExp {
        elt: Box::new(test.clone()),
        generators: vec![Comprehension {
            target: target.clone(),
            iter: iter.clone(),
            ifs: vec![],
            is_async: false,
            range: TextRange::default(),
        }],
        range: TextRange::default(),
    };
    let node1 = ast::ExprName {
        id: id.into(),
        ctx: ExprContext::Load,
        range: TextRange::default(),
    };
    let node2 = ast::ExprCall {
        func: Box::new(node1.into()),
        arguments: Arguments {
            args: vec![node.into()],
            keywords: vec![],
            range: TextRange::default(),
        },
        range: TextRange::default(),
    };
    let node3 = ast::StmtReturn {
        value: Some(Box::new(node2.into())),
        range: TextRange::default(),
    };
    generator.stmt(&node3.into())
}

/// Return `true` if the [`Expr`] contains an `await` expression.
fn contains_await(expr: &Expr) -> bool {
    any_over_expr(expr, &Expr::is_await_expr)
}
