# -*- coding: utf-8 -*-
"""
Created on Fri May 17 18:13:53 2024

@author: Stéphane Pasquet
@url: https://www.mathweb.fr/euclide/2024/05/18/trouver-lequation-dune-parabole-passant-par-3-points/
"""

from sympy import Matrix, Rational, latex
import matplotlib.pyplot as plt
import numpy as np
import os

# retourne a, b, x de y = ax² + bx + c, où A, B, C sont sur la parabole

def parabola(A,B,C):
    L = [A,B,C]
    matrice = Matrix([[Rational(a**2),Rational(a),Rational(1)] for a,b in L])
    other = Matrix([[Rational(b)] for a,b in L])

    return matrice.solve(other)
    
def f(P,x):
    return P[0]*x*x + P[1]*x + P[2]
    

def draw_parabola(A, B, C):
    P = parabola(A,B,C)
    x = np.linspace(-10, 10, 400)
    y = f(P,x)

    # Convertir les coefficients en chaînes de caractères LaTeX
    P_latex = [latex(P[i]) for i in range(3)]
    chaine = rf'$f(x) = {P_latex[0]}x^2'
    if P[1] < 0:
        chaine += rf' {P_latex[1]}x'
    else:
        chaine += rf' + {P_latex[1]}x'
    if P[2] < 0:
        chaine += rf' {P_latex[2]}$'
    else:
        chaine += rf' + {P_latex[2]}$'
        
    plt.plot(x, y, label=chaine)

    # Ajouter les points

    for a,b in [A,B,C]:
        plt.scatter(a, b, color='red', zorder=5)
        plt.annotate(f'({a},{b})', (a, b), textcoords="offset points", xytext=(-15,-10), ha='center', color='red')
    plt.title('Graphique de la fonction f(x)')
    plt.xlabel('x')
    plt.ylabel('f(x)')
    plt.legend()

    # Afficher la grille
    plt.grid(True)

    # Afficher le graphique
    plt.show()

def export_latex(A, B, C, x_range=(-10, 10), num_points=100):
    P = parabola(A,B,C)
    # Convertir les coefficients en chaînes de caractères LaTeX
    P_latex = [ '({}/{})'.format(P[i].numerator,P[i].denominator) for i in range(3)]
    
    y_min, y_max = 0,0
    x = x_range[0]
    while x <= x_range[1]:
        if P[0]*x*x+P[1]*x+P[2] < y_min:
            y_min = P[0]*x*x+P[1]*x+P[2]
        if P[0]*x*x+P[1]*x+P[2] > y_max:
            y_max = P[0]*x*x+P[1]*x+P[2]
        x += (x_range[1]-x_range[0])/num_points
    
    y_min, y_max = y_min-1, y_max+1
    points_latex = ''
    
    for x,y in [A,B,C]:
        points_latex += f"""
            \\fill[red] ({x},{y}) circle (4pt);\n
        """
    
    tikz_code = f"""
\\documentclass{{standalone}}
\\usepackage{{tikz}}
\\begin{{document}}
\\begin{{tikzpicture}}
    \\draw[gray,dashed] ({x_range[0]},{y_min}) grid ({x_range[1]},{y_max});
    \\draw[thick,->,>=latex] ({x_range[0]},0) -- ({x_range[1]},0);
    \\draw[thick,->,>=latex] (0,{y_min}) -- (0,{y_max});
    \\draw[thick] (1,0.1) -- (1,-0.1) node[below] {{1}};
    \\draw[thick] (0.1,1) -- (-0.1,1) node[left] {{1}};
    \\draw[thick, smooth, domain={x_range[0]}:{x_range[1]}, samples={num_points}] plot (\\x, {{ {P_latex[0]}*\\x*\\x + {P_latex[1]}*\\x + {P_latex[2]} }});
    {points_latex}
\\end{{tikzpicture}}
\\end{{document}}
"""
    with open('parabola_plot.tex', 'w') as f:
        f.write(tikz_code)
        
    cmd = "pdflatex  --shell-escape -synctex=1 -interaction=nonstopmode parabola_plot.tex"
    os.system(cmd)
    
    readpdf = "START parabola_plot.pdf"
    os.system(readpdf)


if __name__ == '__main__':
    A, B, C = (-5,4), (0,-3), (7,5)
    draw_parabola(A, B, C)
    export_latex(A, B, C)


