// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package impl

import (
	"fmt"
	"reflect"

	"google.golang.org/protobuf/proto"
	"google.golang.org/protobuf/reflect/protoreflect"
	"google.golang.org/protobuf/runtime/protoiface"
)

type mergeOptions struct{}

func (o mergeOptions) Merge(dst, src proto.Message) {
	proto.Merge(dst, src)
}

// merge is protoreflect.Methods.Merge.
func (mi *MessageInfo) merge(in protoiface.MergeInput) protoiface.MergeOutput {
	dp, ok := mi.getPointer(in.Destination)
	if !ok {
		return protoiface.MergeOutput{}
	}
	sp, ok := mi.getPointer(in.Source)
	if !ok {
		return protoiface.MergeOutput{}
	}
	mi.mergePointer(dp, sp, mergeOptions{})
	return protoiface.MergeOutput{Flags: protoiface.MergeComplete}
}

func (mi *MessageInfo) mergePointer(dst, src pointer, opts mergeOptions) {
	mi.init()
	if dst.IsNil() {
		panic(fmt.Sprintf("invalid value: merging into nil message"))
	}
	if src.IsNil() {
		return
	}
	for _, f := range mi.orderedCoderFields {
		if f.funcs.merge == nil {
			continue
		}
		sfptr := src.Apply(f.offset)
		if f.isPointer && sfptr.Elem().IsNil() {
			continue
		}
		f.funcs.merge(dst.Apply(f.offset), sfptr, f, opts)
	}
	if mi.extensionOffset.IsValid() {
		sext := src.Apply(mi.extensionOffset).Extensions()
		dext := dst.Apply(mi.extensionOffset).Extensions()
		if *dext == nil {
			*dext = make(map[int32]ExtensionField)
		}
		for num, sx := range *sext {
			xt := sx.Type()
			xi := getExtensionFieldInfo(xt)
			if xi.funcs.merge == nil {
				continue
			}
			dx := (*dext)[num]
			var dv protoreflect.Value
			if dx.Type() == sx.Type() {
				dv = dx.Value()
			}
			if !dv.IsValid() && xi.unmarshalNeedsValue {
				dv = xt.New()
			}
			dv = xi.funcs.merge(dv, sx.Value(), opts)
			dx.Set(sx.Type(), dv)
			(*dext)[num] = dx
		}
	}
	if mi.unknownOffset.IsValid() {
		su := mi.getUnknownBytes(src)
		if su != nil && len(*su) > 0 {
			du := mi.mutableUnknownBytes(dst)
			*du = append(*du, *su...)
		}
	}
}

func mergeScalarValue(dst, src protoreflect.Value, opts mergeOptions) protoreflect.Value {
	return src
}

func mergeBytesValue(dst, src protoreflect.Value, opts mergeOptions) protoreflect.Value {
	return protoreflect.ValueOfBytes(append(emptyBuf[:], src.Bytes()...))
}

func mergeListValue(dst, src protoreflect.Value, opts mergeOptions) protoreflect.Value {
	dstl := dst.List()
	srcl := src.List()
	for i, llen := 0, srcl.Len(); i < llen; i++ {
		dstl.Append(srcl.Get(i))
	}
	return dst
}

func mergeBytesListValue(dst, src protoreflect.Value, opts mergeOptions) protoreflect.Value {
	dstl := dst.List()
	srcl := src.List()
	for i, llen := 0, srcl.Len(); i < llen; i++ {
		sb := srcl.Get(i).Bytes()
		db := append(emptyBuf[:], sb...)
		dstl.Append(protoreflect.ValueOfBytes(db))
	}
	return dst
}

func mergeMessageListValue(dst, src protoreflect.Value, opts mergeOptions) protoreflect.Value {
	dstl := dst.List()
	srcl := src.List()
	for i, llen := 0, srcl.Len(); i < llen; i++ {
		sm := srcl.Get(i).Message()
		dm := proto.Clone(sm.Interface()).ProtoReflect()
		dstl.Append(protoreflect.ValueOfMessage(dm))
	}
	return dst
}

func mergeMessageValue(dst, src protoreflect.Value, opts mergeOptions) protoreflect.Value {
	opts.Merge(dst.Message().Interface(), src.Message().Interface())
	return dst
}

func mergeMessage(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
	if f.mi != nil {
		if dst.Elem().IsNil() {
			dst.SetPointer(pointerOfValue(reflect.New(f.mi.GoReflectType.Elem())))
		}
		f.mi.mergePointer(dst.Elem(), src.Elem(), opts)
	} else {
		dm := dst.AsValueOf(f.ft).Elem()
		sm := src.AsValueOf(f.ft).Elem()
		if dm.IsNil() {
			dm.Set(reflect.New(f.ft.Elem()))
		}
		opts.Merge(asMessage(dm), asMessage(sm))
	}
}

func mergeMessageSlice(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
	for _, sp := range src.PointerSlice() {
		dm := reflect.New(f.ft.Elem().Elem())
		if f.mi != nil {
			f.mi.mergePointer(pointerOfValue(dm), sp, opts)
		} else {
			opts.Merge(asMessage(dm), asMessage(sp.AsValueOf(f.ft.Elem().Elem())))
		}
		dst.AppendPointerSlice(pointerOfValue(dm))
	}
}

func mergeBytes(dst, src pointer, _ *coderFieldInfo, _ mergeOptions) {
	*dst.Bytes() = append(emptyBuf[:], *src.Bytes()...)
}

func mergeBytesNoZero(dst, src pointer, _ *coderFieldInfo, _ mergeOptions) {
	v := *src.Bytes()
	if len(v) > 0 {
		*dst.Bytes() = append(emptyBuf[:], v...)
	}
}

func mergeBytesSlice(dst, src pointer, _ *coderFieldInfo, _ mergeOptions) {
	ds := dst.BytesSlice()
	for _, v := range *src.BytesSlice() {
		*ds = append(*ds, append(emptyBuf[:], v...))
	}
}