| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196 | 
							- package hha
 
- import (
 
- 	"bytes"
 
- 	"context"
 
- 	"encoding/json"
 
- 	"math"
 
- 	"math/rand/v2"
 
- 	"net/http"
 
- 	"net/url"
 
- 	"sync"
 
- 	"time"
 
- )
 
- type Logger interface {
 
- 	Debug(f string, v ...any)
 
- }
 
- type Body struct {
 
- 	Alive   bool
 
- 	Address string
 
- }
 
- type HighAvailability struct {
 
- 	Body
 
- 	Timeout time.Duration
 
- 	Logger  Logger
 
- 	serverList []string
 
- 	path       string
 
- 	mu         sync.Mutex
 
- 	server     *http.Server
 
- }
 
- // uri: http://192.168.0.1 or https://192.168.0.1
 
- func New(address, path string, serverAddr []string) *HighAvailability {
 
- 	s := &HighAvailability{
 
- 		Timeout:    1500 * time.Millisecond,
 
- 		Logger:     &defaultLogger{},
 
- 		serverList: serverAddr,
 
- 		path:       path,
 
- 	}
 
- 	s.Address = address
 
- 	mux := http.NewServeMux()
 
- 	mux.Handle(path, s)
 
- 	uri, err := url.Parse(address)
 
- 	if err != nil {
 
- 		panic(err)
 
- 	}
 
- 	s.server = &http.Server{
 
- 		Addr:    uri.Host,
 
- 		Handler: mux,
 
- 	}
 
- 	return s
 
- }
 
- func (s *HighAvailability) Close() error {
 
- 	return s.server.Close()
 
- }
 
- func (s *HighAvailability) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 
- 	s.mu.Lock()
 
- 	defer s.mu.Unlock()
 
- 	switch r.Method {
 
- 	case http.MethodGet:
 
- 		if err := json.NewEncoder(w).Encode(s); err != nil {
 
- 			http.Error(w, err.Error(), http.StatusBadRequest)
 
- 			return
 
- 		}
 
- 	case http.MethodPost:
 
- 		var body Body
 
- 		if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
 
- 			http.Error(w, err.Error(), http.StatusBadRequest)
 
- 			return
 
- 		}
 
- 		if body.Address == s.Address {
 
- 			s.Alive = true
 
- 		}
 
- 	default:
 
- 		http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
 
- 	}
 
- }
 
- func (s *HighAvailability) Start(ctx context.Context) error {
 
- 	go s.checkServers(ctx)
 
- 	go s.sendHeartbeat(ctx)
 
- 	return s.server.ListenAndServe()
 
- }
 
- func (s *HighAvailability) checkServers(ctx context.Context) {
 
- 	timer := time.NewTimer(time.Duration(rand.IntN(math.MaxUint8)) * time.Millisecond)
 
- 	defer timer.Stop()
 
- 	for {
 
- 		select {
 
- 		case <-ctx.Done():
 
- 			return
 
- 		case <-timer.C:
 
- 			timer.Reset(time.Duration(rand.IntN(5)) * time.Second)
 
- 			allDead := true
 
- 			for _, server := range s.serverList {
 
- 				if server == s.Address {
 
- 					continue
 
- 				}
 
- 				alive, err := s.checkAlive(server)
 
- 				if err != nil {
 
- 					s.Logger.Debug("checkAlive err: %s", err)
 
- 					continue
 
- 				}
 
- 				if alive {
 
- 					allDead = false
 
- 					break
 
- 				}
 
- 			}
 
- 			if allDead && !s.Alive {
 
- 				s.mu.Lock()
 
- 				s.Alive = true
 
- 				s.mu.Unlock()
 
- 				s.Logger.Debug("checkAlive: No other server alive. setting alive now: %s", s.Address)
 
- 			}
 
- 		}
 
- 	}
 
- }
 
- func (s *HighAvailability) checkAlive(addr string) (bool, error) {
 
- 	client := http.Client{
 
- 		Timeout: s.Timeout,
 
- 	}
 
- 	resp, err := client.Get(addr + s.path)
 
- 	if err != nil {
 
- 		return false, err
 
- 	}
 
- 	defer func() {
 
- 		_ = resp.Body.Close()
 
- 	}()
 
- 	var other Body
 
- 	if err = json.NewDecoder(resp.Body).Decode(&other); err != nil {
 
- 		return false, err
 
- 	}
 
- 	return other.Alive, nil
 
- }
 
- func (s *HighAvailability) doRequest(ctx context.Context, address string) error {
 
- 	client := http.Client{
 
- 		Timeout: s.Timeout,
 
- 	}
 
- 	body := Body{
 
- 		Address: s.Address,
 
- 	}
 
- 	reqBody, err := json.Marshal(body)
 
- 	if err != nil {
 
- 		return err
 
- 	}
 
- 	req, err := http.NewRequestWithContext(ctx, http.MethodPost, address+s.path, bytes.NewReader(reqBody))
 
- 	if err != nil {
 
- 		return err
 
- 	}
 
- 	req.Header.Set("Content-Type", "application/json")
 
- 	_, err = client.Do(req)
 
- 	if err != nil {
 
- 		return err
 
- 	}
 
- 	return err
 
- }
 
- func (s *HighAvailability) sendHeartbeat(ctx context.Context) {
 
- 	for {
 
- 		select {
 
- 		case <-ctx.Done():
 
- 			return
 
- 		case <-time.After(1 * time.Second):
 
- 			s.mu.Lock()
 
- 			if !s.Alive {
 
- 				s.mu.Unlock()
 
- 				continue
 
- 			}
 
- 			s.mu.Unlock()
 
- 			for _, address := range s.serverList {
 
- 				if address == s.Address {
 
- 					continue
 
- 				}
 
- 				if err := s.doRequest(ctx, address); err != nil {
 
- 					s.Logger.Debug("sendHeartbeat: %s -> %s", err, address)
 
- 				}
 
- 			}
 
- 		}
 
- 	}
 
- }
 
 
  |