config.go 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. package config
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "log"
  7. "net"
  8. "os"
  9. "path/filepath"
  10. "strings"
  11. "SIMANC-WCS/app"
  12. "SIMANC-WCS/lib/server"
  13. )
  14. // 配置文件名称
  15. const configName = "config.json"
  16. // ReadFile 从数据目录中读取 name
  17. func ReadFile(name string) ([]byte, error) {
  18. fileName := filepath.Join(app.DataPath, name)
  19. _, err := os.Stat(fileName)
  20. if err != nil {
  21. return nil, err
  22. }
  23. b, err := os.ReadFile(fileName)
  24. if err != nil {
  25. log.Println(err)
  26. return nil, err
  27. }
  28. return b, nil
  29. }
  30. // Config 配置文件
  31. type Config struct {
  32. // Servers 服务器地址, 其中第一个元素作为主服务器, 剩余作为备用服务器
  33. Servers []string `json:"servers"`
  34. }
  35. func (c *Config) saveConfig(config Config) error {
  36. b, err := json.Marshal(config)
  37. if err != nil {
  38. return err
  39. }
  40. if err = os.MkdirAll(app.DataPath, os.ModePerm); err != nil {
  41. return err
  42. }
  43. dir := filepath.Join(app.DataPath, configName)
  44. return os.WriteFile(dir, b, os.ModePerm)
  45. }
  46. // GetConfig 获取配置文件
  47. func (c *Config) GetConfig() (Config, error) {
  48. b, err := ReadFile(configName)
  49. if err != nil {
  50. return Config{}, err
  51. }
  52. var config Config
  53. if err = json.Unmarshal(b, &config); err != nil {
  54. return Config{}, err
  55. }
  56. return config, nil
  57. }
  58. // SaveConfig 保存配置文件
  59. // 保存时需要校验所有配置项
  60. func (c *Config) SaveConfig(cfg Config) error {
  61. if len(cfg.Servers) < 1 {
  62. return errors.New("no servers")
  63. }
  64. for i, address := range cfg.Servers {
  65. if !strings.Contains(address, ":") {
  66. // 端口号不存在时添加默认端口
  67. address = net.JoinHostPort(address, "80")
  68. }
  69. _, _, err := net.SplitHostPort(address)
  70. if err != nil {
  71. return fmt.Errorf("invalid server address: %s", err)
  72. }
  73. cfg.Servers[i] = address
  74. }
  75. // 测试服务器是否联通
  76. if err := server.ConnectTest(app.Context, cfg.Servers); err != nil {
  77. return err
  78. }
  79. // 保存配置文件
  80. if err := c.saveConfig(cfg); err != nil {
  81. return err
  82. }
  83. // 更新内存数据
  84. c.Servers = cfg.Servers
  85. return nil
  86. }