import numpy as np
import matplotlib.pyplot as plt

def rectangle(n, f, a, b, label=None):
    """
    Trace f(x) sur [a,b] et affiche les rectangles gauche (inf.) et droite (sup.)
    Si label est None, affiche 'f(x)' comme label.
    Retourne (aire_left, aire_right).
    """
    if n < 1 or not isinstance(n, int):
        raise ValueError("n doit être un entier >= 1")

    dx = (b - a) / n
    left_edges = np.linspace(a, b - dx, n)
    right_edges = np.linspace(a + dx, b, n)

    heights_left = f(left_edges)
    heights_right = f(right_edges)

    area_left = np.sum(heights_left * dx)
    area_right = np.sum(heights_right * dx)

    is_increasing = f(b) >= f(a)

    x_plot = np.linspace(a, b, 400)
    plt.figure(figsize=(8, 5))
    plt.plot(x_plot, f(x_plot), 'k', label=label or "f(x)", linewidth=2)

    if is_increasing:
        # Trace orange d'abord (en dessous), puis bleu (dessus)
        plt.bar(right_edges - dx, heights_right, width=dx, align='edge',
                edgecolor='black', alpha=0.35, color='orange',
                label=f"Sup., Aire={area_right:.6f}")
        plt.bar(left_edges, heights_left, width=dx, align='edge',
                edgecolor='black', alpha=0.35, color='blue',
                label=f"Inf., Aire={area_left:.6f}")
    else:
        # Trace bleu d'abord, puis orange (ordre original)
        plt.bar(left_edges, heights_left, width=dx, align='edge',
                edgecolor='black', alpha=0.35, color='blue',
                label=f"Inf., Aire={area_left:.6f}")
        plt.bar(right_edges - dx, heights_right, width=dx, align='edge',
                edgecolor='black', alpha=0.35, color='orange',
                label=f"Sup., Aire={area_right:.6f}")

    ymin = min(0, float(np.min(heights_left)), float(np.min(heights_right)))
    ymax = max(float(np.max(heights_left)), float(np.max(heights_right)))
    plt.xlim(a - 0.02 * (b - a), b + 0.02 * (b - a))
    plt.ylim(ymin - 0.05 * (ymax - ymin + 1e-9), ymax + 0.05 * (ymax - ymin + 1e-9))
    plt.xlabel('x')
    plt.ylabel('f(x)')
    plt.title(f"Méthodes des rectangles (n={n})")
    plt.legend()
    plt.grid(alpha=0.2)
    plt.show()

    return area_left, area_right

rectangle(8, lambda x: x**2, 1, 2, label=r"$f(x)=x^2$" )
