from vtkmodules.vtkCommonDataModel import vtkDataObject
from vtkmodules.vtkCommonDataModel import vtkPartitionedDataSet
from vtkmodules.vtkCommonDataModel import vtkPartitionedDataSetCollection
from vtkmodules.vtkCommonExecutionModel import vtkAlgorithm
from vtkmodules.util.vtkAlgorithm import VTKPythonAlgorithmBase

from paraview.util.vtkAlgorithm import smproxy, smproperty

def _convert_partitions(src: vtkPartitionedDataSetCollection,
                        dest: vtkPartitionedDataSetCollection):
  # Count partitions
  num_partitions = 0
  for block_i in range(0, src.GetNumberOfPartitionedDataSets()):
    block = src.GetPartitionedDataSet(block_i)
    num_partitions += block.GetNumberOfPartitions()
  dest.SetNumberOfPartitionedDataSets(num_partitions)
  # Convert partitions to blocks
  partition_i = 0
  for block_i in range(0, src.GetNumberOfPartitionedDataSets()):
    block = src.GetPartitionedDataSet(block_i)
    for block_part_i in range(0, block.GetNumberOfPartitions()):
      partition = vtkPartitionedDataSet()
      partition.SetNumberOfPartitions(1)
      partition.SetPartition(1, block.GetPartition(block_part_i))
      dest.SetPartitionedDataSet(partition_i, partition)
      partition_i += 1

@smproxy.filter(label="Partitions To Blocks")
@smproperty.input(name="Input")
class PartitionsToBlocks(VTKPythonAlgorithmBase):
  def __init__(self):
    VTKPythonAlgorithmBase.__init__(self)

  def FillInputPortInformation(self, port, info):
    info.Set(vtkAlgorithm.INPUT_REQUIRED_DATA_TYPE(), "vtkPartitionedDataSetCollection")
    return 1

  def FillOutputPortInformation(self, port, info):
    info.Set(vtkDataObject.DATA_TYPE_NAME(), "vtkPartitionedDataSetCollection")
    return 1

  def RequestData(self, request, inInfo, outInfo):
    input0 = vtkPartitionedDataSetCollection.GetData(inInfo[0])
    output = vtkPartitionedDataSetCollection.GetData(outInfo)
    _convert_partitions(input0, output)
    return 1
