| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416 | // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>// SPDX-License-Identifier: MITpackage mdnsimport (	"context"	"errors"	"math/big"	"net"	"sync"	"time"	"golang.org/x/net/dns/dnsmessage"	"golang.org/x/net/ipv4")// Conn represents a mDNS Servertype Conn struct {	mu  sync.RWMutex	log Logger	socket  *ipv4.PacketConn	dstAddr *net.UDPAddr	queryInterval time.Duration	localNames    []string	queries       []query	interList     []net.Interface	closed chan interface{}}type query struct {	nameWithSuffix  string	queryResultChan chan queryResult}type queryResult struct {	answer dnsmessage.ResourceHeader	addr   net.Addr}const (	defaultQueryInterval = time.Second	maxMessageRecords    = 3	responseTTL          = 1)const (	NetType = "udp4")var (	Address = &net.UDPAddr{IP: net.IPv4(224, 0, 0, 251), Port: 5353})var errNoPositiveMTUFound = errors.New("no positive MTU found")// Server establishes a mDNS connection over an existing connfunc Server(conn *ipv4.PacketConn, config *Config) (*Conn, error) {	if config == nil {		return nil, errNilConfig	}	interfaces, err := net.Interfaces()	if err != nil {		return nil, err	}	inBufSize := 0	joinErrCount := 0	interList := make([]net.Interface, 0, len(interfaces))	for i, ifc := range interfaces {		if err = conn.JoinGroup(&interfaces[i], Address); err != nil {			joinErrCount++			continue		}		interList = append(interList, ifc)		if interfaces[i].MTU > inBufSize {			inBufSize = interfaces[i].MTU		}	}	if inBufSize == 0 {		return nil, errNoPositiveMTUFound	}	if joinErrCount >= len(interfaces) {		return nil, errJoiningMulticastGroup	}	logs := config.Logger	if logs == nil {		logs = &logger{}	}	var localNames []string	for _, name := range config.LocalNames {		localNames = append(localNames, Fqdn(name))	}	c := &Conn{		queryInterval: defaultQueryInterval,		queries:       []query{},		socket:        conn,		dstAddr:       Address,		localNames:    localNames,		interList:     interList,		log:           logs,		closed:        make(chan interface{}),	}	if config.QueryInterval != 0 {		c.queryInterval = config.QueryInterval	}	if err = conn.SetControlMessage(ipv4.FlagInterface, true); err != nil {		c.log.Println("Failed to SetControlMessage on PacketConn %v", err)	}	// https://www.rfc-editor.org/rfc/rfc6762.html#section-17	// Multicast DNS messages carried by UDP may be up to the IP MTU of the	// physical interface, less the space required for the IP header (20	// bytes for IPv4; 40 bytes for IPv6) and the UDP header (8 bytes).	go c.start(inBufSize-20-8, config)	return c, nil}// Close closes the mDNS Connfunc (c *Conn) Close() error {	select {	case <-c.closed:		return nil	default:	}	if err := c.socket.Close(); err != nil {		return err	}	<-c.closed	return nil}// Query sends mDNS Queries for the following name until// either the Context is canceled/expires or we get a resultfunc (c *Conn) Query(ctx context.Context, name string) (dnsmessage.ResourceHeader, net.Addr, error) {	select {	case <-c.closed:		return dnsmessage.ResourceHeader{}, nil, errConnectionClosed	default:	}	name = Fqdn(name)	queryChan := make(chan queryResult, 1)	c.mu.Lock()	c.queries = append(c.queries, query{name, queryChan})	ticker := time.NewTicker(c.queryInterval)	c.mu.Unlock()	defer ticker.Stop()	c.sendQuestion(name)	for {		select {		case <-ticker.C:			c.sendQuestion(name)		case <-c.closed:			return dnsmessage.ResourceHeader{}, nil, errConnectionClosed		case res := <-queryChan:			return res.answer, res.addr, nil		case <-ctx.Done():			return dnsmessage.ResourceHeader{}, nil, errContextElapsed		}	}}func ipToBytes(ip net.IP) (out [4]byte) {	rawIP := ip.To4()	if rawIP == nil {		return	}	ipInt := big.NewInt(0)	ipInt.SetBytes(rawIP)	copy(out[:], ipInt.Bytes())	return}func interfaceForRemote(remote string) (net.IP, error) {	conn, err := net.Dial(NetType, remote)	if err != nil {		return nil, err	}	localAddr, ok := conn.LocalAddr().(*net.UDPAddr)	if !ok {		return nil, errFailedCast	}	if err := conn.Close(); err != nil {		return nil, err	}	return localAddr.IP, nil}func (c *Conn) sendQuestion(name string) {	packedName, err := dnsmessage.NewName(name)	if err != nil {		c.log.Println("Failed to construct mDNS packet %v", err)		return	}	msg := dnsmessage.Message{		Header: dnsmessage.Header{},		Questions: []dnsmessage.Question{			{				Type:  dnsmessage.TypeA,				Class: dnsmessage.ClassINET,				Name:  packedName,			},		},	}	rawQuery, err := msg.Pack()	if err != nil {		c.log.Println("Failed to construct mDNS packet %v", err)		return	}	c.writeToSocket(0, rawQuery, false)}func (c *Conn) writeToSocket(ifIndex int, b []byte, isLoopBack bool) {	if ifIndex != 0 {		ifc, err := net.InterfaceByIndex(ifIndex)		if err != nil {			c.log.Println("Failed to get interface interface for %d: %v", ifIndex, err)			return		}		if isLoopBack && ifc.Flags&net.FlagLoopback == 0 {			// avoid accidentally tricking the destination that itself is the same as us			c.log.Println("Interface is not loopback %d", ifIndex)			return		}		if err = c.socket.SetMulticastInterface(ifc); err != nil {			c.log.Println("Failed to set multicast interface for %d: %v", ifIndex, err)		} else {			if _, err = c.socket.WriteTo(b, nil, c.dstAddr); err != nil {				c.log.Println("Failed to send mDNS packet on interface %d: %v", ifIndex, err)			}		}		return	}	for ifcIdx := range c.interList {		if isLoopBack && c.interList[ifcIdx].Flags&net.FlagLoopback == 0 {			// avoid accidentally tricking the destination that itself is the same as us			continue		}		if err := c.socket.SetMulticastInterface(&c.interList[ifcIdx]); err != nil {			c.log.Println("Failed to set multicast interface for %d: %v", c.interList[ifcIdx].Index, err)		} else {			if _, err = c.socket.WriteTo(b, nil, c.dstAddr); err != nil {				c.log.Println("Failed to send mDNS packet on interface %d: %v", c.interList[ifcIdx].Index, err)			}		}	}}func (c *Conn) sendAnswer(name string, ifIndex int, dst net.IP) {	packedName, err := dnsmessage.NewName(name)	if err != nil {		c.log.Println("Failed to construct mDNS packet %v", err)		return	}	msg := dnsmessage.Message{		Header: dnsmessage.Header{			Response:      true,			Authoritative: true,		},		Answers: []dnsmessage.Resource{			{				Header: dnsmessage.ResourceHeader{					Type:  dnsmessage.TypeA,					Class: dnsmessage.ClassINET,					Name:  packedName,					TTL:   responseTTL,				},				Body: &dnsmessage.AResource{					A: ipToBytes(dst),				},			},		},	}	rawAnswer, err := msg.Pack()	if err != nil {		c.log.Println("Failed to construct mDNS packet %v", err)		return	}	c.writeToSocket(ifIndex, rawAnswer, dst.IsLoopback())}func (c *Conn) start(inboundBufferSize int, config *Config) { // nolint gocognit	defer func() {		c.mu.Lock()		defer c.mu.Unlock()		close(c.closed)	}()	b := make([]byte, inboundBufferSize)	p := dnsmessage.Parser{}	for {		n, cm, src, err := c.socket.ReadFrom(b)		if err != nil {			if errors.Is(err, net.ErrClosed) {				return			}			c.log.Println("Failed to ReadFrom %q %v", src, err)			continue		}		var ifIndex int		if cm != nil {			ifIndex = cm.IfIndex		}		func() {			c.mu.RLock()			defer c.mu.RUnlock()			if _, err := p.Start(b[:n]); err != nil {				c.log.Println("Failed to parse mDNS packet %v", err)				return			}			for i := 0; i <= maxMessageRecords; i++ {				q, err := p.Question()				if errors.Is(err, dnsmessage.ErrSectionDone) {					break				} else if err != nil {					c.log.Println("Failed to parse mDNS packet %v", err)					return				}				for _, localName := range c.localNames {					if localName == q.Name.String() {						if config.LocalAddress != nil {							c.sendAnswer(q.Name.String(), ifIndex, config.LocalAddress)						} else {							localAddress, err := interfaceForRemote(src.String())							if err != nil {								c.log.Println("Failed to get local interface to communicate with %s: %v", src.String(), err)								continue							}							c.sendAnswer(q.Name.String(), ifIndex, localAddress)						}					}				}			}			for i := 0; i <= maxMessageRecords; i++ {				a, err := p.AnswerHeader()				if errors.Is(err, dnsmessage.ErrSectionDone) {					return				}				if err != nil {					c.log.Println("Failed to parse mDNS packet %v", err)					return				}				if a.Type != dnsmessage.TypeA && a.Type != dnsmessage.TypeAAAA {					continue				}				for j := len(c.queries) - 1; j >= 0; j-- {					if c.queries[j].nameWithSuffix == a.Name.String() {						ip, err := ipFromAnswerHeader(a, p)						if err != nil {							c.log.Println("Failed to parse mDNS answer %v", err)							return						}						c.queries[j].queryResultChan <- queryResult{a, &net.IPAddr{							IP: ip,						}}						c.queries = append(c.queries[:j], c.queries[j+1:]...)					}				}			}		}()	}}func ipFromAnswerHeader(a dnsmessage.ResourceHeader, p dnsmessage.Parser) (ip []byte, err error) {	if a.Type == dnsmessage.TypeA {		resource, err := p.AResource()		if err != nil {			return nil, err		}		ip = resource.A[:]	} else {		resource, err := p.AAAAResource()		if err != nil {			return nil, err		}		ip = resource.AAAA[:]	}	return}
 |