package mergeset

import (
	"errors"
	"fmt"
	"math/rand"
	"reflect"
	"sort"
	"sync/atomic"
	"testing"
	"time"
)

func TestMergeBlockStreams(t *testing.T) {
	for _, blocksToMerge := range []int{1, 2, 3, 4, 5, 10, 20} {
		t.Run(fmt.Sprintf("blocks-%d", blocksToMerge), func(t *testing.T) {
			for _, maxItemsPerBlock := range []int{1, 2, 10, 100, 1000, 10000} {
				t.Run(fmt.Sprintf("maxItemsPerBlock-%d", maxItemsPerBlock), func(t *testing.T) {
					testMergeBlockStreams(t, blocksToMerge, maxItemsPerBlock)
				})
			}
		})
	}
}

func TestMultilevelMerge(t *testing.T) {
	r := rand.New(rand.NewSource(1))

	// Prepare blocks to merge.
	bsrs, items := newTestInmemoryBlockStreamReaders(r, 10, 4000)
	var itemsMerged atomic.Uint64

	// First level merge
	var dstIP1 inmemoryPart
	var bsw1 blockStreamWriter
	bsw1.MustInitFromInmemoryPart(&dstIP1, -5)
	if err := mergeBlockStreams(&dstIP1.ph, &bsw1, bsrs[:5], nil, nil, &itemsMerged); err != nil {
		t.Fatalf("cannot merge first level part 1: %s", err)
	}

	var dstIP2 inmemoryPart
	var bsw2 blockStreamWriter
	bsw2.MustInitFromInmemoryPart(&dstIP2, -5)
	if err := mergeBlockStreams(&dstIP2.ph, &bsw2, bsrs[5:], nil, nil, &itemsMerged); err != nil {
		t.Fatalf("cannot merge first level part 2: %s", err)
	}

	if n := itemsMerged.Load(); n != uint64(len(items)) {
		t.Fatalf("unexpected itemsMerged; got %d; want %d", n, len(items))
	}

	// Second level merge (aka final merge)
	itemsMerged.Store(0)
	var dstIP inmemoryPart
	var bsw blockStreamWriter
	bsrsTop := []*blockStreamReader{
		newTestBlockStreamReader(&dstIP1),
		newTestBlockStreamReader(&dstIP2),
	}
	bsw.MustInitFromInmemoryPart(&dstIP, 1)
	if err := mergeBlockStreams(&dstIP.ph, &bsw, bsrsTop, nil, nil, &itemsMerged); err != nil {
		t.Fatalf("cannot merge second level: %s", err)
	}
	if n := itemsMerged.Load(); n != uint64(len(items)) {
		t.Fatalf("unexpected itemsMerged after final merge; got %d; want %d", n, len(items))
	}

	// Verify the resulting part (dstIP) contains all the items
	// in the correct order.
	if err := testCheckItems(&dstIP, items); err != nil {
		t.Fatalf("error checking items: %s", err)
	}
}

func TestMergeForciblyStop(t *testing.T) {
	r := rand.New(rand.NewSource(1))
	bsrs, _ := newTestInmemoryBlockStreamReaders(r, 20, 4000)
	var dstIP inmemoryPart
	var bsw blockStreamWriter
	bsw.MustInitFromInmemoryPart(&dstIP, 1)
	ch := make(chan struct{})
	var itemsMerged atomic.Uint64
	close(ch)
	if err := mergeBlockStreams(&dstIP.ph, &bsw, bsrs, nil, ch, &itemsMerged); !errors.Is(err, errForciblyStopped) {
		t.Fatalf("unexpected error during merge: got %v; want %v", err, errForciblyStopped)
	}
	if n := itemsMerged.Load(); n != 0 {
		t.Fatalf("unexpected itemsMerged; got %d; want %d", n, 0)
	}
}

func testMergeBlockStreams(t *testing.T, blocksToMerge, maxItemsPerBlock int) {
	t.Helper()

	r := rand.New(rand.NewSource(1))
	if err := testMergeBlockStreamsSerial(r, blocksToMerge, maxItemsPerBlock); err != nil {
		t.Fatalf("unexpected error in serial test: %s", err)
	}

	const concurrency = 3
	ch := make(chan error, concurrency)
	for i := 0; i < concurrency; i++ {
		go func(n int) {
			rLocal := rand.New(rand.NewSource(int64(n)))
			ch <- testMergeBlockStreamsSerial(rLocal, blocksToMerge, maxItemsPerBlock)
		}(i)
	}

	for i := 0; i < concurrency; i++ {
		select {
		case err := <-ch:
			if err != nil {
				t.Fatalf("unexpected error in concurrent test: %s", err)
			}
		case <-time.After(10 * time.Second):
			t.Fatalf("timeout in concurrent test")
		}
	}
}

