Parcourir la source

gnet: 增加通过 Context 读写 net.Conn 的方法

Matt Evan il y a 1 jour
Parent
commit
41dcc90dea
3 fichiers modifiés avec 128 ajouts et 22 suppressions
  1. 15 22
      gnet/modbus/conn.go
  2. 54 0
      gnet/net.go
  3. 59 0
      gnet/net_test.go

+ 15 - 22
gnet/modbus/conn.go

@@ -10,7 +10,7 @@ import (
 	"strings"
 	"sync"
 	"time"
-
+	
 	"git.simanc.com/software/golib/v4/gio"
 	"git.simanc.com/software/golib/v4/gnet"
 	"git.simanc.com/software/golib/v4/log"
@@ -75,7 +75,7 @@ func (w *modbusConn) ReadData(ctx context.Context, blockId, address, count int)
 	}
 	w.mu.Lock()
 	defer w.mu.Unlock()
-
+	
 	switch blockId {
 	case Code3:
 		if !w.checkCode3(address, count) {
@@ -85,23 +85,19 @@ func (w *modbusConn) ReadData(ctx context.Context, blockId, address, count int)
 		// TODO 目前仅支持 4x(Code03) 地址
 		return nil, fmt.Errorf("modbus: ReadData: unsupported funCode: %d", blockId)
 	}
-
+	
 	pduGroup := gnet.SplitNumber(count, maxReadRegister)
-
+	
 	aduList := make([]ADU, len(pduGroup))
 	for i, length := range pduGroup { //
 		curAddr := address + i*maxReadRegister
 		pdu := NewPDUReadRegisters(byte(blockId), uint16(curAddr), uint16(length))
 		aduList[i] = NewADU(uint16(i), Protocol, 0, pdu)
 	}
-
+	
 	buf := make([]byte, count*2)
 	for i, adu := range aduList {
-		deadline, ok := ctx.Deadline()
-		if !ok {
-			deadline = time.Now().Add(gnet.ClientReadTimeout)
-		}
-		b, err := w.call(deadline, adu.Serialize())
+		b, err := w.call(ctx, adu.Serialize())
 		if err != nil {
 			return nil, fmt.Errorf("modbus: ReadData: %s", err)
 		}
@@ -114,7 +110,7 @@ func (w *modbusConn) ReadData(ctx context.Context, blockId, address, count int)
 		}
 		copy(buf[maxReadRegister*2*i:], resp.PDU.Data)
 	}
-
+	
 	return buf, nil
 }
 
@@ -145,11 +141,7 @@ func (w *modbusConn) WriteData(ctx context.Context, blockId, address, count int,
 		return errors.Join(ErrParamError, err)
 	}
 	adu := NewADU(uint16(address), Protocol, 0, pdu)
-	deadline, ok := ctx.Deadline()
-	if !ok {
-		deadline = time.Now().Add(gnet.ClientReadTimeout)
-	}
-	b, err := w.call(deadline, adu.Serialize())
+	b, err := w.call(ctx, adu.Serialize())
 	if err != nil {
 		return fmt.Errorf("modbus: WriteData: : %s", err)
 	}
@@ -175,12 +167,13 @@ func (w *modbusConn) checkCode6(address, count int, buf []byte) bool {
 	return (address >= 0 && address <= math.MaxUint16) && (count > 0 && count <= math.MaxUint16) && (len(buf)/2 == count)
 }
 
