From 61a73ba606bc827c80739b8ecafb37dc8dde72df Mon Sep 17 00:00:00 2001 From: Nicolas JUHEL Date: Mon, 6 Jan 2025 11:32:54 +0100 Subject: [PATCH] Package Certificates: - fix bug with cert type marshall/unmarshall - add old config to allow retro compatibility - add new type function to retrieve a tls root ca cert instead of a slice of string to get root ca Package HTTPCli: - fix default DNS Mapper - optimze global DNS Mapper - fix non closing sub goroutine Package HTTPCli/DNS-Mapper: - change request function of Root CA with function of root ca cert instance - add function to return a root ca cert from a function that return a slice of root ca string Package Config/Components: - httpcli: bump sub package of certificate, httpcli - httpcli: adjust code following bump - httpcli: change request function of Root CA with function of root ca cert instance - httpcli: add function to return a root ca cert from a function that return a slice of root ca string - tls: change request function of Root CA with function of root ca cert instance - tls: add function to return a root ca cert from a function that return a slice of root ca string Package IOUtils/mapCloser: - fix bug with mapcloser not stopped - optimize code & goroutine Package Logger: - rework mapCloser call - optimize mapClaoser managment Package Request: - rework error managment - using []byte instead of buffer to read response body - add free capability - optimize memory consumption Package Socket / Server: - add filtering error capability - add params to specify a function called on each new connection and before using the connection - the new function param allow to update the network incomming connection (like buffer, deadline...) - rework some useless atomic to direct value to optimize code Package Socket/Delim: - rework to optimize memory & variable use - remove capabilities of update the instance when running, prefert recreate new one if necessary Other: - bump dependencies - minor bug / fix --- certificates/ca/interface.go | 3 + certificates/ca/models.go | 24 ++++ certificates/certificates_suite_test.go | 28 ++--- certificates/certs/config.go | 107 ++++++++++++++++-- certificates/certs/encode.go | 97 +++++++++++++--- certificates/certs/format.go | 15 ++- certificates/certs/interface.go | 8 +- certificates/certs/models.go | 32 ++++++ certificates/cipher/format.go | 38 +++++++ certificates/cipher/interface.go | 65 +++++++++++ certificates/cipher/models.go | 29 +++++ certificates/config_old.go | 143 ++++++++++++++++++++++++ certificates/curves/models.go | 15 +++ certificates/interface.go | 2 + certificates/rootca.go | 9 ++ certificates/tlsversion/models.go | 15 +++ config/components/httpcli/dns.go | 8 ++ config/components/httpcli/interface.go | 32 ++++-- config/components/httpcli/model.go | 42 +++---- config/components/tls/client.go | 4 +- config/components/tls/interface.go | 19 +++- config/components/tls/model.go | 2 +- go.mod | 106 +++++++++--------- httpcli/cli.go | 36 ++---- httpcli/dns-mapper/config.go | 2 +- httpcli/dns-mapper/interface.go | 30 ++++- httpcli/dns-mapper/model.go | 56 +++++++--- httpcli/dns-mapper/transport.go | 8 +- ioutils/mapCloser/interface.go | 28 ++++- ioutils/mapCloser/model.go | 17 ++- logger/interface.go | 3 - logger/iowritecloser.go | 16 +-- logger/manage.go | 63 ++++++++--- pprof/tools.go | 2 + request/DoRequest.go | 52 +++++++++ request/error.go | 42 +++++-- request/interface.go | 9 ++ request/model.go | 9 ++ request/monitor.go | 31 ++--- request/request.go | 67 ++++++----- retro/retro_test.go | 8 +- socket/config/server.go | 4 +- socket/delim/interface.go | 32 +++--- socket/delim/io.go | 61 ++++++---- socket/delim/model.go | 76 +------------ socket/interface.go | 20 ++++ socket/server/interface_linux.go | 11 +- socket/server/interface_other.go | 7 +- socket/server/tcp/interface.go | 8 +- socket/server/tcp/listener.go | 20 ++-- socket/server/tcp/model.go | 28 ++--- socket/server/tcp/tcp_test.go | 4 +- socket/server/udp/interface.go | 8 +- socket/server/udp/listener.go | 11 +- socket/server/udp/model.go | 22 +--- socket/server/udp/udp_test.go | 2 +- socket/server/unix/interface.go | 8 +- socket/server/unix/listener.go | 20 ++-- socket/server/unix/model.go | 26 ++--- socket/server/unix/unix_test.go | 4 +- socket/server/unixgram/interface.go | 8 +- socket/server/unixgram/listener.go | 11 +- socket/server/unixgram/model.go | 22 +--- socket/server/unixgram/unixgram_test.go | 2 +- test/test-socket-server-tcp/main.go | 2 +- test/test-socket-server-udp/main.go | 2 +- test/test-socket-server-unix/main.go | 2 +- 67 files changed, 1218 insertions(+), 525 deletions(-) create mode 100644 certificates/config_old.go create mode 100644 request/DoRequest.go diff --git a/certificates/ca/interface.go b/certificates/ca/interface.go index 8b4d8a62..abeaed18 100644 --- a/certificates/ca/interface.go +++ b/certificates/ca/interface.go @@ -58,7 +58,10 @@ type Cert interface { cbor.Unmarshaler fmt.Stringer + Len() int AppendPool(p *x509.CertPool) + AppendBytes(p []byte) error + AppendString(str string) error } func Parse(str string) (Cert, error) { diff --git a/certificates/ca/models.go b/certificates/ca/models.go index eb09850e..5fcfd30d 100644 --- a/certificates/ca/models.go +++ b/certificates/ca/models.go @@ -37,6 +37,30 @@ type mod struct { c []*x509.Certificate } +func (m *mod) Len() int { + return len(m.c) +} + +func (m *mod) AppendBytes(p []byte) error { + c := &mod{ + c: make([]*x509.Certificate, 0), + } + + if e := c.unMarshall(p); e != nil { + return e + } + + for _, i := range c.c { + m.c = append(m.c, i) + } + + return nil +} + +func (m *mod) AppendString(str string) error { + return m.AppendBytes([]byte(str)) +} + func ViperDecoderHook() libmap.DecodeHookFuncType { return func(from reflect.Type, to reflect.Type, data interface{}) (interface{}, error) { var ( diff --git a/certificates/certificates_suite_test.go b/certificates/certificates_suite_test.go index 8c383d44..5bf277bd 100644 --- a/certificates/certificates_suite_test.go +++ b/certificates/certificates_suite_test.go @@ -27,9 +27,6 @@ package certificates_test import ( "os" - "path/filepath" - "reflect" - "strings" "testing" "time" @@ -44,32 +41,23 @@ import ( type EmptyStruct struct{} -var ( - keyFile string - pubFile string +const ( + keyFile = "test_ed25519.key" + pubFile = "test_ed25519.pub" ) // TestGolibEncodingAESHelper tests the Golib AES Encoding Helper function. -func TestGolibArchiveHelper(t *testing.T) { +func TestGolibCertificatesHelper(t *testing.T) { time.Sleep(500 * time.Millisecond) // Adding delay for better testing synchronization RegisterFailHandler(Fail) // Registering fail handler for better test failure reporting RunSpecs(t, "Certificates Helper Suite") // Running the test suite for Encoding AES Helper } -var _ = BeforeSuite(func() { - keyFile = filepath.Join(os.Getenv("GOPATH"), "src", strings.Replace(reflect.TypeOf(EmptyStruct{}).PkgPath(), "_test", "", -1), "test_ed25519.key") - pubFile = filepath.Join(os.Getenv("GOPATH"), "src", strings.Replace(reflect.TypeOf(EmptyStruct{}).PkgPath(), "_test", "", -1), "test_ed25519.pub") -}) - var _ = AfterSuite(func() { - if keyFile != "" { - if _, e := os.Stat(keyFile); e == nil { - Expect(os.Remove(keyFile)).ToNot(HaveOccurred()) - } + if _, e := os.Stat(keyFile); e == nil { + Expect(os.Remove(keyFile)).ToNot(HaveOccurred()) } - if pubFile != "" { - if _, e := os.Stat(pubFile); e == nil { - Expect(os.Remove(pubFile)).ToNot(HaveOccurred()) - } + if _, e := os.Stat(pubFile); e == nil { + Expect(os.Remove(pubFile)).ToNot(HaveOccurred()) } }) diff --git a/certificates/certs/config.go b/certificates/certs/config.go index fba64095..3f5877ca 100644 --- a/certificates/certs/config.go +++ b/certificates/certs/config.go @@ -27,6 +27,7 @@ package certs import ( + "bytes" "crypto" "crypto/ecdsa" "crypto/rsa" @@ -58,8 +59,28 @@ func cleanPem(s string) string { return strings.TrimSpace(s) } +func cleanPemByte(s []byte) []byte { + s = bytes.TrimSpace(s) + + // remove \n\r + s = bytes.Trim(s, "\n") + s = bytes.Trim(s, "\r") + + // do again if \r\n + s = bytes.Trim(s, "\n") + s = bytes.Trim(s, "\r") + + return bytes.TrimSpace(s) +} + type Config interface { Cert() (*tls.Certificate, error) + + IsChain() bool + IsPair() bool + + IsFile() bool + GetCerts() []string } type ConfigPair struct { @@ -68,34 +89,75 @@ type ConfigPair struct { } func (c *ConfigPair) Cert() (*tls.Certificate, error) { - c.Key = cleanPem(c.Key) - c.Pub = cleanPem(c.Pub) - if c == nil { return nil, ErrInvalidPairCertificate - } else if len(c.Key) < 1 || len(c.Pub) < 1 { + } + + var ( + k = cleanPemByte([]byte(c.Key)) + p = cleanPemByte([]byte(c.Pub)) + ) + + if len(k) < 1 || len(p) < 1 { return nil, ErrInvalidPairCertificate } - if _, e := os.Stat(c.Key); e == nil { - if b, e := os.ReadFile(c.Key); e == nil { - c.Key = cleanPem(string(b)) + if _, e := os.Stat(string(k)); e == nil { + if b, e := os.ReadFile(string(k)); e == nil { + k = cleanPemByte(b) } } - if _, e := os.Stat(c.Pub); e == nil { - if b, e := os.ReadFile(c.Pub); e == nil { - c.Pub = cleanPem(string(b)) + if _, e := os.Stat(string(p)); e == nil { + if b, e := os.ReadFile(string(p)); e == nil { + p = cleanPemByte(b) } } - if crt, err := tls.X509KeyPair([]byte(c.Pub), []byte(c.Key)); err != nil { + if crt, err := tls.X509KeyPair(p, k); err != nil { return nil, err } else { return &crt, nil } } +func (c *ConfigPair) IsChain() bool { + return false +} + +func (c *ConfigPair) IsPair() bool { + return true +} + +func (c *ConfigPair) IsFile() bool { + if c == nil { + return false + } + + var ( + k = cleanPemByte([]byte(c.Key)) + p = cleanPemByte([]byte(c.Pub)) + ) + + if len(k) < 1 || len(p) < 1 { + return false + } + + if _, e := os.Stat(string(k)); e == nil { + return true + } + + if _, e := os.Stat(string(p)); e == nil { + return true + } + + return false +} + +func (c *ConfigPair) GetCerts() []string { + return []string{c.Key, c.Pub} +} + type ConfigChain string func (c *ConfigChain) Cert() (*tls.Certificate, error) { @@ -163,3 +225,26 @@ func (c *ConfigChain) getPrivateKey(der []byte) (crypto.PrivateKey, error) { } return nil, ErrInvalidPrivateKey } +func (c *ConfigChain) IsChain() bool { + return true +} + +func (c *ConfigChain) IsPair() bool { + return false +} + +func (c *ConfigChain) IsFile() bool { + if c == nil { + return false + } + + if _, e := os.Stat(string(*c)); e == nil { + return true + } + + return false +} + +func (c *ConfigChain) GetCerts() []string { + return []string{string(*c)} +} diff --git a/certificates/certs/encode.go b/certificates/certs/encode.go index dc4ddcef..260a3f67 100644 --- a/certificates/certs/encode.go +++ b/certificates/certs/encode.go @@ -67,6 +67,7 @@ func (o *Certif) UnmarshalText(text []byte) error { } else if crt == nil || len(crt.Certificate) == 0 { return ErrInvalidPairCertificate } else { + o.g = &chn o.c = *crt return nil } @@ -81,8 +82,22 @@ func (o *Certif) UnmarshalBinary(data []byte) error { } func (o *Certif) MarshalJSON() ([]byte, error) { - t := o.String() - return json.Marshal(t) + var cfg any + + if o == nil || o.g == nil { + return []byte(""), nil + } else if p := o.g.GetCerts(); len(p) == 1 { + cfg = ConfigChain(o.g.GetCerts()[0]) + } else if len(p) == 2 { + cfg = ConfigPair{ + Key: p[0], + Pub: p[1], + } + } else { + cfg = o.g + } + + return json.Marshal(cfg) } func (o *Certif) UnmarshalJSON(bytes []byte) error { @@ -93,21 +108,23 @@ func (o *Certif) UnmarshalJSON(bytes []byte) error { err error ) - if err = json.Unmarshal(bytes, &cfg); err == nil { + if err = json.Unmarshal(bytes, &cfg); err == nil && len(cfg.Key) > 0 && len(cfg.Pub) > 0 { if crt, err = cfg.Cert(); err != nil { return err } else if crt == nil || len(crt.Certificate) == 0 { return ErrInvalidPairCertificate } else { + o.g = &cfg o.c = *crt return nil } - } else if err = json.Unmarshal(bytes, &chn); err == nil { + } else if err = json.Unmarshal(bytes, &chn); err == nil && len(chn) > 0 { if crt, err = chn.Cert(); err != nil { return err } else if crt == nil || len(crt.Certificate) == 0 { return ErrInvalidPairCertificate } else { + o.g = &chn o.c = *crt return nil } @@ -117,8 +134,22 @@ func (o *Certif) UnmarshalJSON(bytes []byte) error { } func (o *Certif) MarshalYAML() (interface{}, error) { - t := o.String() - return yaml.Marshal(t) + var cfg any + + if o == nil || o.g == nil { + return []byte(""), nil + } else if p := o.g.GetCerts(); len(p) == 1 { + cfg = ConfigChain(o.g.GetCerts()[0]) + } else if len(p) == 2 { + cfg = ConfigPair{ + Key: p[0], + Pub: p[1], + } + } else { + cfg = o.g + } + + return yaml.Marshal(cfg) } func (o *Certif) UnmarshalYAML(value *yaml.Node) error { @@ -130,21 +161,23 @@ func (o *Certif) UnmarshalYAML(value *yaml.Node) error { err error ) - if err = yaml.Unmarshal(src, &cfg); err == nil { + if err = yaml.Unmarshal(src, &cfg); err == nil && len(cfg.Key) > 0 && len(cfg.Pub) > 0 { if crt, err = cfg.Cert(); err != nil { return err } else if crt == nil || len(crt.Certificate) == 0 { return ErrInvalidPairCertificate } else { + o.g = &cfg o.c = *crt return nil } - } else if err = yaml.Unmarshal(src, &chn); err == nil { + } else if err = yaml.Unmarshal(src, &chn); err == nil && len(chn) > 0 { if crt, err = chn.Cert(); err != nil { return err } else if crt == nil || len(crt.Certificate) == 0 { return ErrInvalidPairCertificate } else { + o.g = &chn o.c = *crt return nil } @@ -154,8 +187,22 @@ func (o *Certif) UnmarshalYAML(value *yaml.Node) error { } func (o *Certif) MarshalTOML() ([]byte, error) { - t := o.String() - return toml.Marshal(t) + var cfg any + + if o == nil || o.g == nil { + return []byte(""), nil + } else if p := o.g.GetCerts(); len(p) == 1 { + cfg = ConfigChain(o.g.GetCerts()[0]) + } else if len(p) == 2 { + cfg = ConfigPair{ + Key: p[0], + Pub: p[1], + } + } else { + cfg = o.g + } + + return toml.Marshal(cfg) } func (o *Certif) UnmarshalTOML(i interface{}) error { @@ -184,21 +231,23 @@ func (o *Certif) UnmarshalTOML(i interface{}) error { err error ) - if err = toml.Unmarshal(p, &cfg); err == nil { + if err = toml.Unmarshal(p, &cfg); err == nil && len(cfg.Key) > 0 && len(cfg.Pub) > 0 { if crt, err = cfg.Cert(); err != nil { return err } else if crt == nil || len(crt.Certificate) == 0 { return ErrInvalidPairCertificate } else { + o.g = &cfg o.c = *crt return nil } - } else if err = toml.Unmarshal(p, &chn); err == nil { + } else if err = toml.Unmarshal(p, &chn); err == nil && len(chn) > 0 { if crt, err = chn.Cert(); err != nil { return err } else if crt == nil || len(crt.Certificate) == 0 { return ErrInvalidPairCertificate } else { + o.g = &chn o.c = *crt return nil } @@ -208,8 +257,22 @@ func (o *Certif) UnmarshalTOML(i interface{}) error { } func (o *Certif) MarshalCBOR() ([]byte, error) { - t := o.String() - return cbor.Marshal(t) + var cfg any + + if o == nil || o.g == nil { + return []byte(""), nil + } else if p := o.g.GetCerts(); len(p) == 1 { + cfg = ConfigChain(o.g.GetCerts()[0]) + } else if len(p) == 2 { + cfg = ConfigPair{ + Key: p[0], + Pub: p[1], + } + } else { + cfg = o.g + } + + return cbor.Marshal(cfg) } func (o *Certif) UnmarshalCBOR(bytes []byte) error { @@ -220,21 +283,23 @@ func (o *Certif) UnmarshalCBOR(bytes []byte) error { err error ) - if err = cbor.Unmarshal(bytes, &cfg); err == nil { + if err = cbor.Unmarshal(bytes, &cfg); err == nil && len(cfg.Key) > 0 && len(cfg.Pub) > 0 { if crt, err = cfg.Cert(); err != nil { return err } else if crt == nil || len(crt.Certificate) == 0 { return ErrInvalidPairCertificate } else { + o.g = &cfg o.c = *crt return nil } - } else if err = cbor.Unmarshal(bytes, &chn); err == nil { + } else if err = cbor.Unmarshal(bytes, &chn); err == nil && len(chn) > 0 { if crt, err = chn.Cert(); err != nil { return err } else if crt == nil || len(crt.Certificate) == 0 { return ErrInvalidPairCertificate } else { + o.g = &chn o.c = *crt return nil } diff --git a/certificates/certs/format.go b/certificates/certs/format.go index 031acb3f..e707fd54 100644 --- a/certificates/certs/format.go +++ b/certificates/certs/format.go @@ -34,11 +34,19 @@ import ( ) func (o *Certif) String() string { - str, _ := o.Chain() - return str + if o == nil { + return "" + } + + s, _ := o.Chain() + return cleanPem(s) } func (o *Certif) Pair() (pub string, key string, err error) { + if o == nil { + return "", "", ErrInvalidPairCertificate + } + var ( bufPub = bytes.NewBuffer(make([]byte, 0)) bufKey = bytes.NewBuffer(make([]byte, 0)) @@ -82,5 +90,8 @@ func (o *Certif) Chain() (string, error) { } func (o *Certif) TLS() tls.Certificate { + if o == nil { + return tls.Certificate{} + } return o.c } diff --git a/certificates/certs/interface.go b/certificates/certs/interface.go index d887e37c..a7bb7648 100644 --- a/certificates/certs/interface.go +++ b/certificates/certs/interface.go @@ -54,6 +54,12 @@ type Cert interface { TLS() tls.Certificate Model() Certif + + IsChain() bool + IsPair() bool + + IsFile() bool + GetCerts() []string } func Parse(chain string) (Cert, error) { @@ -71,6 +77,6 @@ func parseCert(cfg Config) (Cert, error) { } else if c == nil { return nil, ErrInvalidPairCertificate } else { - return &Certif{c: *c}, nil + return &Certif{g: cfg, c: *c}, nil } } diff --git a/certificates/certs/models.go b/certificates/certs/models.go index 357f65a0..9ac22592 100644 --- a/certificates/certs/models.go +++ b/certificates/certs/models.go @@ -34,6 +34,7 @@ import ( ) type Certif struct { + g Config c tls.Certificate } @@ -42,9 +43,40 @@ func (o *Certif) Cert() Cert { } func (o *Certif) Model() Certif { + if o == nil { + return Certif{} + } return *o } +func (o *Certif) IsChain() bool { + if o == nil { + return false + } + return o.g.IsChain() +} + +func (o *Certif) IsPair() bool { + if o == nil { + return false + } + return o.g.IsPair() +} + +func (o *Certif) IsFile() bool { + if o == nil { + return false + } + return o.g.IsFile() +} + +func (o *Certif) GetCerts() []string { + if o == nil { + return make([]string, 0) + } + return o.g.GetCerts() +} + func ViperDecoderHook() libmap.DecodeHookFuncType { return func(from reflect.Type, to reflect.Type, data interface{}) (interface{}, error) { var ( diff --git a/certificates/cipher/format.go b/certificates/cipher/format.go index a13dac8c..426bcf59 100644 --- a/certificates/cipher/format.go +++ b/certificates/cipher/format.go @@ -38,26 +38,64 @@ func (v Cipher) Code() []string { switch v { case TLS_RSA_WITH_AES_128_GCM_SHA256: return []string{"rsa", "aes", "128", "gcm", "sha256"} + case TLS_RSA_WITH_AES_128_GCM: + return []string{"rsa", "aes", "128", "gcm"} + case TLS_RSA_WITH_AES128_GCM: + return []string{"rsa", "aes128", "gcm"} case TLS_RSA_WITH_AES_256_GCM_SHA384: return []string{"rsa", "aes", "256", "gcm", "sha384"} + case TLS_RSA_WITH_AES_256_GCM: + return []string{"rsa", "aes", "256", "gcm"} + case TLS_RSA_WITH_AES256_GCM: + return []string{"rsa", "aes256", "gcm"} case TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: return []string{"ecdhe", "rsa", "aes", "128", "gcm", "sha256"} + case TLS_ECDHE_RSA_WITH_AES_128_GCM: + return []string{"ecdhe", "rsa", "aes", "128", "gcm"} + case TLS_ECDHE_RSA_WITH_AES128_GCM: + return []string{"ecdhe", "rsa", "aes128", "gcm"} case TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: return []string{"ecdhe", "ecdsa", "aes", "128", "gcm", "sha256"} + case TLS_ECDHE_ECDSA_WITH_AES_128_GCM: + return []string{"ecdhe", "ecdsa", "aes", "128", "gcm"} + case TLS_ECDHE_ECDSA_WITH_AES128_GCM: + return []string{"ecdhe", "ecdsa", "aes128", "gcm"} case TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384: return []string{"ecdhe", "rsa", "aes", "256", "gcm", "sha384"} + case TLS_ECDHE_RSA_WITH_AES_256_GCM: + return []string{"ecdhe", "rsa", "aes", "256", "gcm"} + case TLS_ECDHE_RSA_WITH_AES256_GCM: + return []string{"ecdhe", "rsa", "aes256", "gcm"} case TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: return []string{"ecdhe", "ecdsa", "aes", "256", "gcm", "sha384"} + case TLS_ECDHE_ECDSA_WITH_AES_256_GCM: + return []string{"ecdhe", "ecdsa", "aes", "256", "gcm"} + case TLS_ECDHE_ECDSA_WITH_AES256_GCM: + return []string{"ecdhe", "ecdsa", "aes256", "gcm"} case TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256: return []string{"ecdhe", "rsa", "chacha20", "poly1305", "sha256"} + case TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305: + return []string{"ecdhe", "rsa", "chacha20", "poly1305"} case TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256: return []string{"ecdhe", "ecdsa", "chacha20", "poly1305", "sha256"} + case TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305: + return []string{"ecdhe", "ecdsa", "chacha20", "poly1305"} case TLS_AES_128_GCM_SHA256: return []string{"aes", "128", "gcm", "sha256"} + case TLS_AES_128_GCM: + return []string{"aes", "128", "gcm"} + case TLS_AES128_GCM: + return []string{"aes128", "gcm"} case TLS_AES_256_GCM_SHA384: return []string{"aes", "256", "gcm", "sha384"} + case TLS_AES_256_GCM: + return []string{"aes", "256", "gcm"} + case TLS_AES256_GCM: + return []string{"aes256", "gcm"} case TLS_CHACHA20_POLY1305_SHA256: return []string{"chacha20", "poly1305", "sha256"} + case TLS_CHACHA20_POLY1305: + return []string{"chacha20", "poly1305"} default: return []string{} } diff --git a/certificates/cipher/interface.go b/certificates/cipher/interface.go index 584ab313..e1d925c5 100644 --- a/certificates/cipher/interface.go +++ b/certificates/cipher/interface.go @@ -53,6 +53,30 @@ const ( TLS_AES_256_GCM_SHA384 = Cipher(tls.TLS_AES_256_GCM_SHA384) TLS_CHACHA20_POLY1305_SHA256 = Cipher(tls.TLS_CHACHA20_POLY1305_SHA256) ) +const ( + // TLS 1.0 - 1.2 cipher suites no sha for retro compt + TLS_RSA_WITH_AES_128_GCM Cipher = iota + 1 + TLS_RSA_WITH_AES_256_GCM + TLS_ECDHE_RSA_WITH_AES_128_GCM + TLS_ECDHE_ECDSA_WITH_AES_128_GCM + TLS_ECDHE_RSA_WITH_AES_256_GCM + TLS_ECDHE_ECDSA_WITH_AES_256_GCM + TLS_RSA_WITH_AES128_GCM Cipher = iota + 1 + TLS_RSA_WITH_AES256_GCM + TLS_ECDHE_RSA_WITH_AES128_GCM + TLS_ECDHE_ECDSA_WITH_AES128_GCM + TLS_ECDHE_RSA_WITH_AES256_GCM + TLS_ECDHE_ECDSA_WITH_AES256_GCM + TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305 + TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305 + + // TLS 1.3 cipher suites retro compat + TLS_AES_128_GCM + TLS_AES_256_GCM + TLS_AES128_GCM + TLS_AES256_GCM + TLS_CHACHA20_POLY1305 +) func List() []Cipher { return []Cipher{ @@ -113,6 +137,47 @@ func Parse(s string) Cipher { return TLS_AES_128_GCM_SHA256 case containString(p, TLS_AES_256_GCM_SHA384.Code()): return TLS_AES_256_GCM_SHA384 + // retro compat + case containString(p, TLS_ECDHE_RSA_WITH_AES_128_GCM.Code()): + return TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 + case containString(p, TLS_ECDHE_ECDSA_WITH_AES_128_GCM.Code()): + return TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 + case containString(p, TLS_ECDHE_RSA_WITH_AES_256_GCM.Code()): + return TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 + case containString(p, TLS_ECDHE_ECDSA_WITH_AES_256_GCM.Code()): + return TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 + case containString(p, TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305.Code()): + return TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 + case containString(p, TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305.Code()): + return TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 + case containString(p, TLS_CHACHA20_POLY1305.Code()): + return TLS_CHACHA20_POLY1305_SHA256 + case containString(p, TLS_RSA_WITH_AES_128_GCM.Code()): + return TLS_RSA_WITH_AES_128_GCM_SHA256 + case containString(p, TLS_RSA_WITH_AES_256_GCM.Code()): + return TLS_RSA_WITH_AES_256_GCM_SHA384 + case containString(p, TLS_AES_128_GCM.Code()): + return TLS_AES_128_GCM_SHA256 + case containString(p, TLS_AES_256_GCM.Code()): + return TLS_AES_256_GCM_SHA384 + + case containString(p, TLS_ECDHE_RSA_WITH_AES128_GCM.Code()): + return TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 + case containString(p, TLS_ECDHE_ECDSA_WITH_AES128_GCM.Code()): + return TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 + case containString(p, TLS_ECDHE_RSA_WITH_AES256_GCM.Code()): + return TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 + case containString(p, TLS_ECDHE_ECDSA_WITH_AES256_GCM.Code()): + return TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 + case containString(p, TLS_RSA_WITH_AES128_GCM.Code()): + return TLS_RSA_WITH_AES_128_GCM_SHA256 + case containString(p, TLS_RSA_WITH_AES256_GCM.Code()): + return TLS_RSA_WITH_AES_256_GCM_SHA384 + case containString(p, TLS_AES128_GCM.Code()): + return TLS_AES_128_GCM_SHA256 + case containString(p, TLS_AES256_GCM.Code()): + return TLS_AES_256_GCM_SHA384 + // not found default: return Unknown } diff --git a/certificates/cipher/models.go b/certificates/cipher/models.go index efc3cf55..e2e1f22e 100644 --- a/certificates/cipher/models.go +++ b/certificates/cipher/models.go @@ -32,6 +32,35 @@ import ( libmap "github.com/mitchellh/mapstructure" ) +func (v Cipher) Check() bool { + switch v { + case TLS_RSA_WITH_AES_128_GCM_SHA256: + return true + case TLS_RSA_WITH_AES_256_GCM_SHA384: + return true + case TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: + return true + case TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: + return true + case TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384: + return true + case TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: + return true + case TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256: + return true + case TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256: + return true + case TLS_AES_128_GCM_SHA256: + return true + case TLS_AES_256_GCM_SHA384: + return true + case TLS_CHACHA20_POLY1305_SHA256: + return true + default: + return false + } +} + func ViperDecoderHook() libmap.DecodeHookFuncType { return func(from reflect.Type, to reflect.Type, data interface{}) (interface{}, error) { var ( diff --git a/certificates/config_old.go b/certificates/config_old.go new file mode 100644 index 00000000..f7003c37 --- /dev/null +++ b/certificates/config_old.go @@ -0,0 +1,143 @@ +/* + * MIT License + * + * Copyright (c) 2020 Nicolas JUHEL + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * + */ + +package certificates + +import ( + tlsaut "github.com/nabbar/golib/certificates/auth" + tlscas "github.com/nabbar/golib/certificates/ca" + tlscrt "github.com/nabbar/golib/certificates/certs" + tlscpr "github.com/nabbar/golib/certificates/cipher" + tlscrv "github.com/nabbar/golib/certificates/curves" + tlsvrs "github.com/nabbar/golib/certificates/tlsversion" +) + +type CertifOld struct { + Key string `mapstructure:"key" json:"key" yaml:"key" toml:"key"` + Pem string `mapstructure:"pem" json:"pem" yaml:"pem" toml:"pem"` +} + +type ConfigOld struct { + CurveList []string `mapstructure:"curveList" json:"curveList" yaml:"curveList" toml:"curveList"` + CipherList []string `mapstructure:"cipherList" json:"cipherList" yaml:"cipherList" toml:"cipherList"` + RootCAString []string `mapstructure:"rootCA" json:"rootCA" yaml:"rootCA" toml:"rootCA"` + RootCAFile []string `mapstructure:"rootCAFiles" json:"rootCAFiles" yaml:"rootCAFiles" toml:"rootCAFiles"` + ClientCAString []string `mapstructure:"clientCA" json:"clientCA" yaml:"clientCA" toml:"clientCA"` + ClientCAFiles []string `mapstructure:"clientCAFiles" json:"clientCAFiles" yaml:"clientCAFiles" toml:"clientCAFiles"` + CertPairString []CertifOld `mapstructure:"certPair" json:"certPair" yaml:"certPair" toml:"certPair"` + CertPairFile []CertifOld `mapstructure:"certPairFiles" json:"certPairFiles" yaml:"certPairFiles" toml:"certPairFiles"` + VersionMin string `mapstructure:"versionMin" json:"versionMin" yaml:"versionMin" toml:"versionMin"` + VersionMax string `mapstructure:"versionMax" json:"versionMax" yaml:"versionMax" toml:"versionMax"` + AuthClient string `mapstructure:"authClient" json:"authClient" yaml:"authClient" toml:"authClient"` + InheritDefault bool `mapstructure:"inheritDefault" json:"inheritDefault" yaml:"inheritDefault" toml:"inheritDefault"` + DynamicSizingDisable bool `mapstructure:"dynamicSizingDisable" json:"dynamicSizingDisable" yaml:"dynamicSizingDisable" toml:"dynamicSizingDisable"` + SessionTicketDisable bool `mapstructure:"sessionTicketDisable" json:"sessionTicketDisable" yaml:"sessionTicketDisable" toml:"sessionTicketDisable"` +} + +func (c *ConfigOld) ToConfig() Config { + var car tlscas.Cert + for _, v := range c.RootCAString { + if car == nil { + if i, e := tlscas.Parse(v); e == nil { + car = i + } + } else { + _ = car.AppendString(v) + } + } + + for _, v := range c.RootCAFile { + if car == nil { + if i, e := tlscas.Parse(v); e == nil { + car = i + } + } else { + _ = car.AppendString(v) + } + } + + var cac tlscas.Cert + for _, v := range c.ClientCAFiles { + if cac == nil { + if i, e := tlscas.Parse(v); e == nil { + cac = i + } + } else { + _ = cac.AppendString(v) + } + } + + for _, v := range c.ClientCAString { + if cac == nil { + if i, e := tlscas.Parse(v); e == nil { + cac = i + } + } else { + _ = cac.AppendString(v) + } + } + + var crt = make([]tlscrt.Certif, 0) + for _, v := range c.CertPairFile { + if i, e := tlscrt.ParsePair(v.Key, v.Pem); e == nil { + crt = append(crt, i.Model()) + } + } + + for _, v := range c.CertPairString { + if i, e := tlscrt.ParsePair(v.Key, v.Pem); e == nil { + crt = append(crt, i.Model()) + } + } + + cip := make([]tlscpr.Cipher, 0) + for _, v := range c.CipherList { + if i := tlscpr.Parse(v); i.Check() { + cip = append(cip, i) + } + } + + crv := make([]tlscrv.Curves, 0) + for _, v := range c.CurveList { + if i := tlscrv.Parse(v); i.Check() { + crv = append(crv, i) + } + } + + return Config{ + CurveList: crv, + CipherList: cip, + RootCA: append(make([]tlscas.Cert, 0), car), + ClientCA: append(make([]tlscas.Cert, 0), cac), + Certs: crt, + VersionMin: tlsvrs.Parse(c.VersionMin), + VersionMax: tlsvrs.Parse(c.VersionMax), + AuthClient: tlsaut.Parse(c.AuthClient), + InheritDefault: c.InheritDefault, + DynamicSizingDisable: c.DynamicSizingDisable, + SessionTicketDisable: c.SessionTicketDisable, + } +} diff --git a/certificates/curves/models.go b/certificates/curves/models.go index 3a29a6b7..538e32eb 100644 --- a/certificates/curves/models.go +++ b/certificates/curves/models.go @@ -32,6 +32,21 @@ import ( libmap "github.com/mitchellh/mapstructure" ) +func (v Curves) Check() bool { + switch v { + case X25519: + return true + case P256: + return true + case P384: + return true + case P521: + return true + default: + return false + } +} + func ViperDecoderHook() libmap.DecodeHookFuncType { return func(from reflect.Type, to reflect.Type, data interface{}) (interface{}, error) { var ( diff --git a/certificates/interface.go b/certificates/interface.go index 2d00949b..14804a9c 100644 --- a/certificates/interface.go +++ b/certificates/interface.go @@ -43,10 +43,12 @@ import ( type FctHttpClient func(def TLSConfig, servername string) *http.Client type FctTLSDefault func() TLSConfig type FctRootCA func() []string +type FctRootCACert func() tlscas.Cert type TLSConfig interface { RegisterRand(rand io.Reader) + AddRootCA(rootCA tlscas.Cert) bool AddRootCAString(rootCA string) bool AddRootCAFile(pemFile string) error GetRootCA() []tlscas.Cert diff --git a/certificates/rootca.go b/certificates/rootca.go index 88708850..35cbc43c 100644 --- a/certificates/rootca.go +++ b/certificates/rootca.go @@ -50,6 +50,15 @@ func (o *config) GetRootCAPool() *x509.CertPool { return res } +func (o *config) AddRootCA(rootCA tlscas.Cert) bool { + if rootCA != nil && rootCA.Len() > 0 { + o.caRoot = append(o.caRoot, rootCA) + return true + } + + return false +} + func (o *config) AddRootCAString(rootCA string) bool { if rootCA != "" { if c, e := tlscas.Parse(rootCA); e == nil { diff --git a/certificates/tlsversion/models.go b/certificates/tlsversion/models.go index 3d896ad2..d0d0527a 100644 --- a/certificates/tlsversion/models.go +++ b/certificates/tlsversion/models.go @@ -32,6 +32,21 @@ import ( libmap "github.com/mitchellh/mapstructure" ) +func (v Version) Check() bool { + switch v { + case VersionTLS10: + return true + case VersionTLS11: + return true + case VersionTLS12: + return true + case VersionTLS13: + return true + default: + return false + } +} + func ViperDecoderHook() libmap.DecodeHookFuncType { return func(from reflect.Type, to reflect.Type, data interface{}) (interface{}, error) { var ( diff --git a/config/components/httpcli/dns.go b/config/components/httpcli/dns.go index a690a379..cc68f570 100644 --- a/config/components/httpcli/dns.go +++ b/config/components/httpcli/dns.go @@ -35,6 +35,14 @@ import ( htcdns "github.com/nabbar/golib/httpcli/dns-mapper" ) +func (o *componentHttpClient) Close() error { + if d := o.getDNSMapper(); d != nil { + return d.Close() + } + + return nil +} + func (o *componentHttpClient) Add(from string, to string) { if d := o.getDNSMapper(); d != nil { d.Add(from, to) diff --git a/config/components/httpcli/interface.go b/config/components/httpcli/interface.go index 920ca0ef..da3fad86 100644 --- a/config/components/httpcli/interface.go +++ b/config/components/httpcli/interface.go @@ -29,7 +29,9 @@ package httpcli import ( "sync/atomic" + libatm "github.com/nabbar/golib/atomic" libtls "github.com/nabbar/golib/certificates" + tlscas "github.com/nabbar/golib/certificates/ca" libcfg "github.com/nabbar/golib/config" cfgtps "github.com/nabbar/golib/config/types" libctx "github.com/nabbar/golib/context" @@ -46,19 +48,33 @@ type ComponentHTTPClient interface { SetFuncMessage(f htcdns.FuncMessage) } -func New(ctx libctx.FuncContext, defCARoot libtls.FctRootCA, isDeftHTTPClient bool, msg htcdns.FuncMessage) ComponentHTTPClient { +func GetRootCaCert(fct libtls.FctRootCA) tlscas.Cert { + var res tlscas.Cert + + for _, c := range fct() { + if res == nil { + res, _ = tlscas.Parse(c) + } else { + _ = res.AppendString(c) + } + } + + return res +} + +func New(ctx libctx.FuncContext, defCARoot libtls.FctRootCACert, isDeftHTTPClient bool, msg htcdns.FuncMessage) ComponentHTTPClient { c := &componentHttpClient{ x: libctx.NewConfig[uint8](ctx), - c: new(atomic.Value), - d: new(atomic.Value), - f: new(atomic.Value), + c: libatm.NewValue[*htcdns.Config](), + d: libatm.NewValue[htcdns.DNSMapper](), + f: libatm.NewValue[libtls.FctRootCACert](), + m: libatm.NewValue[htcdns.FuncMessage](), s: new(atomic.Bool), - m: new(atomic.Value), } if defCARoot == nil { - defCARoot = func() []string { - return make([]string, 0) + defCARoot = func() tlscas.Cert { + return nil } } @@ -77,7 +93,7 @@ func Register(cfg libcfg.Config, key string, cpt ComponentHTTPClient) { cfg.ComponentSet(key, cpt) } -func RegisterNew(ctx libctx.FuncContext, cfg libcfg.Config, key string, defCARoot libtls.FctRootCA, isDeftHTTPClient bool, msg htcdns.FuncMessage) { +func RegisterNew(ctx libctx.FuncContext, cfg libcfg.Config, key string, defCARoot libtls.FctRootCACert, isDeftHTTPClient bool, msg htcdns.FuncMessage) { cfg.ComponentSet(key, New(ctx, defCARoot, isDeftHTTPClient, msg)) } diff --git a/config/components/httpcli/model.go b/config/components/httpcli/model.go index b5fc24e1..d0559f69 100644 --- a/config/components/httpcli/model.go +++ b/config/components/httpcli/model.go @@ -29,7 +29,9 @@ package httpcli import ( "sync/atomic" + libatm "github.com/nabbar/golib/atomic" libtls "github.com/nabbar/golib/certificates" + tlscas "github.com/nabbar/golib/certificates/ca" libctx "github.com/nabbar/golib/context" libhtc "github.com/nabbar/golib/httpcli" htcdns "github.com/nabbar/golib/httpcli/dns-mapper" @@ -37,32 +39,30 @@ import ( type componentHttpClient struct { x libctx.Config[uint8] - c *atomic.Value // htcdns.Config - d *atomic.Value // htcdns.DNSMapper - f *atomic.Value // FuncDefaultCARoot - s *atomic.Bool // is Default at start / update - m *atomic.Value // htcdns.FctMessage + + c libatm.Value[*htcdns.Config] // htcdns.Config + d libatm.Value[htcdns.DNSMapper] // htcdns.DNSMapper + f libatm.Value[libtls.FctRootCACert] // FuncDefaultCARoot + m libatm.Value[htcdns.FuncMessage] // htcdns.FctMessage + + s *atomic.Bool // is Default at start / update } -func (o *componentHttpClient) getRootCA() []string { +func (o *componentHttpClient) getRootCA() tlscas.Cert { if i := o.f.Load(); i == nil { - return make([]string, 0) - } else if v, k := i.(libtls.FctRootCA); !k { - return make([]string, 0) - } else if r := v(); len(r) < 1 { - return make([]string, 0) + return nil + } else if v := i(); v != nil && v.Len() < 1 { + return nil } else { - return r + return v } } func (o *componentHttpClient) getMessage() htcdns.FuncMessage { if i := o.m.Load(); i == nil { return nil - } else if v, k := i.(htcdns.FuncMessage); !k { - return nil } else { - return v + return i } } @@ -87,18 +87,22 @@ func (o *componentHttpClient) setDNSMapper(dns htcdns.DNSMapper) { defer o.SetDefault() } + var old htcdns.DNSMapper + if dns != nil { - o.d.Store(dns) + old = o.d.Swap(dns) + } + + if old != nil { + _ = old.Close() } } func (o *componentHttpClient) Config() htcdns.Config { if i := o.c.Load(); i == nil { return htcdns.Config{} - } else if v, k := i.(*htcdns.Config); !k { - return htcdns.Config{} } else { - return *v + return *i } } diff --git a/config/components/tls/client.go b/config/components/tls/client.go index f8842d9d..65f61af4 100644 --- a/config/components/tls/client.go +++ b/config/components/tls/client.go @@ -164,8 +164,8 @@ func (o *componentTls) _runCli() error { } else if tls = cfg.New(); tls == nil { return prt.Error(fmt.Errorf("cannot use tls config for new instance")) } else if o.f != nil { - for _, s := range o.f() { - tls.AddRootCAString(s) + if v := o.f(); v != nil && v.Len() > 0 { + tls.AddRootCA(v) } } diff --git a/config/components/tls/interface.go b/config/components/tls/interface.go index ca1c8211..85ff58f0 100644 --- a/config/components/tls/interface.go +++ b/config/components/tls/interface.go @@ -30,6 +30,7 @@ import ( "sync" libtls "github.com/nabbar/golib/certificates" + tlscas "github.com/nabbar/golib/certificates/ca" libcfg "github.com/nabbar/golib/config" cfgtps "github.com/nabbar/golib/config/types" libctx "github.com/nabbar/golib/context" @@ -42,7 +43,21 @@ type ComponentTlS interface { SetTLS(tls libtls.TLSConfig) } -func New(ctx libctx.FuncContext, defCARoot libtls.FctRootCA) ComponentTlS { +func GetRootCaCert(fct libtls.FctRootCA) tlscas.Cert { + var res tlscas.Cert + + for _, c := range fct() { + if res == nil { + res, _ = tlscas.Parse(c) + } else { + _ = res.AppendString(c) + } + } + + return res +} + +func New(ctx libctx.FuncContext, defCARoot libtls.FctRootCACert) ComponentTlS { return &componentTls{ m: sync.RWMutex{}, x: libctx.NewConfig[uint8](ctx), @@ -56,7 +71,7 @@ func Register(cfg libcfg.Config, key string, cpt ComponentTlS) { cfg.ComponentSet(key, cpt) } -func RegisterNew(ctx libctx.FuncContext, cfg libcfg.Config, key string, defCARoot libtls.FctRootCA) { +func RegisterNew(ctx libctx.FuncContext, cfg libcfg.Config, key string, defCARoot libtls.FctRootCACert) { cfg.ComponentSet(key, New(ctx, defCARoot)) } diff --git a/config/components/tls/model.go b/config/components/tls/model.go index 926c8db5..04c5da51 100644 --- a/config/components/tls/model.go +++ b/config/components/tls/model.go @@ -38,7 +38,7 @@ type componentTls struct { x libctx.Config[uint8] t libtls.TLSConfig c *libtls.Config - f libtls.FctRootCA + f libtls.FctRootCACert } func (o *componentTls) Config() *libtls.Config { diff --git a/go.mod b/go.mod index b6601fdd..3bef4ad2 100644 --- a/go.mod +++ b/go.mod @@ -6,20 +6,20 @@ toolchain go1.23.3 require ( github.com/aws/aws-sdk-go v1.55.5 - github.com/aws/aws-sdk-go-v2 v1.32.5 - github.com/aws/aws-sdk-go-v2/config v1.28.5 - github.com/aws/aws-sdk-go-v2/credentials v1.17.46 - github.com/aws/aws-sdk-go-v2/service/iam v1.38.1 - github.com/aws/aws-sdk-go-v2/service/s3 v1.69.0 + github.com/aws/aws-sdk-go-v2 v1.32.7 + github.com/aws/aws-sdk-go-v2/config v1.28.7 + github.com/aws/aws-sdk-go-v2/credentials v1.17.48 + github.com/aws/aws-sdk-go-v2/service/iam v1.38.3 + github.com/aws/aws-sdk-go-v2/service/s3 v1.72.0 github.com/aws/smithy-go v1.22.1 - github.com/bits-and-blooms/bitset v1.17.0 + github.com/bits-and-blooms/bitset v1.20.0 github.com/c-bata/go-prompt v0.2.6 github.com/dsnet/compress v0.0.1 github.com/fatih/color v1.18.0 github.com/fsnotify/fsnotify v1.8.0 github.com/fxamacker/cbor/v2 v2.7.0 github.com/gin-gonic/gin v1.10.0 - github.com/go-ldap/ldap/v3 v3.4.8 + github.com/go-ldap/ldap/v3 v3.4.10 github.com/go-playground/validator/v10 v10.23.0 github.com/google/go-github/v33 v33.0.0 github.com/hashicorp/go-hclog v1.6.3 @@ -31,13 +31,14 @@ require ( github.com/mattn/go-colorable v0.1.13 github.com/mitchellh/go-homedir v1.1.0 github.com/mitchellh/mapstructure v1.5.0 - github.com/nats-io/jwt/v2 v2.7.2 - github.com/nats-io/nats-server/v2 v2.10.22 - github.com/nats-io/nats.go v1.37.0 - github.com/onsi/ginkgo/v2 v2.22.0 - github.com/onsi/gomega v1.36.0 + github.com/nats-io/jwt/v2 v2.7.3 + github.com/nats-io/nats-server/v2 v2.10.24 + github.com/nats-io/nats.go v1.38.0 + github.com/onsi/ginkgo/v2 v2.22.2 + github.com/onsi/gomega v1.36.2 github.com/pelletier/go-toml v1.9.5 - github.com/pierrec/lz4/v4 v4.1.21 + github.com/pelletier/go-toml/v2 v2.2.3 + github.com/pierrec/lz4/v4 v4.1.22 github.com/prometheus/client_golang v1.20.5 github.com/shirou/gopsutil v3.21.11+incompatible github.com/sirupsen/logrus v1.9.3 @@ -47,18 +48,18 @@ require ( github.com/ugorji/go/codec v1.2.12 github.com/ulikunitz/xz v0.5.12 github.com/vbauerster/mpb/v8 v8.8.3 - github.com/xanzy/go-gitlab v0.114.0 + github.com/xanzy/go-gitlab v0.115.0 github.com/xhit/go-simple-mail v2.2.2+incompatible - golang.org/x/net v0.31.0 - golang.org/x/oauth2 v0.24.0 - golang.org/x/sync v0.9.0 - golang.org/x/sys v0.27.0 - golang.org/x/term v0.26.0 + golang.org/x/net v0.33.0 + golang.org/x/oauth2 v0.25.0 + golang.org/x/sync v0.10.0 + golang.org/x/sys v0.29.0 + golang.org/x/term v0.28.0 gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/clickhouse v0.6.1 gorm.io/driver/mysql v1.5.7 - gorm.io/driver/postgres v1.5.10 - gorm.io/driver/sqlite v1.5.6 + gorm.io/driver/postgres v1.5.11 + gorm.io/driver/sqlite v1.5.7 gorm.io/driver/sqlserver v1.5.4 gorm.io/gorm v1.25.12 ) @@ -71,32 +72,32 @@ require ( github.com/Masterminds/goutils v1.1.1 // indirect github.com/Masterminds/semver v1.5.0 // indirect github.com/Masterminds/sprig v2.22.0+incompatible // indirect - github.com/PuerkitoBio/goquery v1.10.0 // indirect + github.com/PuerkitoBio/goquery v1.10.1 // indirect github.com/VividCortex/ewma v1.2.0 // indirect github.com/acarl005/stripansi v0.0.0-20180116102854-5a71ef0e047d // indirect github.com/andybalholm/brotli v1.1.1 // indirect - github.com/andybalholm/cascadia v1.3.2 // indirect + github.com/andybalholm/cascadia v1.3.3 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.7 // indirect - github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.20 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.24 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.24 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.22 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.26 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.26 // indirect github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 // indirect - github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.24 // indirect + github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.26 // indirect github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.1 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.4.5 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.5 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.5 // indirect - github.com/aws/aws-sdk-go-v2/service/sso v1.24.6 // indirect - github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.5 // indirect - github.com/aws/aws-sdk-go-v2/service/sts v1.33.1 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.4.7 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.7 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.7 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.24.8 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.7 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.33.3 // indirect github.com/beorn7/perks v1.0.1 // indirect - github.com/bytedance/sonic v1.12.5 // indirect + github.com/bytedance/sonic v1.12.6 // indirect github.com/bytedance/sonic/loader v0.2.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cloudwego/base64x v0.1.4 // indirect github.com/cloudwego/iasm v0.2.0 // indirect - github.com/gabriel-vasile/mimetype v1.4.7 // indirect - github.com/gin-contrib/sse v0.1.0 // indirect + github.com/gabriel-vasile/mimetype v1.4.8 // indirect + github.com/gin-contrib/sse v1.0.0 // indirect github.com/go-asn1-ber/asn1-ber v1.5.7 // indirect github.com/go-faster/city v1.0.1 // indirect github.com/go-faster/errors v0.7.1 // indirect @@ -106,12 +107,12 @@ require ( github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-sql-driver/mysql v1.8.1 // indirect github.com/go-task/slim-sprig/v3 v3.0.0 // indirect - github.com/goccy/go-json v0.10.3 // indirect + github.com/goccy/go-json v0.10.4 // indirect github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/golang-sql/sqlexp v0.1.0 // indirect github.com/google/go-cmp v0.6.0 // indirect github.com/google/go-querystring v1.1.0 // indirect - github.com/google/pprof v0.0.0-20241122213907-cbe949e5a41b // indirect + github.com/google/pprof v0.0.0-20241210010833-40e02aabc2ad // indirect github.com/google/uuid v1.6.0 // indirect github.com/gorilla/css v1.0.1 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect @@ -123,7 +124,7 @@ require ( github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect - github.com/jackc/pgx/v5 v5.7.1 // indirect + github.com/jackc/pgx/v5 v5.7.2 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect github.com/jaytaylor/html2text v0.0.0-20230321000545-74c2419ad056 // indirect github.com/jinzhu/inflection v1.0.0 // indirect @@ -133,23 +134,22 @@ require ( github.com/klauspost/compress v1.17.11 // indirect github.com/klauspost/cpuid/v2 v2.2.9 // indirect github.com/leodido/go-urn v1.4.0 // indirect - github.com/magiconair/properties v1.8.7 // indirect + github.com/magiconair/properties v1.8.9 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-runewidth v0.0.16 // indirect github.com/mattn/go-sqlite3 v1.14.24 // indirect github.com/mattn/go-tty v0.0.7 // indirect - github.com/microsoft/go-mssqldb v1.7.2 // indirect + github.com/microsoft/go-mssqldb v1.8.0 // indirect github.com/minio/highwayhash v1.0.3 // indirect github.com/mitchellh/copystructure v1.2.0 // indirect github.com/mitchellh/reflectwalk v1.0.2 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect - github.com/nats-io/nkeys v0.4.8 // indirect + github.com/nats-io/nkeys v0.4.9 // indirect github.com/nats-io/nuid v1.0.1 // indirect github.com/olekukonko/tablewriter v0.0.5 // indirect github.com/paulmach/orb v0.11.1 // indirect - github.com/pelletier/go-toml/v2 v2.2.3 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pkg/term v1.2.0-beta.2 // indirect github.com/prometheus/client_model v0.6.1 // indirect @@ -163,7 +163,7 @@ require ( github.com/shopspring/decimal v1.4.0 // indirect github.com/sourcegraph/conc v0.3.0 // indirect github.com/spf13/afero v1.11.0 // indirect - github.com/spf13/cast v1.7.0 // indirect + github.com/spf13/cast v1.7.1 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/ssor/bom v0.0.0-20170718123548-6386211fdfcf // indirect github.com/subosito/gotenv v1.6.0 // indirect @@ -172,15 +172,15 @@ require ( github.com/vanng822/go-premailer v1.22.0 // indirect github.com/x448/float16 v0.8.4 // indirect github.com/yusufpapurcu/wmi v1.2.3 // indirect - go.opentelemetry.io/otel v1.32.0 // indirect - go.opentelemetry.io/otel/trace v1.32.0 // indirect + go.opentelemetry.io/otel v1.33.0 // indirect + go.opentelemetry.io/otel/trace v1.33.0 // indirect go.uber.org/multierr v1.11.0 // indirect - golang.org/x/arch v0.12.0 // indirect - golang.org/x/crypto v0.29.0 // indirect - golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f // indirect - golang.org/x/text v0.20.0 // indirect - golang.org/x/time v0.8.0 // indirect - golang.org/x/tools v0.27.0 // indirect - google.golang.org/protobuf v1.35.2 // indirect + golang.org/x/arch v0.13.0 // indirect + golang.org/x/crypto v0.31.0 // indirect + golang.org/x/exp v0.0.0-20250103183323-7d7fa50e5329 // indirect + golang.org/x/text v0.21.0 // indirect + golang.org/x/time v0.9.0 // indirect + golang.org/x/tools v0.28.0 // indirect + google.golang.org/protobuf v1.36.1 // indirect gopkg.in/ini.v1 v1.67.0 // indirect ) diff --git a/httpcli/cli.go b/httpcli/cli.go index c1858f13..d78e0bfb 100644 --- a/httpcli/cli.go +++ b/httpcli/cli.go @@ -29,9 +29,9 @@ package httpcli import ( "context" "net/http" - "sync/atomic" "time" + libatm "github.com/nabbar/golib/atomic" libtls "github.com/nabbar/golib/certificates" libdur "github.com/nabbar/golib/duration" htcdns "github.com/nabbar/golib/httpcli/dns-mapper" @@ -41,20 +41,10 @@ const ( ClientTimeout5Sec = 5 * time.Second ) -var ( - dns = new(atomic.Value) - ctx context.Context - cnl context.CancelFunc -) +var dns = libatm.NewValue[htcdns.DNSMapper]() func initDNSMapper() htcdns.DNSMapper { - if cnl != nil { - cnl() - } - - ctx, cnl = context.WithCancel(context.Background()) - - return htcdns.New(ctx, &htcdns.Config{ + return htcdns.New(context.Background(), &htcdns.Config{ DNSMapper: make(map[string]string), TimerClean: libdur.ParseDuration(3 * time.Minute), Transport: htcdns.TransportConfig{ @@ -77,17 +67,11 @@ func initDNSMapper() htcdns.DNSMapper { } func DefaultDNSMapper() htcdns.DNSMapper { - if i := dns.Load(); i == nil { - d := initDNSMapper() - dns.Store(d) - return d - } else if d, k := i.(htcdns.DNSMapper); !k { - d = initDNSMapper() - dns.Store(d) - return d - } else { - return d + if dns.Load() == nil { + SetDefaultDNSMapper(initDNSMapper()) } + + return dns.Load() } func SetDefaultDNSMapper(d htcdns.DNSMapper) { @@ -95,11 +79,9 @@ func SetDefaultDNSMapper(d htcdns.DNSMapper) { return } - if cnl != nil { - cnl() + if o := dns.Swap(d); o != nil { + _ = o.Close() } - - dns.Store(d) } type FctHttpClient func() *http.Client diff --git a/httpcli/dns-mapper/config.go b/httpcli/dns-mapper/config.go index c87d0520..c04589f2 100644 --- a/httpcli/dns-mapper/config.go +++ b/httpcli/dns-mapper/config.go @@ -121,6 +121,6 @@ func (o Config) Validate() liberr.Error { return e } -func (o Config) New(ctx context.Context, fct libtls.FctRootCA, msg FuncMessage) DNSMapper { +func (o Config) New(ctx context.Context, fct libtls.FctRootCACert, msg FuncMessage) DNSMapper { return New(ctx, &o, fct, msg) } diff --git a/httpcli/dns-mapper/interface.go b/httpcli/dns-mapper/interface.go index 81c25caf..df959c2d 100644 --- a/httpcli/dns-mapper/interface.go +++ b/httpcli/dns-mapper/interface.go @@ -31,10 +31,11 @@ import ( "net" "net/http" "sync" - "sync/atomic" "time" + libatm "github.com/nabbar/golib/atomic" libtls "github.com/nabbar/golib/certificates" + tlscas "github.com/nabbar/golib/certificates/ca" libdur "github.com/nabbar/golib/duration" ) @@ -58,9 +59,24 @@ type DNSMapper interface { DefaultClient() *http.Client TimeCleaner(ctx context.Context, dur time.Duration) + Close() error } -func New(ctx context.Context, cfg *Config, fct libtls.FctRootCA, msg FuncMessage) DNSMapper { +func GetRootCaCert(fct libtls.FctRootCA) tlscas.Cert { + var res tlscas.Cert + + for _, c := range fct() { + if res == nil { + res, _ = tlscas.Parse(c) + } else { + _ = res.AppendString(c) + } + } + + return res +} + +func New(ctx context.Context, cfg *Config, fct libtls.FctRootCACert, msg FuncMessage) DNSMapper { if cfg == nil { cfg = &Config{ DNSMapper: make(map[string]string), @@ -73,8 +89,8 @@ func New(ctx context.Context, cfg *Config, fct libtls.FctRootCA, msg FuncMessage } if fct == nil { - fct = func() []string { - return make([]string, 0) + fct = func() tlscas.Cert { + return nil } } @@ -85,10 +101,12 @@ func New(ctx context.Context, cfg *Config, fct libtls.FctRootCA, msg FuncMessage d := &dmp{ d: new(sync.Map), z: new(sync.Map), - c: new(atomic.Value), - t: new(atomic.Value), + c: libatm.NewValue[*Config](), + t: libatm.NewValue[*http.Transport](), f: fct, i: msg, + n: libatm.NewValue[context.CancelFunc](), + x: libatm.NewValue[context.Context](), } for edp, adr := range cfg.DNSMapper { diff --git a/httpcli/dns-mapper/model.go b/httpcli/dns-mapper/model.go index 02c9033f..6405c58b 100644 --- a/httpcli/dns-mapper/model.go +++ b/httpcli/dns-mapper/model.go @@ -28,28 +28,37 @@ package dns_mapper import ( "context" + "net/http" "sync" - "sync/atomic" "time" + libatm "github.com/nabbar/golib/atomic" libtls "github.com/nabbar/golib/certificates" ) type dmp struct { d *sync.Map z *sync.Map - c *atomic.Value // *Config - t *atomic.Value // *http transport - f libtls.FctRootCA + c libatm.Value[*Config] + t libatm.Value[*http.Transport] + n libatm.Value[context.CancelFunc] + x libatm.Value[context.Context] + f libtls.FctRootCACert i func(msg string) } +func (o *dmp) Close() error { + if i := o.n.Swap(func() {}); i != nil { + i() + } + + return nil +} + func (o *dmp) config() *Config { var cfg = &Config{} - if i := o.c.Load(); i == nil { - return cfg - } else if c, k := i.(*Config); !k { + if c := o.c.Load(); c == nil { return cfg } else { *cfg = *c @@ -82,20 +91,35 @@ func (o *dmp) TimeCleaner(ctx context.Context, dur time.Duration) { dur = 5 * time.Minute } + var ( + x context.Context + n context.CancelFunc + ) + + if ctx != nil { + x, n = context.WithCancel(ctx) + o.x.Store(x) + if i := o.n.Swap(n); i != nil { + i() + } + } + go func() { - var tck = time.NewTicker(dur) - defer tck.Stop() + var ( + tk = time.NewTicker(dur) + cx = o.x.Load() + ) - for { - if ctx.Err() != nil { - return - } + defer func() { + tk.Stop() + }() + for { select { - case <-tck.C: - o.DefaultTransport().CloseIdleConnections() - case <-ctx.Done(): + case <-cx.Done(): return + case <-tk.C: + o.DefaultTransport().CloseIdleConnections() } } }() diff --git a/httpcli/dns-mapper/transport.go b/httpcli/dns-mapper/transport.go index 65d996b7..e3ef6123 100644 --- a/httpcli/dns-mapper/transport.go +++ b/httpcli/dns-mapper/transport.go @@ -87,8 +87,8 @@ func (o *dmp) Transport(cfg TransportConfig) *http.Transport { ssl.SetVersionMax(tls.VersionTLS13) } - for _, c := range o.f() { - ssl.AddRootCAString(c) + if v := o.f(); v != nil && v.Len() > 0 { + ssl.AddRootCA(v) } if cfg.TimeoutGlobal == 0 { @@ -150,9 +150,7 @@ func (o *dmp) Client(cfg TransportConfig) *http.Client { func (o *dmp) DefaultTransport() *http.Transport { i := o.t.Load() if i != nil { - if t, k := i.(*http.Transport); k { - return t - } + return i } t := o.Transport(o.config().Transport) diff --git a/ioutils/mapCloser/interface.go b/ioutils/mapCloser/interface.go index e9ce8baa..69a68072 100644 --- a/ioutils/mapCloser/interface.go +++ b/ioutils/mapCloser/interface.go @@ -28,8 +28,10 @@ package mapCloser import ( + "context" "io" "sync/atomic" + "time" libctx "github.com/nabbar/golib/context" ) @@ -44,19 +46,33 @@ type Closer interface { Close() error } -func New(ctx libctx.FuncContext) Closer { +func New(ctx context.Context) Closer { + var ( + x, n = context.WithCancel(ctx) + fx = func() context.Context { + return x + } + ) + c := &closer{ + f: n, i: new(atomic.Uint64), - x: libctx.NewConfig[uint64](ctx), + c: new(atomic.Bool), + x: libctx.NewConfig[uint64](fx), } + c.c.Store(false) c.i.Store(0) go func() { - select { - case <-c.x.Done(): - _ = c.Close() - return + for !c.c.Load() { + select { + case <-c.x.Done(): + _ = c.Close() + return + default: + time.Sleep(time.Millisecond * 100) + } } }() diff --git a/ioutils/mapCloser/model.go b/ioutils/mapCloser/model.go index 7a6ce367..fb93ec31 100644 --- a/ioutils/mapCloser/model.go +++ b/ioutils/mapCloser/model.go @@ -38,6 +38,8 @@ import ( ) type closer struct { + c *atomic.Bool + f func() // Context Func Cancel i *atomic.Uint64 x libctx.Config[uint64] } @@ -130,7 +132,12 @@ func (o *closer) Clone() Closer { i := new(atomic.Uint64) i.Store(o.idx()) + c := new(atomic.Bool) + c.Store(o.c.Load()) + return &closer{ + c: c, + f: o.f, i: i, x: o.x.Clone(nil), } @@ -141,7 +148,15 @@ func (o *closer) Close() error { if o == nil { return fmt.Errorf("not initialized") - } else if o.x == nil { + } + + o.c.Store(true) + + if o.f != nil { + defer o.f() + } + + if o.x == nil { return fmt.Errorf("not initialized") } else if o.x.Err() != nil { return o.x.Err() diff --git a/logger/interface.go b/logger/interface.go index b6d026fb..ff75924d 100644 --- a/logger/interface.go +++ b/logger/interface.go @@ -34,8 +34,6 @@ import ( "sync/atomic" "time" - iotclo "github.com/nabbar/golib/ioutils/mapCloser" - libctx "github.com/nabbar/golib/context" logcfg "github.com/nabbar/golib/logger/config" logent "github.com/nabbar/golib/logger/entry" @@ -133,7 +131,6 @@ func New(ctx libctx.FuncContext) Logger { c: new(atomic.Value), } - l.c.Store(iotclo.New(ctx)) l.SetLevel(loglvl.InfoLevel) return l diff --git a/logger/iowritecloser.go b/logger/iowritecloser.go index 04479043..517b560f 100644 --- a/logger/iowritecloser.go +++ b/logger/iowritecloser.go @@ -35,21 +35,11 @@ import ( ) func (o *logger) Close() error { - if o == nil { - return nil - } - - c := o.newCloser() - if c == nil { - return nil - } - - s := o.switchCloser(c) - if s == nil { - return nil + if o != nil && o.hasCloser() { + o.switchCloser(nil) } - return s.Close() + return nil } func (o *logger) Write(p []byte) (n int, err error) { diff --git a/logger/manage.go b/logger/manage.go index 552d545b..0ce43070 100644 --- a/logger/manage.go +++ b/logger/manage.go @@ -52,16 +52,31 @@ func (o *logger) getCloser() iotclo.Closer { return c } - c := iotclo.New(o.x.GetContext) + c := o.newCloser() o.c.Store(c) return c } -func (o *logger) switchCloser(c iotclo.Closer) iotclo.Closer { - b := o.getCloser() - o.c.Store(c) - return b +func (o *logger) switchCloser(c iotclo.Closer) { + if o == nil { + return + } else if c == nil { + c = o.newCloser() + } + + i := o.c.Swap(c) + + if i == nil { + return + } else if v, k := i.(iotclo.Closer); k && v != nil { + go func() { + // temp waiting all still calling log finish + time.Sleep(10 * time.Second) + _ = v.Close() + v = nil + }() + } } func (o *logger) newCloser() iotclo.Closer { @@ -69,7 +84,21 @@ func (o *logger) newCloser() iotclo.Closer { return nil } - return iotclo.New(o.x.GetContext) + return iotclo.New(o.x.GetContext()) +} + +func (o *logger) hasCloser() bool { + if o == nil || o.x == nil { + return false + } + + if i := o.c.Load(); i != nil { + if _, k := i.(iotclo.Closer); k { + return true + } + } + + return false } func (o *logger) Clone() Logger { @@ -222,24 +251,24 @@ func (o *logger) SetOptions(opt *logcfg.Options) error { } } - var clo = o.newCloser() + if len(hkl) > 0 { + var clo = o.newCloser() + + for _, h := range hkl { + clo.Add(h) + h.RegisterHook(obj) + go h.Run(o.x.GetContext()) + } - for _, h := range hkl { - clo.Add(h) - h.RegisterHook(obj) - go h.Run(o.x.GetContext()) + o.switchCloser(clo) + } else if o.hasCloser() { + o.switchCloser(nil) } o.x.Store(keyOptions, opt) o.x.Store(keyLogrus, obj) o.runFuncUpdateLogger() - clo = o.switchCloser(clo) - go func(c iotclo.Closer) { - time.Sleep(3 * time.Second) - _ = c.Close() - }(clo) - return nil } diff --git a/pprof/tools.go b/pprof/tools.go index f25e524f..7e6284fc 100644 --- a/pprof/tools.go +++ b/pprof/tools.go @@ -33,6 +33,7 @@ import ( "os" "path/filepath" "runtime" + "runtime/debug" "runtime/pprof" "time" @@ -132,6 +133,7 @@ func ProfilingMemRun(ctx context.Context, tck *time.Ticker) error { }() runtime.GC() + debug.FreeOSMemory() if e = pprof.WriteHeapProfile(h); e != nil { return e diff --git a/request/DoRequest.go b/request/DoRequest.go new file mode 100644 index 00000000..82782511 --- /dev/null +++ b/request/DoRequest.go @@ -0,0 +1,52 @@ +/* + * MIT License + * + * Copyright (c) 2022 Nicolas JUHEL + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * + */ + +package request + +type DoRequestOptions struct { + Retry int + Checksum256 []byte + ValidStatusCodes []int + Model interface{} +} + +func (r *request) DoParse(model interface{}, validStatus ...int) error { + return r.DoWithOptions(&DoRequestOptions{ + Retry: 0, + Checksum256: nil, + ValidStatusCodes: validStatus, + Model: model, + }) +} + +func (r *request) DoParseRetry(retry int, model interface{}, validStatus ...int) error { + return r.DoWithOptions(&DoRequestOptions{ + Retry: retry, + Checksum256: nil, + ValidStatusCodes: validStatus, + Model: model, + }) +} diff --git a/request/error.go b/request/error.go index 275d047f..425e6254 100644 --- a/request/error.go +++ b/request/error.go @@ -35,8 +35,10 @@ type requestError struct { code int status string statusErr bool - bufBody *bytes.Buffer + bufBody []byte bodyErr bool + checksum []byte + isSame bool err error } @@ -49,7 +51,16 @@ func (r *requestError) Status() string { } func (r *requestError) Body() *bytes.Buffer { - return r.bufBody + if len(r.bufBody) > 0 { + b := make([]byte, 0, len(r.bufBody)) + copy(b, r.bufBody) + return bytes.NewBuffer(b) + } + return bytes.NewBuffer(make([]byte, 0)) +} + +func (r *requestError) CheckSum() []byte { + return r.checksum } func (r *requestError) Error() error { @@ -68,9 +79,13 @@ func (r *requestError) IsBodyError() bool { return r.bodyErr } +func (r *requestError) IsBodySame() bool { + return r.isSame +} + func (r *requestError) ParseBody(i interface{}) bool { - if r.bufBody != nil && r.bufBody.Len() > 0 { - if e := json.Unmarshal(r.bufBody.Bytes(), i); e == nil { + if len(r.bufBody) > 0 { + if e := json.Unmarshal(r.bufBody, i); e == nil { return true } } @@ -78,13 +93,22 @@ func (r *requestError) ParseBody(i interface{}) bool { return false } +func (r *requestError) Free() { + if len(r.bufBody) > 0 { + r.bufBody = r.bufBody[:0] + } + r.bufBody = nil +} + func (r *request) newError() { r.err.Store(&requestError{ code: 0, status: "", statusErr: false, - bufBody: bytes.NewBuffer(make([]byte, 0)), + bufBody: nil, bodyErr: false, + checksum: nil, + isSame: false, err: nil, }) } @@ -100,8 +124,10 @@ func (r *request) getError() *requestError { code: 0, status: "", statusErr: false, - bufBody: bytes.NewBuffer(make([]byte, 0)), + bufBody: nil, bodyErr: false, + checksum: nil, + isSame: false, err: nil, } } @@ -112,8 +138,10 @@ func (r *request) setError(e *requestError) { code: 0, status: "", statusErr: false, - bufBody: bytes.NewBuffer(make([]byte, 0)), + bufBody: nil, bodyErr: false, + checksum: nil, + isSame: false, err: nil, } } diff --git a/request/interface.go b/request/interface.go index 678c666c..18be08e7 100644 --- a/request/interface.go +++ b/request/interface.go @@ -46,11 +46,13 @@ type Error interface { StatusCode() int Status() string Body() *bytes.Buffer + CheckSum() []byte Error() error IsError() bool IsStatusError() bool IsBodyError() bool + IsBodySame() bool ParseBody(i interface{}) bool } @@ -105,6 +107,7 @@ type Request interface { Clone() (Request, error) New() (Request, error) + Free() GetOption() *Options SetOption(opt *Options) error @@ -116,7 +119,13 @@ type Request interface { IsError() bool Do() (*http.Response, error) + DoWithOptions(opt *DoRequestOptions) error + + // DoParse + //Deprecated: use DoWithOptions instead of DoParse DoParse(model interface{}, validStatus ...int) error + // DoParseRetry + //Deprecated: use DoWithOptions instead of DoParseRetry DoParseRetry(retry int, model interface{}, validStatus ...int) error Monitor(ctx context.Context, vrs libver.Version) (montps.Monitor, error) diff --git a/request/model.go b/request/model.go index 40e1e3df..14ac084b 100644 --- a/request/model.go +++ b/request/model.go @@ -128,6 +128,15 @@ func (r *request) New() (Request, error) { return n, nil } +func (r *request) Free() { + if i := r.err.Load(); i != nil { + if v, k := i.(*requestError); !k { + v.Free() + r.err.Store(&requestError{}) + } + } +} + func (r *request) RegisterDefaultLogger(fct liblog.FuncLog) { if fct == nil { fct = func() liblog.Logger { diff --git a/request/monitor.go b/request/monitor.go index 8b073704..0587ee2c 100644 --- a/request/monitor.go +++ b/request/monitor.go @@ -27,7 +27,6 @@ package request import ( - "bytes" "context" "encoding/base64" "fmt" @@ -76,13 +75,19 @@ func (r *request) HealthCheck(ctx context.Context) error { } var ( - e error err error - buf *bytes.Buffer + buf []byte req *http.Request rsp *http.Response ) + defer func() { + if len(buf) > 0 { + buf = buf[:0] + buf = nil + } + }() + if ent != nil { ent.FieldAdd("endpoint", ednp) ent.FieldAdd("method", http.MethodGet) @@ -92,23 +97,23 @@ func (r *request) HealthCheck(ctx context.Context) error { if err != nil { if ent != nil { - ent.ErrorAdd(true, err).Check(loglvl.NilLevel) + ent.ErrorAdd(true, err).Log() } return err } - rsp, e = r.client().Do(req) - if e != nil { - err = ErrorSendRequest.Error(e) + rsp, err = r.client().Do(req) + if err != nil { + err = ErrorSendRequest.Error(err) if ent != nil { - ent.ErrorAdd(true, err).Check(loglvl.NilLevel) + ent.ErrorAdd(true, err).Log() } return err } if buf, err = r.checkResponse(rsp); err != nil { if ent != nil { - ent.ErrorAdd(true, err).Check(loglvl.NilLevel) + ent.ErrorAdd(true, err).Log() } return err } @@ -117,7 +122,7 @@ func (r *request) HealthCheck(ctx context.Context) error { if !r.isValidCode(opts.Health.Result.ValidHTTPCode, rsp.StatusCode) { err = ErrorResponseStatus.Error(fmt.Errorf("status: %s", rsp.Status)) if ent != nil { - ent.ErrorAdd(true, err).Check(loglvl.NilLevel) + ent.ErrorAdd(true, err).Log() } return err } @@ -125,7 +130,7 @@ func (r *request) HealthCheck(ctx context.Context) error { if r.isValidCode(opts.Health.Result.InvalidHTTPCode, rsp.StatusCode) { err = ErrorResponseStatus.Error(fmt.Errorf("status: %s", rsp.Status)) if ent != nil { - ent.ErrorAdd(true, err).Check(loglvl.NilLevel) + ent.ErrorAdd(true, err).Log() } return err } @@ -135,7 +140,7 @@ func (r *request) HealthCheck(ctx context.Context) error { if !r.isValidContents(opts.Health.Result.Contain, buf) { err = ErrorResponseContainsNotFound.Error(nil) if ent != nil { - ent.ErrorAdd(true, err).Check(loglvl.NilLevel) + ent.ErrorAdd(true, err).Log() } return err } @@ -143,7 +148,7 @@ func (r *request) HealthCheck(ctx context.Context) error { if r.isValidContents(opts.Health.Result.NotContain, buf) { err = ErrorResponseNotContainsFound.Error(nil) if ent != nil { - ent.ErrorAdd(true, err).Check(loglvl.NilLevel) + ent.ErrorAdd(true, err).Log() } return err } diff --git a/request/request.go b/request/request.go index c507bd77..fb469c5c 100644 --- a/request/request.go +++ b/request/request.go @@ -29,11 +29,11 @@ package request import ( "bytes" "context" + "crypto/sha256" "encoding/json" "io" "net/http" "net/url" - "strings" liberr "github.com/nabbar/golib/errors" ) @@ -67,10 +67,10 @@ func (r *request) makeRequest(ctx context.Context, u *url.URL, mtd string, body return req, nil } -func (r *request) checkResponse(rsp *http.Response, validStatus ...int) (*bytes.Buffer, error) { +func (r *request) checkResponse(rsp *http.Response, validStatus ...int) ([]byte, error) { var ( e error - b = bytes.NewBuffer(make([]byte, 0)) + b []byte ) defer func() { @@ -80,11 +80,11 @@ func (r *request) checkResponse(rsp *http.Response, validStatus ...int) (*bytes. }() if rsp == nil { - return b, ErrorResponseInvalid.Error(nil) + return nil, ErrorResponseInvalid.Error(nil) } if rsp.Body != nil { - if _, e = io.Copy(b, rsp.Body); e != nil { + if b, e = io.ReadAll(rsp.Body); e != nil { return b, ErrorResponseLoadBody.Error(e) } } @@ -110,15 +110,15 @@ func (r *request) isValidCode(listValid []int, statusCode int) bool { return false } -func (r *request) isValidContents(contains []string, buf *bytes.Buffer) bool { +func (r *request) isValidContents(contains []string, buf []byte) bool { if len(contains) < 1 { return true - } else if buf.Len() < 1 { + } else if len(buf) < 1 { return false } for _, c := range contains { - if strings.Contains(buf.String(), c) { + if bytes.Contains(buf, []byte(c)) { return true } } @@ -135,7 +135,6 @@ func (r *request) Do() (*http.Response, error) { } var ( - e error req *http.Request rer *requestError rsp *http.Response @@ -153,22 +152,19 @@ func (r *request) Do() (*http.Response, error) { return nil, ErrorCreateRequest.Error(err) } - rsp, e = r.client().Do(req) + rsp, err = r.client().Do(req) - if e != nil { - rer.err = e + if err != nil { + rer.err = err r.setError(rer) - return nil, ErrorSendRequest.Error(e) + return nil, ErrorSendRequest.Error(err) } return rsp, nil } -func (r *request) DoParse(model interface{}, validStatus ...int) error { +func (r *request) doParse(opt *DoRequestOptions) error { var ( - e error - b = bytes.NewBuffer(make([]byte, 0)) - err error rsp *http.Response rer *requestError @@ -184,10 +180,9 @@ func (r *request) DoParse(model interface{}, validStatus ...int) error { rer.status = rsp.Status } - b, err = r.checkResponse(rsp, validStatus...) - rer.bufBody = b + rer.bufBody, err = r.checkResponse(rsp, opt.ValidStatusCodes...) - if er := liberr.Get(err); er != nil && er.HasCode(ErrorResponseStatus) { + if liberr.Has(err, ErrorResponseStatus) { rer.statusErr = true } else if err != nil { rer.err = err @@ -195,23 +190,39 @@ func (r *request) DoParse(model interface{}, validStatus ...int) error { return err } - if b.Len() > 0 { - if e = json.Unmarshal(b.Bytes(), model); e != nil { + if len(rer.bufBody) > 0 { + v := sha256.Sum256(rer.bufBody) + rer.checksum = v[:] + + if len(opt.Checksum256) > 0 && len(rer.checksum) > 0 { + if bytes.Equal(rer.CheckSum(), opt.Checksum256) { + rer.isSame = true + return nil + } + } + + if err = json.Unmarshal(rer.bufBody, opt.Model); err != nil { rer.bodyErr = true - rer.err = e + rer.err = err r.setError(rer) - return ErrorResponseUnmarshall.Error(e) + return ErrorResponseUnmarshall.Error(err) } } return nil } -func (r *request) DoParseRetry(retry int, model interface{}, validStatus ...int) error { - var e error +func (r *request) DoWithOptions(opt *DoRequestOptions) error { + var ( + e error + ) + + if opt.Retry < 1 { + opt.Retry = 1 + } - for i := 0; i < retry; i++ { - if e = r.DoParse(model, validStatus...); e != nil { + for i := 0; i < opt.Retry; i++ { + if e = r.doParse(opt); e != nil { continue } else if r.IsError() { continue diff --git a/retro/retro_test.go b/retro/retro_test.go index a3cb6b37..d57195ef 100644 --- a/retro/retro_test.go +++ b/retro/retro_test.go @@ -64,15 +64,15 @@ type Test struct { } type Standard struct { - A int `json:"a" yaml:"a" toml:"a"` - b int + A int `json:"a" yaml:"a" toml:"a"` + b int C string `json:"C" yaml:"C" toml:"C"` D []string `json:"d" yaml:"d" toml:"d"` } type Address struct { - Street string `json:"street" ` - City string `json:"city,omitempty"` + Street string `json:"street" ` + City string `json:"city,omitempty"` } type Status int diff --git a/socket/config/server.go b/socket/config/server.go index c0854c5e..e3b3a53f 100644 --- a/socket/config/server.go +++ b/socket/config/server.go @@ -49,6 +49,6 @@ type ServerConfig struct { // New returns a new server with the given handler and based on the ServerConfig // handler libsck.Handler // (libsck.Server, error) -func (o ServerConfig) New(handler libsck.Handler) (libsck.Server, error) { - return scksrv.New(handler, o.Network, o.Address, o.PermFile, o.GroupPerm) +func (o ServerConfig) New(updateCon libsck.UpdateConn, handler libsck.Handler) (libsck.Server, error) { + return scksrv.New(updateCon, handler, o.Network, o.Address, o.PermFile, o.GroupPerm) } diff --git a/socket/delim/interface.go b/socket/delim/interface.go index ac72b020..0c0e86a9 100644 --- a/socket/delim/interface.go +++ b/socket/delim/interface.go @@ -27,8 +27,8 @@ package delim import ( + "bufio" "io" - "sync/atomic" libsiz "github.com/nabbar/golib/size" ) @@ -37,29 +37,25 @@ type BufferDelim interface { io.ReadCloser io.WriterTo - SetDelim(d rune) - GetDelim() rune - - SetBufferSize(b libsiz.Size) - GetBufferSize() libsiz.Size - - SetInput(i io.ReadCloser) + Delim() rune Reader() io.ReadCloser Copy(w io.Writer) (n int64, err error) ReadBytes() ([]byte, error) + UnRead() ([]byte, error) } func New(r io.ReadCloser, delim rune, sizeBufferRead libsiz.Size) BufferDelim { - d := &dlm{ - i: new(atomic.Value), - d: new(atomic.Int32), - s: new(atomic.Uint64), - r: new(atomic.Value), - } + var b *bufio.Reader - d.SetDelim(delim) - d.SetBufferSize(sizeBufferRead) - d.SetInput(r) + if sizeBufferRead > 0 { + b = bufio.NewReaderSize(r, sizeBufferRead.Int()) + } else { + b = bufio.NewReader(r) + } - return d + return &dlm{ + i: r, + r: b, + d: delim, + } } diff --git a/socket/delim/io.go b/socket/delim/io.go index 1affe493..5c83640d 100644 --- a/socket/delim/io.go +++ b/socket/delim/io.go @@ -37,44 +37,67 @@ func (o *dlm) Copy(w io.Writer) (n int64, err error) { } func (o *dlm) Read(p []byte) (n int, err error) { - if r := o.getReader(); r == nil { + if o == nil || o.r == nil { return 0, ErrInstance - } else { - b, e := r.ReadBytes(o.getDelimByte()) - - if len(b) > 0 { - if cap(p) < len(b) { - p = append(p, make([]byte, len(b)-len(p))...) - } - copy(p, b) + } + + b, e := o.r.ReadBytes(byte(o.d)) + + if len(b) > 0 { + if cap(p) < len(b) { + p = append(p, make([]byte, len(b)-len(p))...) } + copy(p, b) + } + + return len(b), e +} + +func (o *dlm) UnRead() ([]byte, error) { + if o == nil || o.r == nil { + return nil, ErrInstance + } - return len(b), e + if s := o.r.Buffered(); s > 0 { + b := make([]byte, s) + _, e := o.r.Read(b) + return b, e } + + return nil, nil } func (o *dlm) ReadBytes() ([]byte, error) { - if r := o.getReader(); r == nil { - return make([]byte, 0), ErrInstance - } else { - return r.ReadBytes(o.getDelimByte()) + if o.r == nil { + return nil, ErrInstance } + + return o.r.ReadBytes(byte(o.d)) } func (o *dlm) Close() error { - return o.getInput().Close() + o.r.Reset(nil) + o.r = nil + + return o.i.Close() } func (o *dlm) WriteTo(w io.Writer) (n int64, err error) { var ( - i int - s = 1 e error + i int b []byte + + s = 1 + d = o.getDelimByte() ) - for err == nil && s > 0 { - b, err = o.ReadBytes() + if o.r == nil { + return 0, ErrInstance + } + + for err == nil { + b, err = o.r.ReadBytes(d) s = len(b) if s > 0 { diff --git a/socket/delim/model.go b/socket/delim/model.go index dd1b6c20..fe04cbab 100644 --- a/socket/delim/model.go +++ b/socket/delim/model.go @@ -29,82 +29,18 @@ package delim import ( "bufio" "io" - "sync/atomic" - - libsiz "github.com/nabbar/golib/size" ) type dlm struct { - i *atomic.Value // input io.ReadCloser - d *atomic.Int32 // delimiter rune - s *atomic.Uint64 // buffer libsiz.Size - r *atomic.Value // *bufio.Reader -} - -func (o *dlm) SetDelim(delim rune) { - o.d.Store(delim) + i io.ReadCloser // input io.ReadCloser + r *bufio.Reader // *bufio.Reader + d rune // delimiter rune } -func (o *dlm) GetDelim() rune { - return o.d.Load() +func (o *dlm) Delim() rune { + return o.d } func (o *dlm) getDelimByte() byte { - return byte(o.GetDelim()) -} - -func (o *dlm) SetBufferSize(b libsiz.Size) { - o.s.Store(b.Uint64()) -} - -func (o *dlm) GetBufferSize() libsiz.Size { - return libsiz.Size(o.s.Load()) -} - -func (o *dlm) SetInput(i io.ReadCloser) { - if i == nil { - i = &DiscardCloser{} - } - - o.i.Store(i) -} - -func (o *dlm) getInput() io.ReadCloser { - if i := o.i.Load(); i == nil { - return &DiscardCloser{} - } else if v, k := i.(io.ReadCloser); !k { - return &DiscardCloser{} - } else { - return v - } -} - -func (o *dlm) newReader() *bufio.Reader { - if siz := o.GetBufferSize(); siz > 0 { - return bufio.NewReaderSize(o.getInput(), siz.Int()) - } else { - return bufio.NewReader(o.getInput()) - } -} - -func (o *dlm) setReader(r *bufio.Reader) { - if r == nil { - r = o.newReader() - } - - o.r.Store(r) -} - -func (o *dlm) getReader() *bufio.Reader { - if i := o.r.Load(); i == nil { - r := o.newReader() - o.setReader(r) - return r - } else if v, k := i.(*bufio.Reader); !k { - r := o.newReader() - o.setReader(r) - return r - } else { - return v - } + return byte(o.d) } diff --git a/socket/interface.go b/socket/interface.go index 50b38cee..939d0802 100644 --- a/socket/interface.go +++ b/socket/interface.go @@ -40,6 +40,11 @@ const DefaultBufferSize = 32 * 1024 // EOL is the end of line, default delimiter of the socket const EOL byte = '\n' +// error to be filtered +var ( + errFilterClosed = "use of closed network connection" +) + // ConnState is used to process state connection type ConnState uint8 @@ -89,6 +94,9 @@ type FuncInfo func(local, remote net.Addr, state ConnState) // Handler is used to process request type Handler func(request Reader, response Writer) +// UpdateConn is used to update new connection before used it +type UpdateConn func(co net.Conn) + // Response is used to process response type Response func(r io.Reader) @@ -167,3 +175,15 @@ type Client interface { // error. Once(ctx context.Context, request io.Reader, fct Response) error } + +func ErrorFilter(err error) error { + if err == nil { + return nil + } + + if err.Error() == errFilterClosed { + return nil + } + + return err +} diff --git a/socket/server/interface_linux.go b/socket/server/interface_linux.go index 3e3c9dd3..278123eb 100644 --- a/socket/server/interface_linux.go +++ b/socket/server/interface_linux.go @@ -46,6 +46,7 @@ import ( // New creates a new server based on the provided network protocol. // // Parameters: +// - upd: a Update Connection function or nil // - handler: the handler for the server // - delim: the delimiter to use to separate messages // - proto: the network protocol to use @@ -56,26 +57,26 @@ import ( // Return type(s): // - libsck.Server: the created server // - error: an error if any occurred during server creation -func New(handler libsck.Handler, proto libptc.NetworkProtocol, address string, perm os.FileMode, gid int32) (libsck.Server, error) { +func New(upd libsck.UpdateConn, handler libsck.Handler, proto libptc.NetworkProtocol, address string, perm os.FileMode, gid int32) (libsck.Server, error) { switch proto { case libptc.NetworkUnix: if strings.EqualFold(runtime.GOOS, "linux") { - s := scksrx.New(handler) + s := scksrx.New(upd, handler) e := s.RegisterSocket(address, perm, gid) return s, e } case libptc.NetworkUnixGram: if strings.EqualFold(runtime.GOOS, "linux") { - s := sckgrm.New(handler) + s := sckgrm.New(upd, handler) e := s.RegisterSocket(address, perm, gid) return s, e } case libptc.NetworkTCP, libptc.NetworkTCP4, libptc.NetworkTCP6: - s := scksrt.New(handler) + s := scksrt.New(upd, handler) e := s.RegisterServer(address) return s, e case libptc.NetworkUDP, libptc.NetworkUDP4, libptc.NetworkUDP6: - s := scksru.New(handler) + s := scksru.New(upd, handler) e := s.RegisterServer(address) return s, e } diff --git a/socket/server/interface_other.go b/socket/server/interface_other.go index 1af2314c..985815c7 100644 --- a/socket/server/interface_other.go +++ b/socket/server/interface_other.go @@ -42,6 +42,7 @@ import ( // New creates a new server based on the network protocol specified. // // Parameters: +// - upd: a Update Connection function or nil // - handler: the handler for the server // - delim: the delimiter to use to separate messages // - proto: the network protocol to use @@ -50,14 +51,14 @@ import ( // - perm: the file mode permissions for the socket, not applicable for non unix // - gid: the group ID for the socket permissions, not applicable for non unix // Return type(s): libsck.Server, error -func New(handler libsck.Handler, proto libptc.NetworkProtocol, address string, perm os.FileMode, gid int32) (libsck.Server, error) { +func New(upd libsck.UpdateConn, handler libsck.Handler, proto libptc.NetworkProtocol, address string, perm os.FileMode, gid int32) (libsck.Server, error) { switch proto { case libptc.NetworkTCP, libptc.NetworkTCP4, libptc.NetworkTCP6: - s := scksrt.New(handler) + s := scksrt.New(upd, handler) e := s.RegisterServer(address) return s, e case libptc.NetworkUDP, libptc.NetworkUDP4, libptc.NetworkUDP6: - s := scksru.New(handler) + s := scksru.New(upd, handler) e := s.RegisterServer(address) return s, e } diff --git a/socket/server/tcp/interface.go b/socket/server/tcp/interface.go index 03e8cc73..bf193415 100644 --- a/socket/server/tcp/interface.go +++ b/socket/server/tcp/interface.go @@ -37,7 +37,7 @@ type ServerTcp interface { RegisterServer(address string) error } -func New(h libsck.Handler) ServerTcp { +func New(u libsck.UpdateConn, h libsck.Handler) ServerTcp { c := new(atomic.Value) c.Store(make(chan []byte)) @@ -47,12 +47,10 @@ func New(h libsck.Handler) ServerTcp { r := new(atomic.Value) r.Store(make(chan struct{})) - f := new(atomic.Value) - f.Store(h) - return &srv{ ssl: new(atomic.Value), - hdl: f, + upd: u, + hdl: h, msg: c, stp: s, rst: r, diff --git a/socket/server/tcp/listener.go b/socket/server/tcp/listener.go index 7789e212..057a3509 100644 --- a/socket/server/tcp/listener.go +++ b/socket/server/tcp/listener.go @@ -76,7 +76,7 @@ func (o *srv) Listen(ctx context.Context) error { if len(a) == 0 { o.fctError(ErrInvalidHandler) return ErrInvalidAddress - } else if hdl := o.handler(); hdl == nil { + } else if o.hdl == nil { o.fctError(ErrInvalidHandler) return ErrInvalidHandler } else if l, e = o.getListen(a); e != nil { @@ -141,13 +141,17 @@ func (o *srv) Listen(ctx context.Context) error { func (o *srv) Conn(ctx context.Context, con net.Conn) { var ( - hdl libsck.Handler cnl context.CancelFunc cor libsck.Reader cow libsck.Writer ) o.nc.Add(1) // inc nb connection + + if o.upd != nil { + o.upd(con) + } + ctx, cnl = context.WithCancel(ctx) cor, cow = o.getReadWriter(ctx, cnl, con) @@ -181,10 +185,10 @@ func (o *srv) Conn(ctx context.Context, con net.Conn) { }() // get handler or exit if nil - if hdl = o.handler(); hdl == nil { + if o.hdl == nil { return } else { - go hdl(cor, cow) + go o.hdl(cor, cow) } for { @@ -213,12 +217,12 @@ func (o *srv) getReadWriter(ctx context.Context, cnl context.CancelFunc, con net if cr, ok := con.(*net.TCPConn); ok { rc.Store(true) o.fctInfo(con.LocalAddr(), con.RemoteAddr(), libsck.ConnectionCloseRead) - return cr.CloseRead() + return libsck.ErrorFilter(cr.CloseRead()) } else { rc.Store(true) rw.Store(true) o.fctInfo(con.LocalAddr(), con.RemoteAddr(), libsck.ConnectionClose) - return con.Close() + return libsck.ErrorFilter(con.Close()) } } @@ -232,12 +236,12 @@ func (o *srv) getReadWriter(ctx context.Context, cnl context.CancelFunc, con net if cr, ok := con.(*net.TCPConn); ok { rw.Store(true) o.fctInfo(con.LocalAddr(), con.RemoteAddr(), libsck.ConnectionCloseWrite) - return cr.CloseRead() + return libsck.ErrorFilter(cr.CloseRead()) } else { rc.Store(true) rw.Store(true) o.fctInfo(con.LocalAddr(), con.RemoteAddr(), libsck.ConnectionClose) - return con.Close() + return libsck.ErrorFilter(con.Close()) } } diff --git a/socket/server/tcp/model.go b/socket/server/tcp/model.go index 8fc1a15b..c58a82bf 100644 --- a/socket/server/tcp/model.go +++ b/socket/server/tcp/model.go @@ -49,13 +49,14 @@ func init() { } type srv struct { - ssl *atomic.Value // tls config - hdl *atomic.Value // handler - msg *atomic.Value // chan []byte - stp *atomic.Value // chan struct{} - rst *atomic.Value // chan struct{} - run *atomic.Bool // is Running - gon *atomic.Bool // is Running + ssl *atomic.Value // tls config + upd libsck.UpdateConn // updateConn + hdl libsck.Handler // handler + msg *atomic.Value // chan []byte + stp *atomic.Value // chan struct{} + rst *atomic.Value // chan struct{} + run *atomic.Bool // is Running + gon *atomic.Bool // is Running fe *atomic.Value // function error fi *atomic.Value // function info @@ -294,19 +295,6 @@ func (o *srv) fctInfoSrv(msg string, args ...interface{}) { } } -func (o *srv) handler() libsck.Handler { - if o == nil { - return nil - } - - v := o.hdl.Load() - if v != nil { - return v.(libsck.Handler) - } - - return nil -} - func (o *srv) getTLS() *tls.Config { i := o.ssl.Load() diff --git a/socket/server/tcp/tcp_test.go b/socket/server/tcp/tcp_test.go index 47da2181..1ce2fb50 100644 --- a/socket/server/tcp/tcp_test.go +++ b/socket/server/tcp/tcp_test.go @@ -130,7 +130,7 @@ var _ = Describe("socket/server/tcp", func() { }) It("Create new server based on config must succeed", func() { - sck, err = srv.New(Handler) + sck, err = srv.New(nil, Handler) Expect(err).ToNot(HaveOccurred()) Expect(sck).ToNot(BeNil()) }) @@ -226,7 +226,7 @@ var _ = Describe("socket/server/tcp", func() { }) It("Create new server based on config must succeed", func() { - sck, err = srv.New(Handler) + sck, err = srv.New(nil, Handler) Expect(err).ToNot(HaveOccurred()) Expect(sck).ToNot(BeNil()) }) diff --git a/socket/server/udp/interface.go b/socket/server/udp/interface.go index bca2ca40..fc9772f4 100644 --- a/socket/server/udp/interface.go +++ b/socket/server/udp/interface.go @@ -37,18 +37,16 @@ type ServerTcp interface { RegisterServer(address string) error } -func New(h libsck.Handler) ServerTcp { +func New(u libsck.UpdateConn, h libsck.Handler) ServerTcp { c := new(atomic.Value) c.Store(make(chan []byte)) s := new(atomic.Value) s.Store(make(chan struct{})) - f := new(atomic.Value) - f.Store(h) - return &srv{ - hdl: f, + upd: u, + hdl: h, msg: c, stp: s, run: new(atomic.Bool), diff --git a/socket/server/udp/listener.go b/socket/server/udp/listener.go index fe515c1b..4ed21a35 100644 --- a/socket/server/udp/listener.go +++ b/socket/server/udp/listener.go @@ -74,7 +74,6 @@ func (o *srv) Listen(ctx context.Context) error { loc *net.UDPAddr con *net.UDPConn - hdl libsck.Handler cnl context.CancelFunc cor libsck.Reader cow libsck.Writer @@ -85,7 +84,7 @@ func (o *srv) Listen(ctx context.Context) error { if len(a) == 0 { o.fctError(ErrInvalidInstance) return ErrInvalidAddress - } else if hdl = o.handler(); hdl == nil { + } else if o.hdl == nil { o.fctError(ErrInvalidInstance) return ErrInvalidHandler } else if loc, con, e = o.getListen(a); e != nil { @@ -93,6 +92,10 @@ func (o *srv) Listen(ctx context.Context) error { return e } + if o.upd != nil { + o.upd(con) + } + ctx, cnl = context.WithCancel(ctx) cor, cow = o.getReadWriter(ctx, con, loc) @@ -114,7 +117,7 @@ func (o *srv) Listen(ctx context.Context) error { }() // get handler or exit if nil - go hdl(cor, cow) + go o.hdl(cor, cow) for { select { @@ -143,7 +146,7 @@ func (o *srv) getReadWriter(ctx context.Context, con *net.UDPConn, loc net.Addr) fctClose := func() error { o.fctInfo(loc, fg(), libsck.ConnectionClose) - return con.Close() + return libsck.ErrorFilter(con.Close()) } rdr := libsck.NewReader( diff --git a/socket/server/udp/model.go b/socket/server/udp/model.go index 514211e5..028193f6 100644 --- a/socket/server/udp/model.go +++ b/socket/server/udp/model.go @@ -48,10 +48,11 @@ func init() { } type srv struct { - hdl *atomic.Value // handler - msg *atomic.Value // chan []byte - stp *atomic.Value // chan struct{} - run *atomic.Bool // is Running + upd libsck.UpdateConn // updateConn + hdl libsck.Handler // handler + msg *atomic.Value // chan []byte + stp *atomic.Value // chan struct{} + run *atomic.Bool // is Running fe *atomic.Value // function error fi *atomic.Value // function info @@ -218,16 +219,3 @@ func (o *srv) fctInfoSrv(msg string, args ...interface{}) { v.(libsck.FuncInfoSrv)(fmt.Sprintf(msg, args...)) } } - -func (o *srv) handler() libsck.Handler { - if o == nil { - return nil - } - - v := o.hdl.Load() - if v != nil { - return v.(libsck.Handler) - } - - return nil -} diff --git a/socket/server/udp/udp_test.go b/socket/server/udp/udp_test.go index fe64d070..821159c4 100644 --- a/socket/server/udp/udp_test.go +++ b/socket/server/udp/udp_test.go @@ -129,7 +129,7 @@ var _ = Describe("socket/server/udp", func() { }) It("Create new server based on config must succeed", func() { - sck, err = srv.New(Handler) + sck, err = srv.New(nil, Handler) Expect(err).ToNot(HaveOccurred()) Expect(sck).ToNot(BeNil()) }) diff --git a/socket/server/unix/interface.go b/socket/server/unix/interface.go index c8a6f931..9c299a49 100644 --- a/socket/server/unix/interface.go +++ b/socket/server/unix/interface.go @@ -43,7 +43,7 @@ type ServerUnix interface { RegisterSocket(unixFile string, perm os.FileMode, gid int32) error } -func New(h libsck.Handler) ServerUnix { +func New(u libsck.UpdateConn, h libsck.Handler) ServerUnix { c := new(atomic.Value) c.Store(make(chan []byte)) @@ -53,9 +53,6 @@ func New(h libsck.Handler) ServerUnix { r := new(atomic.Value) r.Store(make(chan struct{})) - f := new(atomic.Value) - f.Store(h) - // socket file sf := new(atomic.Value) sf.Store("") @@ -69,7 +66,8 @@ func New(h libsck.Handler) ServerUnix { sg.Store(0) return &srv{ - hdl: f, + upd: u, + hdl: h, msg: c, stp: s, rst: r, diff --git a/socket/server/unix/listener.go b/socket/server/unix/listener.go index b2917a77..e5c420de 100644 --- a/socket/server/unix/listener.go +++ b/socket/server/unix/listener.go @@ -146,7 +146,7 @@ func (o *srv) Listen(ctx context.Context) error { if f, e = o.getSocketFile(); e != nil { o.fctError(e) return e - } else if hdl := o.handler(); hdl == nil { + } else if o.hdl == nil { o.fctError(ErrInvalidHandler) return ErrInvalidHandler } else if l, e = o.getListen(f); e != nil { @@ -211,13 +211,17 @@ func (o *srv) Listen(ctx context.Context) error { func (o *srv) Conn(ctx context.Context, con net.Conn) { var ( - hdl libsck.Handler cnl context.CancelFunc cor libsck.Reader cow libsck.Writer ) o.nc.Add(1) // inc nb connection + + if o.upd != nil { + o.upd(con) + } + ctx, cnl = context.WithCancel(ctx) cor, cow = o.getReadWriter(ctx, cnl, con) @@ -251,10 +255,10 @@ func (o *srv) Conn(ctx context.Context, con net.Conn) { }() // get handler or exit if nil - if hdl = o.handler(); hdl == nil { + if o.hdl == nil { return } else { - go hdl(cor, cow) + go o.hdl(cor, cow) } for { @@ -283,12 +287,12 @@ func (o *srv) getReadWriter(ctx context.Context, cnl context.CancelFunc, con net if cr, ok := con.(*net.UnixConn); ok { rc.Store(true) o.fctInfo(con.LocalAddr(), con.RemoteAddr(), libsck.ConnectionCloseRead) - return cr.CloseRead() + return libsck.ErrorFilter(cr.CloseRead()) } else { rc.Store(true) rw.Store(true) o.fctInfo(con.LocalAddr(), con.RemoteAddr(), libsck.ConnectionClose) - return con.Close() + return libsck.ErrorFilter(con.Close()) } } @@ -302,12 +306,12 @@ func (o *srv) getReadWriter(ctx context.Context, cnl context.CancelFunc, con net if cr, ok := con.(*net.UnixConn); ok { rw.Store(true) o.fctInfo(con.LocalAddr(), con.RemoteAddr(), libsck.ConnectionCloseWrite) - return cr.CloseRead() + return libsck.ErrorFilter(cr.CloseRead()) } else { rc.Store(true) rw.Store(true) o.fctInfo(con.LocalAddr(), con.RemoteAddr(), libsck.ConnectionClose) - return con.Close() + return libsck.ErrorFilter(con.Close()) } } diff --git a/socket/server/unix/model.go b/socket/server/unix/model.go index 71474fa8..3d82acb7 100644 --- a/socket/server/unix/model.go +++ b/socket/server/unix/model.go @@ -52,12 +52,13 @@ func init() { } type srv struct { - hdl *atomic.Value // handler - msg *atomic.Value // chan []byte - stp *atomic.Value // chan struct{} - rst *atomic.Value // chan struct{} - run *atomic.Bool // is Running - gon *atomic.Bool // is Running + upd libsck.UpdateConn // updateConn + hdl libsck.Handler // handler + msg *atomic.Value // chan []byte + stp *atomic.Value // chan struct{} + rst *atomic.Value // chan struct{} + run *atomic.Bool // is Running + gon *atomic.Bool // is Running fe *atomic.Value // function error fi *atomic.Value // function info @@ -284,16 +285,3 @@ func (o *srv) fctInfoSrv(msg string, args ...interface{}) { v.(libsck.FuncInfoSrv)(fmt.Sprintf(msg, args...)) } } - -func (o *srv) handler() libsck.Handler { - if o == nil { - return nil - } - - v := o.hdl.Load() - if v != nil { - return v.(libsck.Handler) - } - - return nil -} diff --git a/socket/server/unix/unix_test.go b/socket/server/unix/unix_test.go index 8224ad2f..7ec48e8e 100644 --- a/socket/server/unix/unix_test.go +++ b/socket/server/unix/unix_test.go @@ -124,7 +124,7 @@ var _ = Describe("socket/server/unix", func() { }) It("Create new server based on config must succeed", func() { - sck, err = srv.New(Handler) + sck, err = srv.New(nil, Handler) Expect(err).ToNot(HaveOccurred()) Expect(sck).ToNot(BeNil()) }) @@ -226,7 +226,7 @@ var _ = Describe("socket/server/unix", func() { }) It("Create new server based on config must succeed", func() { - sck, err = srv.New(Handler) + sck, err = srv.New(nil, Handler) Expect(err).ToNot(HaveOccurred()) Expect(sck).ToNot(BeNil()) }) diff --git a/socket/server/unixgram/interface.go b/socket/server/unixgram/interface.go index d2e71a58..5b36a4ac 100644 --- a/socket/server/unixgram/interface.go +++ b/socket/server/unixgram/interface.go @@ -43,16 +43,13 @@ type ServerUnixGram interface { RegisterSocket(unixFile string, perm os.FileMode, gid int32) error } -func New(h libsck.Handler) ServerUnixGram { +func New(u libsck.UpdateConn, h libsck.Handler) ServerUnixGram { c := new(atomic.Value) c.Store(make(chan []byte)) s := new(atomic.Value) s.Store(make(chan struct{})) - f := new(atomic.Value) - f.Store(h) - // socket file sf := new(atomic.Value) sf.Store("") @@ -66,7 +63,8 @@ func New(h libsck.Handler) ServerUnixGram { sg.Store(0) return &srv{ - hdl: f, + upd: u, + hdl: h, msg: c, stp: s, run: new(atomic.Bool), diff --git a/socket/server/unixgram/listener.go b/socket/server/unixgram/listener.go index 328c3b05..09c24592 100644 --- a/socket/server/unixgram/listener.go +++ b/socket/server/unixgram/listener.go @@ -142,7 +142,6 @@ func (o *srv) Listen(ctx context.Context) error { loc *net.UnixAddr con *net.UnixConn - hdl libsck.Handler cnl context.CancelFunc cor libsck.Reader cow libsck.Writer @@ -153,7 +152,7 @@ func (o *srv) Listen(ctx context.Context) error { if u, e = o.getSocketFile(); e != nil { o.fctError(e) return e - } else if hdl = o.handler(); hdl == nil { + } else if o.hdl == nil { o.fctError(ErrInvalidHandler) return ErrInvalidHandler } else if loc, e = net.ResolveUnixAddr(libptc.NetworkUnixGram.Code(), u); e != nil { @@ -164,6 +163,10 @@ func (o *srv) Listen(ctx context.Context) error { return e } + if o.upd != nil { + o.upd(con) + } + ctx, cnl = context.WithCancel(ctx) cor, cow = o.getReadWriter(ctx, con, loc) @@ -189,7 +192,7 @@ func (o *srv) Listen(ctx context.Context) error { }() // get handler or exit if nil - go hdl(cor, cow) + go o.hdl(cor, cow) for { select { @@ -218,7 +221,7 @@ func (o *srv) getReadWriter(ctx context.Context, con *net.UnixConn, loc net.Addr fctClose := func() error { o.fctInfo(loc, fg(), libsck.ConnectionClose) - return con.Close() + return libsck.ErrorFilter(con.Close()) } rdr := libsck.NewReader( diff --git a/socket/server/unixgram/model.go b/socket/server/unixgram/model.go index 8c1a74f4..1410cdc4 100644 --- a/socket/server/unixgram/model.go +++ b/socket/server/unixgram/model.go @@ -52,10 +52,11 @@ func init() { } type srv struct { - hdl *atomic.Value // handler - msg *atomic.Value // chan []byte - stp *atomic.Value // chan struct{} - run *atomic.Bool // is Running + upd libsck.UpdateConn // updateConn + hdl libsck.Handler // handler + msg *atomic.Value // chan []byte + stp *atomic.Value // chan struct{} + run *atomic.Bool // is Running fe *atomic.Value // function error fi *atomic.Value // function info @@ -228,16 +229,3 @@ func (o *srv) fctInfoSrv(msg string, args ...interface{}) { v.(libsck.FuncInfoSrv)(fmt.Sprintf(msg, args...)) } } - -func (o *srv) handler() libsck.Handler { - if o == nil { - return nil - } - - v := o.hdl.Load() - if v != nil { - return v.(libsck.Handler) - } - - return nil -} diff --git a/socket/server/unixgram/unixgram_test.go b/socket/server/unixgram/unixgram_test.go index a7278f5c..000edb41 100644 --- a/socket/server/unixgram/unixgram_test.go +++ b/socket/server/unixgram/unixgram_test.go @@ -117,7 +117,7 @@ var _ = Describe("socket/server/unixgram", func() { }) It("Create new server based on config must succeed", func() { - sck, err = srv.New(Handler) + sck, err = srv.New(nil, Handler) Expect(err).ToNot(HaveOccurred()) Expect(sck).ToNot(BeNil()) }) diff --git a/test/test-socket-server-tcp/main.go b/test/test-socket-server-tcp/main.go index 5544fa2b..bbb62492 100644 --- a/test/test-socket-server-tcp/main.go +++ b/test/test-socket-server-tcp/main.go @@ -78,7 +78,7 @@ func checkPanic(err ...error) { } func main() { - srv, err := config().New(Handler) + srv, err := config().New(nil, Handler) checkPanic(err) srv.RegisterFuncError(func(e ...error) { diff --git a/test/test-socket-server-udp/main.go b/test/test-socket-server-udp/main.go index e86db601..7a72f6e4 100644 --- a/test/test-socket-server-udp/main.go +++ b/test/test-socket-server-udp/main.go @@ -78,7 +78,7 @@ func checkPanic(err ...error) { } func main() { - srv, err := config().New(Handler) + srv, err := config().New(nil, Handler) checkPanic(err) srv.RegisterFuncError(func(e ...error) { diff --git a/test/test-socket-server-unix/main.go b/test/test-socket-server-unix/main.go index 391c1802..08c0530c 100644 --- a/test/test-socket-server-unix/main.go +++ b/test/test-socket-server-unix/main.go @@ -79,7 +79,7 @@ func checkPanic(err ...error) { } func main() { - srv, err := config().New(Handler) + srv, err := config().New(nil, Handler) checkPanic(err) srv.RegisterFuncError(func(e ...error) {