import code

import sympy as sp
from sympy.parsing.sympy_parser import parse_expr
from sympy.utilities.lambdify import lambdastr, lambdify
from sympy.solvers import solve

import numpy

import matplotlib.pyplot as plt


def calc_conjugate(str, varnames='x'):

    # set the symbols
    vars = sp.symbols(varnames)
    x = vars[0] if isinstance(vars, tuple) else vars
    y = sp.symbols('y', real=True)

    # set the function and objective
    fun = parse_expr(str)
    obj = x*y - fun

    # calculate derivative of obj and solve for zero
    sol = solve(sp.diff(obj, x), x)

    # substitute solution into objective
    solfun = sp.simplify(obj.subs(x, sol[0]))

    # if extra values were passed add to lambda function
    varnames = [y] + list(vars[1:]) if isinstance(vars, tuple) else y

    return (sp.sstr(solfun), lambdify(vars, fun, 'numpy'), lambdify(varnames, solfun, 'numpy'))


def example_plot():

    funstr_list = ['x**2', 'exp(x)', '-log(x)', 'log(1+exp(x))', 'x*log(x)']
    fname_list = ['quad', 'exp', 'log', 'log-exp', 'xlogx']

    fconj_str_list = []

    for funstr, fname in zip(funstr_list, fname_list):
        # calcualte the conjugate and get lambda
        fconj_str, flam, fconjlam = calc_conjugate(funstr)
        fconj_str_list.append(fconj_str)

        print('function: {0} conjugate {1}'.format(funstr, fconj_str))

        if False:
            # get numerical values to plot
            x = numpy.linspace(-5, 5, 1000)
            y = numpy.linspace(-5, 5, 1000)

            fx = flam(x)
            fconjy = fconjlam(y)

            # remove the undefined values
            idx = numpy.isfinite(fx)
            x = x[idx]
            fx = fx[idx]

            idx = numpy.isfinite(fconjy)
            y = y[idx]
            fconjy = fconjy[idx]

            # plot and save
            fig = plt.figure(figsize=(12,4))
            plt.subplot(121)
            plt.plot(x, fx, color='red', linewidth=2.0)
            plt.xlabel('x')
            plt.ylabel('f(x)')
            plt.xlim([min(x), max(x)])
            plt.title(funstr)
            plt.subplot(122)
            plt.plot(y, fconjy, color='black', linewidth=2.0)
            plt.xlabel('y')
            plt.ylabel(r'f*(y)')
            plt.xlim([min(y), max(y)])
            fig.subplots_adjust(bottom=0.2)

            plt.savefig(fname)



if __name__=='__main__':

    # a more complicated example with a variable
    res, _, __ = calc_conjugate('a*(x-c)**2', varnames='x a')
    print(res)

    # plotting everything
    example_plot()

# code.interact(local={**locals(), **globals()})
