/*
 * Decompiled with CFR 0.152.
 */
package io.activej.dataflow.dataset.impl;

import io.activej.dataflow.dataset.Dataset;
import io.activej.dataflow.dataset.DatasetUtils;
import io.activej.dataflow.graph.DataflowContext;
import io.activej.dataflow.graph.DataflowGraph;
import io.activej.dataflow.graph.Partition;
import io.activej.dataflow.graph.StreamId;
import io.activej.dataflow.node.NodeReduce;
import io.activej.dataflow.node.NodeReduceSimple;
import io.activej.dataflow.node.NodeShard;
import io.activej.dataflow.node.NodeSort;
import io.activej.datastream.processor.StreamReducers;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.function.Function;

public final class DatasetSplitSortReduceRepartitionReduce<K, I, O, A>
extends Dataset<O> {
    private final Dataset<I> input;
    private final Function<I, K> inputKeyFunction;
    private final Function<A, K> accumulatorKeyFunction;
    private final Comparator<K> keyComparator;
    private final StreamReducers.ReducerToResult<K, I, O, A> reducer;
    private final Class<A> accumulatorType;
    private final int sortBufferSize;

    public DatasetSplitSortReduceRepartitionReduce(Dataset<I> input, Function<I, K> inputKeyFunction, Function<A, K> accumulatorKeyFunction, Comparator<K> keyComparator, StreamReducers.ReducerToResult<K, I, O, A> reducer, Class<O> resultType, Class<A> accumulatorType, int sortBufferSize) {
        super(resultType);
        this.input = input;
        this.inputKeyFunction = inputKeyFunction;
        this.accumulatorKeyFunction = accumulatorKeyFunction;
        this.keyComparator = keyComparator;
        this.reducer = reducer;
        this.accumulatorType = accumulatorType;
        this.sortBufferSize = sortBufferSize;
    }

    @Override
    public List<StreamId> channels(DataflowContext context) {
        DataflowGraph graph = context.getGraph();
        int nonce = context.getNonce();
        ArrayList<StreamId> outputStreamIds = new ArrayList<StreamId>();
        ArrayList<NodeShard<K, I>> sharders = new ArrayList<NodeShard<K, I>>();
        int shardIndex = context.generateNodeIndex();
        for (StreamId inputStreamId : this.input.channels(context.withoutFixedNonce())) {
            Partition partition = graph.getPartition(inputStreamId);
            NodeShard<K, I> sharder = new NodeShard<K, I>(shardIndex, this.inputKeyFunction, inputStreamId, nonce);
            graph.addNode(partition, sharder);
            sharders.add(sharder);
        }
        int reduceIndex = context.generateNodeIndex();
        List<Partition> partitions = graph.getAvailablePartitions();
        int[] uploadIndexes = DatasetUtils.generateIndexes(context, sharders.size());
        int[] downloadIndexes = DatasetUtils.generateIndexes(context, partitions.size());
        for (int i = 0; i < partitions.size(); ++i) {
            Partition partition = partitions.get(i);
            NodeReduce nodeReduce = new NodeReduce(reduceIndex, this.keyComparator);
            graph.addNode(partition, nodeReduce);
            int sortIndex = context.generateNodeIndex();
            int simpleReduceIndex = context.generateNodeIndex();
            for (int j = 0; j < sharders.size(); ++j) {
                NodeShard sharder = (NodeShard)sharders.get(j);
                StreamId sharderOutput = sharder.newPartition();
                graph.addNodeStream(sharder, sharderOutput);
                StreamId reducerInput = this.sortReduceForward(context, sharderOutput, partition, sortIndex, simpleReduceIndex, uploadIndexes[i], downloadIndexes[j]);
                nodeReduce.addInput(reducerInput, this.accumulatorKeyFunction, this.reducer.accumulatorToOutput());
            }
            outputStreamIds.add(nodeReduce.getOutput());
        }
        return outputStreamIds;
    }

    private StreamId sortReduceForward(DataflowContext context, StreamId sourceStreamId, Partition targetPartition, int sortIndex, int simpleReduceIndex, int uploadIndex, int downloadIndex) {
        DataflowGraph graph = context.getGraph();
        Partition sourcePartition = graph.getPartition(sourceStreamId);
        NodeSort<K, I> nodeSort = new NodeSort<K, I>(sortIndex, this.input.valueType(), this.inputKeyFunction, this.keyComparator, false, this.sortBufferSize, sourceStreamId);
        graph.addNode(sourcePartition, nodeSort);
        NodeReduceSimple nodeReduce = new NodeReduceSimple(simpleReduceIndex, this.inputKeyFunction, this.keyComparator, this.reducer.inputToAccumulator());
        nodeReduce.addInput(nodeSort.getOutput());
        graph.addNode(sourcePartition, nodeReduce);
        return DatasetUtils.forwardChannel(context, this.accumulatorType, nodeReduce.getOutput(), targetPartition, uploadIndex, downloadIndex);
    }

    @Override
    public Collection<Dataset<?>> getBases() {
        return Collections.singletonList(this.input);
    }
}

