// Copyright 2012 Google Inc. All rights reserved. // Use of this source code is governed by the Apache 2.0 // license that can be found in the LICENSE file. //go:build appengine // +build appengine package socket import ( "context" "fmt" "io" "net" "strconv" "time" "github.com/golang/protobuf/proto" "google.golang.org/appengine/internal" pb "google.golang.org/appengine/internal/socket" ) // Dial connects to the address addr on the network protocol. // The address format is host:port, where host may be a hostname or an IP address. // Known protocols are "tcp" and "udp". // The returned connection satisfies net.Conn, and is valid while ctx is valid; // if the connection is to be used after ctx becomes invalid, invoke SetContext // with the new context. func Dial(ctx context.Context, protocol, addr string) (*Conn, error) { return DialTimeout(ctx, protocol, addr, 0) } var ipFamilies = []pb.CreateSocketRequest_SocketFamily{ pb.CreateSocketRequest_IPv4, pb.CreateSocketRequest_IPv6, } // DialTimeout is like Dial but takes a timeout. // The timeout includes name resolution, if required. func DialTimeout(ctx context.Context, protocol, addr string, timeout time.Duration) (*Conn, error) { dialCtx := ctx // Used for dialing and name resolution, but not stored in the *Conn. if timeout > 0 { var cancel context.CancelFunc dialCtx, cancel = context.WithTimeout(ctx, timeout) defer cancel() } host, portStr, err := net.SplitHostPort(addr) if err != nil { return nil, err } port, err := strconv.Atoi(portStr) if err != nil { return nil, fmt.Errorf("socket: bad port %q: %v", portStr, err) } var prot pb.CreateSocketRequest_SocketProtocol switch protocol { case "tcp": prot = pb.CreateSocketRequest_TCP case "udp": prot = pb.CreateSocketRequest_UDP default: return nil, fmt.Errorf("socket: unknown protocol %q", protocol) } packedAddrs, resolved, err := resolve(dialCtx, ipFamilies, host) if err != nil { return nil, fmt.Errorf("socket: failed resolving %q: %v", host, err) } if len(packedAddrs) == 0 { return nil, fmt.Errorf("no addresses for %q", host) } packedAddr := packedAddrs[0] // use first address fam := pb.CreateSocketRequest_IPv4 if len(packedAddr) == net.IPv6len { fam = pb.CreateSocketRequest_IPv6 } req := &pb.CreateSocketRequest{ Family: fam.Enum(), Protocol: prot.Enum(), RemoteIp: &pb.AddressPort{ Port: proto.Int32(int32(port)), PackedAddress: packedAddr, }, } if resolved { req.RemoteIp.HostnameHint = &host } res := &pb.CreateSocketReply{} if err := internal.Call(dialCtx, "remote_socket", "CreateSocket", req, res); err != nil { return nil, err } return &Conn{ ctx: ctx, desc: res.GetSocketDescriptor(), prot: prot, local: res.ProxyExternalIp, remote: req.RemoteIp, }, nil } // LookupIP returns the given host's IP addresses. func LookupIP(ctx context.Context, host string) (addrs []net.IP, err error) { packedAddrs, _, err := resolve(ctx, ipFamilies, host) if err != nil { return nil, fmt.Errorf("socket: failed resolving %q: %v", host, err) } addrs = make([]net.IP, len(packedAddrs)) for i, pa := range packedAddrs { addrs[i] = net.IP(pa) } return addrs, nil } func resolve(ctx context.Context, fams []pb.CreateSocketRequest_SocketFamily, host string) ([][]byte, bool, error) { // Check if it's an IP address. if ip := net.ParseIP(host); ip != nil { if ip := ip.To4(); ip != nil { return [][]byte{ip}, false, nil } return [][]byte{ip}, false, nil } req := &pb.ResolveRequest{ Name: &host, AddressFamilies: fams, } res := &pb.ResolveReply{} if err := internal.Call(ctx, "remote_socket", "Resolve", req, res); err != nil { // XXX: need to map to pb.ResolveReply_ErrorCode? return nil, false, err } return res.PackedAddress, true, nil } // withDeadline is like context.WithDeadline, except it ignores the zero deadline. func withDeadline(parent context.Context, deadline time.Time) (context.Context, context.CancelFunc) { if deadline.IsZero() { return parent, func() {} } return context.WithDeadline(parent, deadline) } // Conn represents a socket connection. // It implements net.Conn. type Conn struct { ctx context.Context desc string offset int64 prot pb.CreateSocketRequest_SocketProtocol local, remote *pb.AddressPort readDeadline, writeDeadline time.Time // optional } // SetContext sets the context that is used by this Conn. // It is usually used only when using a Conn that was created in a different context, // such as when a connection is created during a warmup request but used while // servicing a user request. func (cn *Conn) SetContext(ctx context.Context) { cn.ctx = ctx } func (cn *Conn) Read(b []byte) (n int, err error) { const maxRead = 1 << 20 if len(b) > maxRead { b = b[:maxRead] } req := &pb.ReceiveRequest{ SocketDescriptor: &cn.desc, DataSize: proto.Int32(int32(len(b))), } res := &pb.ReceiveReply{} if !cn.readDeadline.IsZero() { req.TimeoutSeconds = proto.Float64(cn.readDeadline.Sub(time.Now()).Seconds()) } ctx, cancel := withDeadline(cn.ctx, cn.readDeadline) defer cancel() if err := internal.Call(ctx, "remote_socket", "Receive", req, res); err != nil { return 0, err } if len(res.Data) == 0 { return 0, io.EOF } if len(res.Data) > len(b) { return 0, fmt.Errorf("socket: internal error: read too much data: %d > %d", len(res.Data), len(b)) } return copy(b, res.Data), nil } func (cn *Conn) Write(b []byte) (n int, err error) { const lim = 1 << 20 // max per chunk for n < len(b) { chunk := b[n:] if len(chunk) > lim { chunk = chunk[:lim] } req := &pb.SendRequest{ SocketDescriptor: &cn.desc, Data: chunk, StreamOffset: &cn.offset, } res := &pb.SendReply{} if !cn.writeDeadline.IsZero() { req.TimeoutSeconds = proto.Float64(cn.writeDeadline.Sub(time.Now()).Seconds()) } ctx, cancel := withDeadline(cn.ctx, cn.writeDeadline) defer cancel() if err = internal.Call(ctx, "remote_socket", "Send", req, res); err != nil { // assume zero bytes were sent in this RPC break } n += int(res.GetDataSent()) cn.offset += int64(res.GetDataSent()) } return } func (cn *Conn) Close() error { req := &pb.CloseRequest{ SocketDescriptor: &cn.desc, } res := &pb.CloseReply{} if err := internal.Call(cn.ctx, "remote_socket", "Close", req, res); err != nil { return err } cn.desc = "CLOSED" return nil } func addr(prot pb.CreateSocketRequest_SocketProtocol, ap *pb.AddressPort) net.Addr { if ap == nil { return nil } switch prot { case pb.CreateSocketRequest_TCP: return &net.TCPAddr{ IP: net.IP(ap.PackedAddress), Port: int(*ap.Port), } case pb.CreateSocketRequest_UDP: return &net.UDPAddr{ IP: net.IP(ap.PackedAddress), Port: int(*ap.Port), } } panic("unknown protocol " + prot.String()) } func (cn *Conn) LocalAddr() net.Addr { return addr(cn.prot, cn.local) } func (cn *Conn) RemoteAddr() net.Addr { return addr(cn.prot, cn.remote) } func (cn *Conn) SetDeadline(t time.Time) error { cn.readDeadline = t cn.writeDeadline = t return nil } func (cn *Conn) SetReadDeadline(t time.Time) error { cn.readDeadline = t return nil } func (cn *Conn) SetWriteDeadline(t time.Time) error { cn.writeDeadline = t return nil } // KeepAlive signals that the connection is still in use. // It may be called to prevent the socket being closed due to inactivity. func (cn *Conn) KeepAlive() error { req := &pb.GetSocketNameRequest{ SocketDescriptor: &cn.desc, } res := &pb.GetSocketNameReply{} return internal.Call(cn.ctx, "remote_socket", "GetSocketName", req, res) } func init() { internal.RegisterErrorCodeMap("remote_socket", pb.RemoteSocketServiceError_ErrorCode_name) }