-func (w *modbusConn) call(deadline time.Time, b gnet.Bytes) ([]byte, error) {
-	if err := w.conn.SetDeadline(deadline); err != nil {
-		w.logger.Error("modbus: call: failed to set deadline: %s", err)
-		return nil, errors.Join(ErrConnError, err)
+func (w *modbusConn) call(ctx context.Context, b gnet.Bytes) ([]byte, error) {
+	ctx, cancel := context.WithCancel(ctx)
+	if _, ok := ctx.Deadline(); !ok {
+		ctx, cancel = context.WithTimeout(ctx, gnet.ClientReadTimeout)
 	}
-	if _, err := w.conn.Write(b); err != nil {
+	defer cancel()
+	if _, err := gnet.WriteWithContext(ctx, w.conn, b); err != nil {
 		w.logger.Error("modbus: call: failed to write response: %s", err)
 		if isNetTimeout(err) {
 			return nil, errors.Join(ErrWriteTimeout, err)
@@ -189,7 +182,7 @@ func (w *modbusConn) call(deadline time.Time, b gnet.Bytes) ([]byte, error) {
 	}
 	w.logger.Debug("modbus: Write: %s", b.HexTo())
 	clear(w.buf)
-	n, err := w.conn.Read(w.buf)
+	n, err := gnet.ReadWithContext(ctx, w.conn, w.buf)
 	if err != nil {
 		w.logger.Error("modbus: call: failed to read response: %s", err)
 		if isNetTimeout(err) {

+ 54 - 0
gnet/net.go

@@ -1,6 +1,7 @@
 package gnet
 
 import (
+	"context"
 	"errors"
 	"math/rand/v2"
 	"net"
@@ -328,3 +329,56 @@ func DialTCPConfig(address string, config *Config) (net.Conn, error) {
 	}
 	return conn, nil
 }
+
+func ReadWithContext(ctx context.Context, conn net.Conn, b []byte) (n int, err error) {
+	done := make(chan struct{})
+	stop := context.AfterFunc(ctx, func() {
+		_ = conn.SetReadDeadline(time.Now())
+		close(done)
+	})
+	n, err = conn.Read(b)
+	if !stop() {
+		<-done
+		_ = conn.SetReadDeadline(time.Time{})
+		if err == nil {
+			err = ctx.Err()
+		}
+		return n, err
+	}
+	return n, err
+}
+
+func WriteWithContext(ctx context.Context, conn net.Conn, b []byte) (n int, err error) {
+	done := make(chan struct{})
+	stop := context.AfterFunc(ctx, func() {
+		_ = conn.SetWriteDeadline(time.Now())
+		close(done)
+	})
+	n, err = conn.Write(b)
+	if !stop() {
+		<-done
+		_ = conn.SetWriteDeadline(time.Time{})
+		if err == nil {
+			err = ctx.Err()
+		}
+		return n, err
+	}
+	return n, err
+}
+
+type connWithContext struct {
+	ctx context.Context
+	net.Conn
+}
+
+func (c *connWithContext) Read(b []byte) (n int, err error) {
+	return ReadWithContext(c.ctx, c.Conn, b)
+}
+
+func (c *connWithContext) Write(b []byte) (n int, err error) {
+	return WriteWithContext(c.ctx, c.Conn, b)
+}
+
+func NewConnWithContext(ctx context.Context, conn net.Conn) net.Conn {
+	return &connWithContext{ctx: ctx, Conn: conn}
+}

+ 59 - 0
gnet/net_test.go

@@ -1,6 +1,7 @@
 package gnet
 
 import (
+	"context"
 	"errors"
 	"fmt"
 	"log"
@@ -289,3 +290,61 @@ func TestGetAvailableInterfaces(t *testing.T) {
 		}
 	}
 }
+
+func TestReadWithContext(t *testing.T) {
+	listener, err := net.Listen("tcp", "localhost:0")
+	if err != nil {
+		fmt.Println(err)
+		return
+	}
+	defer func() {
+		_ = listener.Close()
+	}()
+	conn, err := net.Dial(listener.Addr().Network(), listener.Addr().String())
+	if err != nil {
+		t.Error(err)
+		return
+	}
+	defer func() {
+		_ = conn.Close()
+	}()
+
+	ctx, cancel := context.WithCancel(context.Background())
+	go func() {
+		time.Sleep(2 * time.Second)
+		cancel()
+	}()
+
+	b := make([]byte, 1024)
+	n, err := ReadWithContext(ctx, conn, b)
+	t.Logf("ReadWithContext: %v, %v", n, err)
+}
+
+func TestNewConnWithContext(t *testing.T) {
+	listener, err := net.Listen("tcp", "localhost:0")
+	if err != nil {
+		fmt.Println(err)
+		return
+	}
+	defer func() {
+		_ = listener.Close()
+	}()
+	conn, err := net.Dial(listener.Addr().Network(), listener.Addr().String())
+	if err != nil {
+		t.Error(err)
+		return
+	}
+	defer func() {
+		_ = conn.Close()
+	}()
+
+	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+	defer cancel()
+
+	conn = NewConnWithContext(ctx, conn)
+
+	b := make([]byte, 1024)
+
+	n, err := conn.Read(b)
+	t.Logf("Read %d bytes, err: %v\n", n, err)
+}