func testMergeBlockStreamsSerial(r *rand.Rand, blocksToMerge, maxItemsPerBlock int) error {
	// Prepare blocks to merge.
	bsrs, items := newTestInmemoryBlockStreamReaders(r, blocksToMerge, maxItemsPerBlock)

	// Merge blocks.
	var itemsMerged atomic.Uint64
	var dstIP inmemoryPart
	var bsw blockStreamWriter
	bsw.MustInitFromInmemoryPart(&dstIP, -4)
	if err := mergeBlockStreams(&dstIP.ph, &bsw, bsrs, nil, nil, &itemsMerged); err != nil {
		return fmt.Errorf("cannot merge block streams: %w", err)
	}
	if n := itemsMerged.Load(); n != uint64(len(items)) {
		return fmt.Errorf("unexpected itemsMerged; got %d; want %d", n, len(items))
	}

	// Verify the resulting part (dstIP) contains all the items
	// in the correct order.
	if err := testCheckItems(&dstIP, items); err != nil {
		return fmt.Errorf("error checking items: %w", err)
	}
	return nil
}

func testCheckItems(dstIP *inmemoryPart, items []string) error {
	if int(dstIP.ph.itemsCount) != len(items) {
		return fmt.Errorf("unexpected number of items in the part; got %d; want %d", dstIP.ph.itemsCount, len(items))
	}
	if string(dstIP.ph.firstItem) != string(items[0]) {
		return fmt.Errorf("unexpected first item; got %q; want %q", dstIP.ph.firstItem, items[0])
	}
	if string(dstIP.ph.lastItem) != string(items[len(items)-1]) {
		return fmt.Errorf("unexpected last item; got %q; want %q", dstIP.ph.lastItem, items[len(items)-1])
	}

	var dstItems []string
	dstBsr := newTestBlockStreamReader(dstIP)
	for dstBsr.Next() {
		bh := dstBsr.bh
		if int(bh.itemsCount) != len(dstBsr.Block.items) {
			return fmt.Errorf("unexpected number of items in the block; got %d; want %d", len(dstBsr.Block.items), bh.itemsCount)
		}
		if bh.itemsCount <= 0 {
			return fmt.Errorf("unexpected empty block")
		}
		item := dstBsr.Block.items[0].Bytes(dstBsr.Block.data)
		if string(bh.firstItem) != string(item) {
			return fmt.Errorf("unexpected blockHeader.firstItem; got %q; want %q", bh.firstItem, item)
		}
		for _, it := range dstBsr.Block.items {
			item := it.Bytes(dstBsr.Block.data)
			dstItems = append(dstItems, string(item))
		}
	}
	if err := dstBsr.Error(); err != nil {
		return fmt.Errorf("unexpected error in dstBsr: %w", err)
	}
	if !reflect.DeepEqual(items, dstItems) {
		return fmt.Errorf("unequal items\ngot\n%q\nwant\n%q", dstItems, items)
	}
	return nil
}

func newTestInmemoryBlockStreamReaders(r *rand.Rand, blocksCount, maxItemsPerBlock int) ([]*blockStreamReader, []string) {
	var items []string
	var bsrs []*blockStreamReader
	for i := 0; i < blocksCount; i++ {
		var ib inmemoryBlock
		itemsPerBlock := r.Intn(maxItemsPerBlock) + 1
		for j := 0; j < itemsPerBlock; j++ {
			item := getRandomBytes(r)
			if !ib.Add(item) {
				break
			}
			items = append(items, string(item))
		}
		var ip inmemoryPart
		ip.Init(&ib)
		bsr := newTestBlockStreamReader(&ip)
		bsrs = append(bsrs, bsr)
	}
	sort.Strings(items)
	return bsrs, items
}

func newTestBlockStreamReader(ip *inmemoryPart) *blockStreamReader {
	var bsr blockStreamReader
	bsr.MustInitFromInmemoryPart(ip)
	return &bsr
}