CapturedSqlPlanAuditTemplate.java

package io.github.databaseaudits.audit.runtime.plan;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.TreeMap;

import com.fasterxml.jackson.databind.JsonNode;

import io.github.databaseaudits.capture.SqlCapturingStatementInspector;
import io.github.databaseaudits.plan.QueryPlanExplainer;
import lombok.AccessLevel;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;

/**
 * Template-method base for the EXPLAIN-driven runtime audits
 * ({@link WhereClauseIndexAudit}, {@link OrderByIndexAudit},
 * {@link JoinIndexAudit}).
 *
 * <p>
 * {@link #audit(Set, Collection)} is the fixed algorithm: read the captured
 * SQL, de-duplicate by statement shape, plan each candidate via
 * {@link QueryPlanExplainer} with the audit's planner settings, and return one
 * readable finding per offending plan node. It also owns the two guards every
 * plan audit shares — an empty capture, and a wholly-unexplainable (vacuous)
 * run — which throw {@link IllegalStateException} rather than returning no
 * findings, so a misconfigured run never looks clean. That logic lives in
 * exactly one place.
 *
 * <p>
 * Subclasses supply only the variation points: which statements to look at
 * ({@link #isCandidate(String)}), which planner GUCs to penalize
 * ({@link #plannerSettings()}), how to recognize an offending node
 * ({@link #collectFindings(JsonNode, List, Set)}), and the statement noun for
 * the vacuous-run guard ({@link #statementNoun()}). Dependencies are
 * constructor-injected and passed up via {@code super(...)}.
 */
@AllArgsConstructor(access = AccessLevel.PROTECTED)
@Slf4j
abstract class CapturedSqlPlanAuditTemplate {
    private static final String FAIL_NO_EXPLAINS_MSG = """
            %d %s statement shape(s) were captured but none could be EXPLAINed,\
             so this audit verified nothing\
             — these plan-based audits are PostgreSQL 16+ only.
             On PostgreSQL, the most likely cause is a missing \
             preferQueryMode=simple on the test datasource JDBC URL.
             See: https://database-audits.github.io/spring-boot-integration/usage.html#postgresql-jdbc-requirement""";

    private static final String SKIP_UNCHECKABLE_MSG = """
            Un-checkable (parameter type inference, jsonb `?`, unparsable).\
             The subsequent all-skipped guard\
             still catches a wholly vacuous run.""";

    /**
     * Available to subclasses for
     * {@link QueryPlanExplainer#textOf(JsonNode, String)} while walking the
     * plan.
     */
    protected final QueryPlanExplainer queryPlanExplainer;

    private final SqlCapturingStatementInspector sqlCapturer;

    /**
     * The template method — fixed across audits. Returns one finding per
     * offending plan node; an empty list when every candidate statement is
     * served by an index.
     *
     * @param excludedRelations
     *                                 The table/relation names to skip.
     * @param excludedSqlFragments
     *                                 The SQL fragments whose containing
     *                                 statements to skip.
     * @throws UnsupportedOperationException
     *                                           On any non-PostgreSQL platform.
     * @throws IllegalStateException
     *                                           If nothing was captured, or if
     *                                           statements were captured but
     *                                           none could be EXPLAINed.
     */
    public final List<String> audit(final Set<String> excludedRelations,
            final Collection<String> excludedSqlFragments) {
        queryPlanExplainer.requirePlanAuditSupport(getClass().getSimpleName());

        final Set<String> capturedSql = sqlCapturer.capturedSql();
        if (capturedSql.isEmpty()) {
            throw new IllegalStateException(
                    SqlCapturingStatementInspector.EMPTY_CAPTURE_MESSAGE);
        }

        final var violations = new TreeMap<String, String>();
        final var checkedShapes = new HashSet<String>();

        final int explainedCount =
                collectViolations(capturedSql, excludedRelations,
                        excludedSqlFragments, violations, checkedShapes);

        requireSomethingExplained(checkedShapes, explainedCount);

        final int skippedCount = checkedShapes.size() - explainedCount;
        log.debug("audit: counts: captured={},"
                + " checked={}, explained={}, skipped={}, violations={}",
                capturedSql.size(), checkedShapes.size(), explainedCount,
                skippedCount, violations.size());

        return findingsOf(violations);
    }

    private int collectViolations(final Set<String> capturedSql,
            final Set<String> excludedRelations,
            final Collection<String> excludedSqlFragments,
            final TreeMap<String, String> violations,
            final HashSet<String> checkedShapes) {
        int explainedCount = 0;
        for (final String rawSql : capturedSql) {
            final String trimmedSql = rawSql.strip();
            final String normalizedSql = sqlCapturer.normalize(trimmedSql);
            final String upperCasedSql = normalizedSql.toUpperCase();

            if (!isCandidate(upperCasedSql)
                    || isExcluded(normalizedSql, excludedSqlFragments)
                    || !checkedShapes.add(upperCasedSql)) {
                continue;
            }
            explainedCount +=
                    explain(trimmedSql, excludedRelations, violations);
        }
        return explainedCount;
    }

    private int explain(final String sql, final Set<String> excludedRelations,
            final TreeMap<String, String> violations) {
        try {
            final JsonNode plan =
                    queryPlanExplainer.planWith(sql, plannerSettings());
            final List<String> findings = new ArrayList<>();
            collectFindings(plan, findings, excludedRelations);
            if (!findings.isEmpty()) {
                violations.put(sql, String.join("; ", findings));
            }
            return 1;
        } catch (final Exception e) {
            log.debug("Skipping un-explainable statement [{}]: {}", sql,
                    SKIP_UNCHECKABLE_MSG, e);
            return 0;
        }
    }

    private boolean isExcluded(final String normalizedSql,
            final Collection<String> excludedSqlFragments) {
        final String lower = normalizedSql.toLowerCase();
        return excludedSqlFragments.stream()
                .anyMatch(f -> lower.contains(f.toLowerCase()));
    }

    private void requireSomethingExplained(final HashSet<String> checkedShapes,
            final int explainedCount) {
        if (!checkedShapes.isEmpty() && explainedCount == 0) {
            throw new IllegalStateException(FAIL_NO_EXPLAINS_MSG
                    .formatted(checkedShapes.size(), statementNoun()));
        }
    }

    private List<String> findingsOf(final TreeMap<String, String> violations) {
        return violations.entrySet().stream()
                .map(violation -> violation.getValue() + "\n      "
                        + violation.getKey())
                .toList();
    }

    /**
     * Whether this normalized, upper-cased statement is one this audit should
     * EXPLAIN.
     */
    protected abstract boolean isCandidate(String upperCasedSql);

    /**
     * Planner GUCs to penalize so a surviving node proves a missing index, e.g.
     * {@code "enable_seqscan = off"}.
     */
    protected abstract String[] plannerSettings();

    /**
     * Walk the plan tree from {@code plan} and add a human-readable finding for
     * each offending node.
     */
    protected abstract void collectFindings(JsonNode plan,
            List<String> findings, Set<String> excludedRelations);

    /**
     * Noun for the vacuous-run guard message, e.g. {@code "WHERE-clause"} or
     * {@code "ORDER BY"}.
     */
    protected abstract String statementNoun();
}