From cce6ba534819b23d7ea2092bf4f80a0ebe0e39a7 Mon Sep 17 00:00:00 2001 From: Nicolas JUHEL Date: Wed, 27 Mar 2024 17:33:32 +0100 Subject: [PATCH] Package Config: - component/request: fix IsStarted function - component/head: replace mutex to atomic Package LDAP: - fix bug if ldapPort for TLS is not set - fix bug if ldapPort for StartTLS or noTLS is not set --- config/components/head/client.go | 20 +----------------- config/components/head/component.go | 29 +++----------------------- config/components/head/interface.go | 14 +++++++------ config/components/head/model.go | 23 +++++++++++--------- config/components/request/component.go | 2 +- ldap/ldap.go | 14 +++++++++++-- ldap/model.go | 6 ++++-- 7 files changed, 42 insertions(+), 66 deletions(-) diff --git a/config/components/head/client.go b/config/components/head/client.go index d45ad161..51b7c5cc 100644 --- a/config/components/head/client.go +++ b/config/components/head/client.go @@ -34,9 +34,6 @@ import ( ) func (o *componentHead) _getKey() string { - o.m.RLock() - defer o.m.RUnlock() - if i, l := o.x.Load(keyCptKey); !l { return "" } else if i == nil { @@ -49,9 +46,6 @@ func (o *componentHead) _getKey() string { } func (o *componentHead) _getFctVpr() libvpr.FuncViper { - o.m.RLock() - defer o.m.RUnlock() - if i, l := o.x.Load(keyFctViper); !l { return nil } else if i == nil { @@ -84,9 +78,6 @@ func (o *componentHead) _getSPFViper() *spfvbr.Viper { } func (o *componentHead) _getFctCpt() cfgtps.FuncCptGet { - o.m.RLock() - defer o.m.RUnlock() - if i, l := o.x.Load(keyFctGetCpt); !l { return nil } else if i == nil { @@ -99,9 +90,6 @@ func (o *componentHead) _getFctCpt() cfgtps.FuncCptGet { } func (o *componentHead) _getVersion() libver.Version { - o.m.RLock() - defer o.m.RUnlock() - if i, l := o.x.Load(keyCptVersion); !l { return nil } else if i == nil { @@ -122,9 +110,6 @@ func (o *componentHead) _getFct() (cfgtps.FuncCptEvent, cfgtps.FuncCptEvent) { } func (o *componentHead) _getFctEvt(key uint8) cfgtps.FuncCptEvent { - o.m.RLock() - defer o.m.RUnlock() - if i, l := o.x.Load(key); !l { return nil } else if i == nil { @@ -148,10 +133,7 @@ func (o *componentHead) _runCli() error { if cfg, err := o._getConfig(); err != nil { return ErrorParamInvalid.Error(err) } else { - o.m.Lock() - defer o.m.Unlock() - - o.h = cfg.New() + o.SetHeaders(cfg.New()) return nil } } diff --git a/config/components/head/component.go b/config/components/head/component.go index cbc72c2e..e62ae87e 100644 --- a/config/components/head/component.go +++ b/config/components/head/component.go @@ -55,17 +55,6 @@ func (o *componentHead) Type() string { } func (o *componentHead) Init(key string, ctx libctx.FuncContext, get cfgtps.FuncCptGet, vpr libvpr.FuncViper, vrs libver.Version, log liblog.FuncLog) { - o.m.Lock() - defer o.m.Unlock() - - if o.x == nil { - o.x = libctx.NewConfig[uint8](ctx) - } else { - x := libctx.NewConfig[uint8](ctx) - x.Merge(o.x) - o.x = x - } - o.x.Store(keyCptKey, key) o.x.Store(keyFctGetCpt, get) o.x.Store(keyFctViper, vpr) @@ -84,10 +73,7 @@ func (o *componentHead) RegisterFuncReload(before, after cfgtps.FuncCptEvent) { } func (o *componentHead) IsStarted() bool { - o.m.RLock() - defer o.m.RUnlock() - - return o.h != nil + return len(o.GetHeaders().Header()) > 0 } func (o *componentHead) IsRunning() bool { @@ -103,17 +89,11 @@ func (o *componentHead) Reload() error { } func (o *componentHead) Stop() { - o.m.Lock() - defer o.m.Unlock() - - o.h = nil + o.SetHeaders(nil) return } func (o *componentHead) Dependencies() []string { - o.m.RLock() - defer o.m.RUnlock() - var def = make([]string, 0) if o == nil { @@ -132,10 +112,7 @@ func (o *componentHead) Dependencies() []string { } func (o *componentHead) SetDependencies(d []string) error { - o.m.RLock() - defer o.m.RUnlock() - - if o.x == nil { + if o == nil { return ErrorComponentNotInitialized.Error(nil) } else { o.x.Store(keyCptDependencies, d) diff --git a/config/components/head/interface.go b/config/components/head/interface.go index 53a516cd..aa18543b 100644 --- a/config/components/head/interface.go +++ b/config/components/head/interface.go @@ -27,7 +27,9 @@ package head import ( - "sync" + "sync/atomic" + + libctx "github.com/nabbar/golib/context" libcfg "github.com/nabbar/golib/config" cfgtps "github.com/nabbar/golib/config/types" @@ -41,10 +43,10 @@ type ComponentHead interface { SetHeaders(head librtr.Headers) } -func New() ComponentHead { +func New(ctx libctx.FuncContext) ComponentHead { return &componentHead{ - m: sync.RWMutex{}, - h: nil, + x: libctx.NewConfig[uint8](ctx), + h: new(atomic.Value), } } @@ -52,8 +54,8 @@ func Register(cfg libcfg.Config, key string, cpt ComponentHead) { cfg.ComponentSet(key, cpt) } -func RegisterNew(cfg libcfg.Config, key string) { - cfg.ComponentSet(key, New()) +func RegisterNew(ctx libctx.FuncContext, cfg libcfg.Config, key string) { + cfg.ComponentSet(key, New(ctx)) } func Load(getCpt cfgtps.FuncCptGet, key string) ComponentHead { diff --git a/config/components/head/model.go b/config/components/head/model.go index 27b85984..f1b2ec21 100644 --- a/config/components/head/model.go +++ b/config/components/head/model.go @@ -27,28 +27,31 @@ package head import ( - "sync" + "sync/atomic" libctx "github.com/nabbar/golib/context" librtr "github.com/nabbar/golib/router/header" ) type componentHead struct { - m sync.RWMutex x libctx.Config[uint8] - h librtr.Headers + h *atomic.Value // librtr.Headers } func (o *componentHead) GetHeaders() librtr.Headers { - o.m.RLock() - defer o.m.RUnlock() - - return o.h + if i := o.h.Load(); i == nil { + return librtr.NewHeaders() + } else if v, k := i.(librtr.Headers); !k { + return librtr.NewHeaders() + } else { + return v + } } func (o *componentHead) SetHeaders(head librtr.Headers) { - o.m.Lock() - defer o.m.Unlock() + if head == nil { + head = librtr.NewHeaders() + } - o.h = head + o.h.Store(head) } diff --git a/config/components/request/component.go b/config/components/request/component.go index d39db6c9..0155b1fb 100644 --- a/config/components/request/component.go +++ b/config/components/request/component.go @@ -74,7 +74,7 @@ func (o *componentRequest) RegisterFuncReload(before, after cfgtps.FuncCptEvent) } func (o *componentRequest) IsStarted() bool { - return o != nil && o.r != nil + return o.getRequest() != nil } func (o *componentRequest) IsRunning() bool { diff --git a/ldap/ldap.go b/ldap/ldap.go index 4ff2b11d..dd65136f 100644 --- a/ldap/ldap.go +++ b/ldap/ldap.go @@ -181,8 +181,13 @@ func (lc *HelperLDAP) ForceTLSMode(tlsMode TLSMode, tlsConfig *tls.Config) { func (lc *HelperLDAP) dialTLS() (*ldap.Conn, liberr.Error) { d := net.Dialer{} + adr := lc.config.ServerAddr(true) - c, err := d.DialContext(lc.ctx, "tcp", lc.config.ServerAddr(true)) + if len(adr) < 3 { + return nil, ErrorLDAPServerTLS.Error(fmt.Errorf("invalid port for LDAPS")) + } + + c, err := d.DialContext(lc.ctx, "tcp", adr) if err != nil { if c != nil { @@ -218,8 +223,13 @@ func (lc *HelperLDAP) dialTLS() (*ldap.Conn, liberr.Error) { func (lc *HelperLDAP) dial() (*ldap.Conn, liberr.Error) { d := net.Dialer{} + adr := lc.config.ServerAddr(false) + + if len(adr) < 3 { + return nil, ErrorLDAPServerTLS.Error(fmt.Errorf("invalid port for LDAP / LDAP+STARTLS")) + } - c, err := d.DialContext(lc.ctx, "tcp", lc.config.ServerAddr(false)) + c, err := d.DialContext(lc.ctx, "tcp", adr) if err != nil { if c != nil { diff --git a/ldap/model.go b/ldap/model.go index d8e13bbb..09e2ccf7 100644 --- a/ldap/model.go +++ b/ldap/model.go @@ -99,11 +99,13 @@ func (cnf Config) BaseDN() string { } func (cnf Config) ServerAddr(withTls bool) string { - if withTls { + if withTls && cnf.Portldaps > 0 { return fmt.Sprintf("%s:%d", cnf.Uri, cnf.Portldaps) + } else if !withTls && cnf.PortLdap > 0 { + return fmt.Sprintf("%s:%d", cnf.Uri, cnf.PortLdap) } - return fmt.Sprintf("%s:%d", cnf.Uri, cnf.PortLdap) + return "" } func (cnf Config) PatternFilterGroup() string {