# https://stackoverflow.com/questions/52202014/how-can-i-plot-2d-fem-results-using-matplotlib

import matplotlib.pyplot as plt
import matplotlib.tri as tri


# converts quad elements into tri elements
def quads_to_four_tris(quads, nodes_x, nodes_y, nodal_vals):
    n_nodes = len(nodes_x)
    print(f'nodes_x:\n{nodes_x}')

    tris = [[None for j in range(3)] for i in range(4*len(quads))]
    print(f'tris:\n{tris}')
    for i in range(len(quads)):

        j = 4*i
        n0 = quads[i][0]
        n1 = quads[i][1]
        n2 = quads[i][2]
        n3 = quads[i][3]

        new_x = (nodes_x[n0] + nodes_x[n1] + nodes_x[n2] + nodes_x[n3])/4.
        new_y = (nodes_y[n0] + nodes_y[n1] + nodes_y[n2] + nodes_y[n3])/4.
        new_val = (nodal_vals[n0] + nodal_vals[n1] + nodal_vals[n2] +
                   nodal_vals[n3])/4.
        print(f'new_x: {new_x}, new_y: {new_y}, new_val: {new_val}')
        nodes_x.append(new_x)
        nodes_y.append(new_y)
        nodal_vals.append(new_val)
        print(f'nodes_x:\n{nodes_x}')
        print(f'nodes_y:\n{nodes_y}')
        print(f'nodal_vals:\n{nodal_vals}')

        n_nodes += 1

        tris[j][0] = n0
        tris[j][1] = n1
        tris[j][2] = n_nodes-1
        tris[j + 1][0] = n1
        tris[j + 1][1] = n2
        tris[j + 1][2] = n_nodes-1
        tris[j + 2][0] = n2
        tris[j + 2][1] = n3
        tris[j + 2][2] = n_nodes-1
        tris[j + 3][0] = n3
        tris[j + 3][1] = n0
        tris[j + 3][2] = n_nodes-1
    return tris, nodes_x, nodes_y, nodal_vals


# plots a finite element mesh
def plot_fem_mesh(nodes_x, nodes_y, elements):
    for element in elements:
        x = [nodes_x[element[i]] for i in range(len(element))]
        y = [nodes_y[element[i]] for i in range(len(element))]
        plt.fill(x, y, edgecolor='black', fill=False)


# FEM data
nodes_x = [-1., 1., -1., 1., -1., 1., -1, 1.]
nodes_y = [-4., -4., -2., -2., 2., 2., 4., 4.]
nodal_vals = [0., 1., 1., 0., 1., 0., 0., 1.]
elements_quads = [[0, 1, 3, 2], [4, 5, 7, 6]]

elements = elements_quads

# convert all elements into triangles
elements_all_tris, nodes_x, nodes_y, nodal_vals = \
    quads_to_four_tris(elements_quads, nodes_x, nodes_y, nodal_vals)
print(f'elements_all_tris:\n{elements_all_tris}')
print(f'nodes_x:\n{nodes_x}')
print(f'nodes_y:\n{nodes_y}')
print(f'nodal_vals:\n{nodal_vals}')

# create an unstructured triangular grid instance
triangulation = tri.Triangulation(nodes_x, nodes_y, elements_all_tris)
print(f'triangulation:\n{triangulation}')

# plot the finite element mesh
plot_fem_mesh(nodes_x, nodes_y, elements)

cmap = 'jet'
# plot the contours
plt.tricontourf(triangulation, nodal_vals, 200, cmap=cmap)

# show
plt.colorbar()
plt.axis('equal')

plt.show()

exit()
