import numpy as np

def write_vtu(points, cells, velocities, filename="layered_grid_with_flow.vtu"):
    with open(filename, "w") as f:
        f.write('<?xml version="1.0"?>\n')
        f.write('<VTKFile type="UnstructuredGrid" version="0.1" byte_order="LittleEndian">\n')
        f.write('  <UnstructuredGrid>\n')
        f.write(f'    <Piece NumberOfPoints="{len(points)}" NumberOfCells="{len(cells)}">\n')

        # Point data (Velocity vectors)
        f.write('      <PointData Vectors="Velocity">\n')
        f.write('        <DataArray type="Float32" Name="Velocity" NumberOfComponents="3" format="ascii">\n')
        for v in velocities:
            f.write(f"{v[0]} {v[1]} {v[2]}\n")
        f.write('        </DataArray>\n')
        f.write('      </PointData>\n')

        # Points
        f.write('      <Points>\n')
        f.write('        <DataArray type="Float32" NumberOfComponents="3" format="ascii">\n')
        for p in points:
            f.write(f"{p[0]} {p[1]} {p[2]}\n")
        f.write('        </DataArray>\n')
        f.write('      </Points>\n')

        # Cells
        f.write('      <Cells>\n')
        f.write('        <DataArray type="Int32" Name="connectivity" format="ascii">\n')
        for c in cells:
            f.write(" ".join(map(str, c)) + "\n")
        f.write('        </DataArray>\n')

        f.write('        <DataArray type="Int32" Name="offsets" format="ascii">\n')
        offset = 0
        for c in cells:
            offset += len(c)
            f.write(f"{offset}\n")
        f.write('        </DataArray>\n')

        f.write('        <DataArray type="UInt8" Name="types" format="ascii">\n')
        for c in cells:
            f.write("12\n")  # VTK_HEXAHEDRON
        f.write('        </DataArray>\n')
        f.write('      </Cells>\n')

        f.write('    </Piece>\n')
        f.write('  </UnstructuredGrid>\n')
        f.write('</VTKFile>\n')

# Parameters
nx, ny = 25, 25
nz1 = 25
layers = 3
resolutions = [4, 2, 1]  # each next layer half resolution
thicknesses = [nz1, nz1, nz1]

points = []
cells = []
velocities = []

z_start = 0
domain_x, domain_y = 1.0, 1.0
xc, yc = domain_x / 2, domain_y / 2
r0 = 0.2  # cylinder radius (in normalized units)
vmin, vmax = 1.0, 5.0

for i, res in enumerate(resolutions):
    nx_layer, ny_layer, nz_layer = nx * res, ny * res, thicknesses[i]

    x = np.linspace(0, domain_x, nx_layer + 1)
    y = np.linspace(0, domain_y, ny_layer + 1)
    z = np.linspace(z_start, z_start + 1, nz_layer + 1)

    grid_x, grid_y, grid_z = np.meshgrid(x, y, z, indexing="ij")
    coords = np.column_stack([grid_x.ravel(), grid_y.ravel(), grid_z.ravel()])

    start_index = len(points)
    points.extend(coords.tolist())

    # Assign velocity vectors
    for (xx, yy, zz) in coords:
        r = np.sqrt((xx - xc) ** 2 + (yy - yc) ** 2)
        vz = vmax if r < r0 else vmin
        velocities.append([0.0, 0.0, vz])

    # Build hexahedral cells
    for ix in range(nx_layer):
        for iy in range(ny_layer):
            for iz in range(nz_layer):
                n0 = start_index + ix * (ny_layer + 1) * (nz_layer + 1) + iy * (nz_layer + 1) + iz
                n1 = n0 + (ny_layer + 1) * (nz_layer + 1)
                n2 = n1 + 1
                n3 = n0 + 1
                n4 = n0 + (nz_layer + 1)
                n5 = n1 + (nz_layer + 1)
                n6 = n2 + (nz_layer + 1)
                n7 = n3 + (nz_layer + 1)
                cells.append([n0, n1, n2, n3, n4, n5, n6, n7])

    z_start += 1

points = np.array(points)
write_vtu(points, cells, velocities, "layered_grid_with_flow.vtu")

