package sharin.unlinq.iterable;

import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Queue;

import sharin.unlinq.Func;
import sharin.unlinq.Func2;
import sharin.unlinq.Pair;
import sharin.unlinq.QueuedIterator;

public class JoinIterable<O, I, K, R> implements Iterable<R> {

    private final Iterable<O> outer;

    private final Iterable<I> inner;

    private final Func<O, K> outerKeySelector;

    private final Func<I, K> innerKeySelector;

    private final Func2<O, I, R> resultSelector;

    public JoinIterable(Iterable<O> outer, Iterable<I> inner,
            Func<O, K> outerKeySelector, Func<I, K> innerKeySelector,
            Func2<O, I, R> resultSelector) {

        this.outer = outer;
        this.resultSelector = resultSelector;
        this.innerKeySelector = innerKeySelector;
        this.outerKeySelector = outerKeySelector;
        this.inner = inner;
    }

    public Iterator<R> iterator() {
        return new QueuedIterator<O, Pair<O, I>, R>(outer.iterator()) {

            private final Map<K, I> innerMap;

            {
                innerMap = new HashMap<K, I>();

                for (I i : inner) {
                    K innerKey = innerKeySelector.call(i);
                    innerMap.put(innerKey, i);
                }
            }

            @Override
            protected void addElement(Queue<Pair<O, I>> queue, O t) {
                K outerKey = outerKeySelector.call(t);
                I i = innerMap.get(outerKey);

                if (i != null) {
                    queue.add(new Pair<O, I>(t, i));
                }
            }

            @Override
            protected R toResult(Pair<O, I> e) {
                return resultSelector.call(e.first, e.second);
            }
        };
    }
}
