'''
GEA Computing Ltd. 
June 2026

Paraview forum: 
     https://discourse.paraview.org/t/center-colourmap-to-zero/17587/8


'''


import json
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np


def export_asymmetric_matplotlib_cmap_to_paraview(
    output_file="asymmetric_mpl_colormap.json",
    name="Asymmetric_Matplotlib_Diverging",
    vmin=-2.0,
    vzero=0.0,
    vmax=8.0,
    negative_cmap="Blues_r",
    positive_cmap="Reds",
    n_negative=6,
    n_positive=12,
    zero_color=(1.0, 1.0, 1.0),
    color_space="Lab",
):
    if not (vmin < vzero < vmax):
        raise ValueError("Require vmin < vzero < vmax")

    neg_cmap = plt.get_cmap(negative_cmap)
    pos_cmap = plt.get_cmap(positive_cmap)

    neg_values = np.linspace(vmin, vzero, n_negative, endpoint=False)
    pos_values = np.linspace(vzero, vmax, n_positive + 1, endpoint=True)[1:]

    rgb_points = []

    for value, t in zip(neg_values, np.linspace(0.0, 1.0, n_negative, endpoint=False)):
        r, g, b, _ = neg_cmap(t)
        rgb_points.extend([float(value), float(r), float(g), float(b)])

    rgb_points.extend([float(vzero), *map(float, zero_color)])

    for value, t in zip(pos_values, np.linspace(0.0, 1.0, n_positive, endpoint=True)):
        r, g, b, _ = pos_cmap(t)
        rgb_points.extend([float(value), float(r), float(g), float(b)])

    colormap = [
        {
            "ColorSpace": color_space,
            "Name": name,
            "NanColor": [0.0, 1.0, 0.0],
            "RGBPoints": rgb_points,
        }
    ]

    output_file = Path(output_file)
    with output_file.open("w", encoding="utf-8") as f:
        json.dump(colormap, f, indent=4)

    print(f"Written: {output_file.resolve()}")


if __name__ == "__main__":
    export_asymmetric_matplotlib_cmap_to_paraview(
        output_file="asymmetric_mpl_colormap.json",
        vmin=-60.0,
        vzero=0.0,
        vmax=20.0,
        negative_cmap="Blues_r",
        positive_cmap="YlOrRd",
        n_negative=20,
        n_positive=5,
    )