# https://stackoverflow.com/questions/52202014/how-can-i-plot-2d-fem-results-using-matplotlib

import matplotlib.pyplot as plt
import matplotlib.tri as tri


# converts quad elements into tri elements
def quads_to_tris(quads):
    tris = [[None for j in range(3)] for i in range(2*len(quads))]
    for i in range(len(quads)):
        j = 2*i
        n0 = quads[i][0]
        n1 = quads[i][1]
        n2 = quads[i][2]
        n3 = quads[i][3]
        tris[j][0] = n0
        tris[j][1] = n1
        tris[j][2] = n2
        tris[j + 1][0] = n2
        tris[j + 1][1] = n3
        tris[j + 1][2] = n0
    return tris


# plots a finite element mesh
def plot_fem_mesh(nodes_x, nodes_y, elements):
    for element in elements:
        x = [nodes_x[element[i]] for i in range(len(element))]
        y = [nodes_y[element[i]] for i in range(len(element))]
        plt.fill(x, y, edgecolor='black', fill=False)


# FEM data
nodes_x = [-1., 1., -1., 1., -1., 1., -1, 1.]
nodes_y = [-4., -4., -2., -2., 2., 2., 4., 4.]
nodal_values = [0., 1., 1., 0., 1., 0., 0., 1.]
elements_quads = [[0, 1, 3, 2], [4, 5, 7, 6]]
elements = elements_quads

# convert all elements into triangles
elements_all_tris = quads_to_tris(elements_quads)

# create an unstructured triangular grid instance
triangulation = tri.Triangulation(nodes_x, nodes_y, elements_all_tris)

# plot the finite element mesh
plot_fem_mesh(nodes_x, nodes_y, elements)

cmap = 'jet'
# plot the contours
plt.tricontourf(triangulation, nodal_values, 200, cmap=cmap)

# show
plt.colorbar()
plt.axis('equal')

plt.show()
