From 359bafcfe741f9ccb3e9dc73e46c74be16510103 Mon Sep 17 00:00:00 2001 From: vsaliieva <91525276+vsaliieva@users.noreply.github.com> Date: Wed, 7 Aug 2024 17:25:29 +0300 Subject: [PATCH 01/17] Allow to proceed devices with duplicated engineID (#38) * Allow to proceed devices with duplicated engineID Fixes ZEN-21001. --- pynetsnmp/netsnmp.py | 28 +++++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/pynetsnmp/netsnmp.py b/pynetsnmp/netsnmp.py index afee683..1ef00c3 100644 --- a/pynetsnmp/netsnmp.py +++ b/pynetsnmp/netsnmp.py @@ -717,14 +717,32 @@ class Session(object): cb = None - def __init__(self, cmdLineArgs=(), **kw): + def __init__(self, cmdLineArgs=(), freeEtimelist=True, **kw): self.cmdLineArgs = cmdLineArgs + self.freeEtimelist = freeEtimelist self.kw = kw self.sess = None self.args = None self._data = None # ref to _CallbackData object self._log = _getLogger("session") + def _snmp_send(self, session, pdu): + """Allows execution of free_etimelist() after each snmp_send() call. + + Executes lib.free_etimelist() after each lib.snmp_send() call if the + `freeEtimelist` attribute is set, or re-calls lib.snmp_send() otherwise. + This frees all the memory used by entries in the etimelist inside t he + net-snmp library, allowing the processing of devices with duplicated engineID. + + Note: This feature is not supported by RFC. + """ + + try: + return lib.snmp_send(session, pdu) + finally: + if self.freeEtimelist: + lib.free_etimelist() + def open(self): sess = netsnmp_session() self.args = initialize_session(sess, self.cmdLineArgs, self.kw) @@ -857,7 +875,7 @@ def sendTrap(self, trapoid, varbinds=None): n = strToOid(n) lib.snmp_add_var(pdu, n, len(n), t, v) - lib.snmp_send(self.sess, pdu) + self._snmp_send(self.sess, pdu) def close(self): if self.sess is not None: @@ -921,7 +939,7 @@ def get(self, oids): for oid in oids: oid = mkoid(oid) lib.snmp_add_null_var(req, oid, len(oid)) - send_status = lib.snmp_send(self.sess, req) + send_status = self._snmp_send(self.sess, req) self._handle_send_status(req, send_status, "get") return req.contents.reqid @@ -933,7 +951,7 @@ def getbulk(self, nonrepeaters, maxrepetitions, oids): for oid in oids: oid = mkoid(oid) lib.snmp_add_null_var(req, oid, len(oid)) - send_status = lib.snmp_send(self.sess, req) + send_status = self._snmp_send(self.sess, req) self._handle_send_status(req, send_status, "get") return req.contents.reqid @@ -941,7 +959,7 @@ def walk(self, root): req = self._create_request(SNMP_MSG_GETNEXT) oid = mkoid(root) lib.snmp_add_null_var(req, oid, len(oid)) - send_status = lib.snmp_send(self.sess, req) + send_status = self._snmp_send(self.sess, req) self._log.debug("walk: send_status=%s", send_status) self._handle_send_status(req, send_status, "walk") return req.contents.reqid From 787d7956ac109f970ec023a041a7e19a15de7f09 Mon Sep 17 00:00:00 2001 From: Jason Peacock Date: Tue, 8 Oct 2024 13:26:05 -0500 Subject: [PATCH 02/17] bump version to 0.42.1 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index fdd4f30..feb336c 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ setup( name="pynetsnmp", - version="0.42.0", + version="0.42.1", packages=find_packages(), install_requires=["setuptools"], include_package_data=True, From d3e4b49d20eb76ee327bce4dae9a776cb62e827f Mon Sep 17 00:00:00 2001 From: Jason Peacock Date: Tue, 8 Oct 2024 13:26:54 -0500 Subject: [PATCH 03/17] Refactored for formatting, better type matching, reduce function complexity, etc. --- pynetsnmp/netsnmp.py | 350 +++++++++++++++++++++++++------------------ 1 file changed, 201 insertions(+), 149 deletions(-) diff --git a/pynetsnmp/netsnmp.py b/pynetsnmp/netsnmp.py index 1ef00c3..cf05b81 100644 --- a/pynetsnmp/netsnmp.py +++ b/pynetsnmp/netsnmp.py @@ -12,7 +12,6 @@ Structure, Union, byref, - c_byte, c_char, c_char_p, c_double, @@ -97,12 +96,12 @@ def _getLogger(name): find_library_orig = find_library def find_library(name): - for name in [ + for filename in [ "/usr/lib/lib%s.so" % name, "/usr/local/lib/lib%s.so" % name, ]: - if os.path.exists(name): - return name + if os.path.exists(filename): + return filename return find_library_orig(name) @@ -122,8 +121,13 @@ def find_library(name): return find_library_orig(name) -c_int_p = c_void_p -authenticator = CFUNCTYPE(c_char_p, c_int_p, c_char_p, c_int) +oid = c_long +size_t = c_size_t +u_char = c_ubyte +u_char_p = POINTER(c_ubyte) +u_int = c_uint +u_long = c_ulong +u_short = c_ushort try: # needed by newer netsnmp's @@ -131,18 +135,17 @@ def find_library(name): except Exception: import warnings - warnings.warn("Unable to load crypto library") + warnings.warn("Unable to load crypto library", stacklevel=1) lib = CDLL(find_library("netsnmp"), RTLD_GLOBAL) lib.netsnmp_get_version.restype = c_char_p -oid = c_long -u_long = c_ulong -u_short = c_ushort -u_char_p = c_char_p -u_int = c_uint -size_t = c_size_t -u_char = c_byte +version = lib.netsnmp_get_version() +float_version = float(".".join(version.split(".")[:2])) +_netsnmp_str_version = tuple(str(v) for v in version.split(".")) + +if float_version < 5.099: + raise ImportError("netsnmp version 5.1 or greater is required") class netsnmp_session(Structure): @@ -173,7 +176,13 @@ class netsnmp_trap_stats(Structure): ] -# include/net-snmp/types.h -> int (*netsnmp_callback) (int, netsnmp_session *, int, netsnmp_pdu *, void *); +authenticator = CFUNCTYPE( + u_char_p, u_char_p, POINTER(c_size_t), u_char_p, c_size_t +) + + +# include/net-snmp/types.h +# int (*netsnmp_callback) (int, netsnmp_session *, int, netsnmp_pdu *, void *); # the first argument is the return type in CFUNCTYPE notation. netsnmp_callback = CFUNCTYPE( c_int, @@ -187,9 +196,6 @@ class netsnmp_trap_stats(Structure): # int (*proc)(int, char * const *, int) arg_parse_proc = CFUNCTYPE(c_int, POINTER(c_char_p), c_int) -version = lib.netsnmp_get_version() -float_version = float(".".join(version.split(".")[:2])) -_netsnmp_str_version = tuple(str(v) for v in version.split(".")) localname = [] paramName = [] transportConfig = [] @@ -203,8 +209,6 @@ class netsnmp_trap_stats(Structure): identifier = [] fGetTaddr = [] -if float_version < 5.099: - raise ImportError("netsnmp version 5.1 or greater is required") if float_version > 5.199: localname = [("localname", c_char_p)] if float_version > 5.299: @@ -224,11 +228,13 @@ class netsnmp_container_s(Structure): transportConfig = [ ("transport_configuration", POINTER(netsnmp_container_s)) ] -if _netsnmp_str_version >= ('5','8'): - # Version >= 5.8 broke binary compatibility, adding the trap_stats member to the netsnmp_session struct - trapStats = [('trap_stats', POINTER(netsnmp_trap_stats))] - # Version >= 5.8 broke binary compatibility, adding the msgMaxSize member to the snmp_pdu struct - msgMaxSize = [('msgMaxSize', c_long)] +if _netsnmp_str_version >= ("5", "8"): + # Version >= 5.8 broke binary compatibility, adding the trap_stats + # member to the netsnmp_session struct + trapStats = [("trap_stats", POINTER(netsnmp_trap_stats))] + # Version >= 5.8 broke binary compatibility, adding the msgMaxSize + # member to the snmp_pdu struct + msgMaxSize = [("msgMaxSize", c_long)] baseTransport = [("base_transport", POINTER(netsnmp_transport))] fOpen = [("f_open", c_void_p)] fConfig = [("f_config", c_void_p)] @@ -236,7 +242,8 @@ class netsnmp_container_s(Structure): fSetupSession = [("f_setup_session", c_void_p)] identifier = [("identifier", POINTER(u_char_p))] fGetTaddr = [("f_get_taddr", c_void_p)] - # Version >= 5.8 broke binary compatibility, doubling the size of these constants used for struct sizes + # Version >= 5.8 broke binary compatibility, doubling the size of these + # constants used for struct sizes USM_AUTH_KU_LEN = 64 USM_PRIV_KU_LEN = 64 @@ -262,7 +269,7 @@ class netsnmp_container_s(Structure): ("subsession", POINTER(netsnmp_session)), ("next", POINTER(netsnmp_session)), ("peername", c_char_p), - ("remote_port", u_short), + ("remote_port", u_short), # deprecated ] + localname + [ @@ -303,7 +310,8 @@ class netsnmp_container_s(Structure): ("securityModel", c_int), ("securityLevel", c_int), ] - + paramName + trapStats + + paramName + + trapStats + [ ("securityInfo", c_void_p), ] @@ -323,6 +331,7 @@ class counter64(Structure): ("low", c_ulong), ] + # include/net-snmp/types.h class netsnmp_vardata(Union): _fields_ = [ @@ -339,6 +348,7 @@ class netsnmp_vardata(Union): class netsnmp_variable_list(Structure): pass + # include/net-snmp/types.h netsnmp_variable_list._fields_ = [ ("next_variable", POINTER(netsnmp_variable_list)), @@ -354,45 +364,49 @@ class netsnmp_variable_list(Structure): ("index", c_int), ] # include/net-snmp/types.h -netsnmp_pdu._fields_ = [ - ("version", c_long), - ("command", c_int), - ("reqid", c_long), - ("msgid", c_long), - ("transid", c_long), - ("sessid", c_long), - ("errstat", c_long), - ("errindex", c_long), - ("time", c_ulong), - ("flags", c_ulong), - ("securityModel", c_int), - ("securityLevel", c_int), - ("msgParseModel", c_int), - ] + msgMaxSize + [ - ("transport_data", c_void_p), - ("transport_data_length", c_int), - ("tDomain", POINTER(oid)), - ("tDomainLen", c_size_t), - ("variables", POINTER(netsnmp_variable_list)), - ("community", c_char_p), - ("community_len", c_size_t), - ("enterprise", POINTER(oid)), - ("enterprise_length", c_size_t), - ("trap_type", c_long), - ("specific_type", c_long), - ("agent_addr", c_ubyte * 4), - ("contextEngineID", c_char_p), - ("contextEngineIDLen", c_size_t), - ("contextName", c_char_p), - ("contextNameLen", c_size_t), - ("securityEngineID", c_char_p), - ("securityEngineIDLen", c_size_t), - ("securityName", c_char_p), - ("securityNameLen", c_size_t), - ("priority", c_int), - ("range_subid", c_int), - ("securityStateRef", c_void_p), -] +netsnmp_pdu._fields_ = ( + [ + ("version", c_long), + ("command", c_int), + ("reqid", c_long), + ("msgid", c_long), + ("transid", c_long), + ("sessid", c_long), + ("errstat", c_long), + ("errindex", c_long), + ("time", c_ulong), + ("flags", c_ulong), + ("securityModel", c_int), + ("securityLevel", c_int), + ("msgParseModel", c_int), + ] + + msgMaxSize + + [ + ("transport_data", c_void_p), + ("transport_data_length", c_int), + ("tDomain", POINTER(oid)), + ("tDomainLen", c_size_t), + ("variables", POINTER(netsnmp_variable_list)), + ("community", c_char_p), + ("community_len", c_size_t), + ("enterprise", POINTER(oid)), + ("enterprise_length", c_size_t), + ("trap_type", c_long), + ("specific_type", c_long), + ("agent_addr", c_ubyte * 4), + ("contextEngineID", c_char_p), + ("contextEngineIDLen", c_size_t), + ("contextName", c_char_p), + ("contextNameLen", c_size_t), + ("securityEngineID", c_char_p), + ("securityEngineIDLen", c_size_t), + ("securityName", c_char_p), + ("securityNameLen", c_size_t), + ("priority", c_int), + ("range_subid", c_int), + ("securityStateRef", c_void_p), + ] +) netsnmp_pdu_p = POINTER(netsnmp_pdu) @@ -404,7 +418,9 @@ class netsnmp_log_message(Structure): netsnmp_log_message_p = POINTER(netsnmp_log_message) -# callback.h typedef int (SNMPCallback) (int majorID, int minorID, void *serverarg, void *clientarg); +# callback.h +# typedef int (SNMPCallback) ( +# int majorID, int minorID, void *serverarg, void *clientarg); log_callback = CFUNCTYPE(c_int, c_int, netsnmp_log_message_p, c_void_p) # include/net-snmp/library/snmp_logging.h @@ -423,8 +439,10 @@ class netsnmp_log_message(Structure): LOG_DEBUG: logging.DEBUG, } + # snmplib/snmp_logging.c -> free(logh); -# include/net-snmp/output_api.h -> int snmp_log( int priority, const char *format, ...) +# include/net-snmp/output_api.h +# int snmp_log(int priority, const char *format, ...); # in net-snmp -> snmp_log(LOG_ERR|WARNING|INFO|DEBUG, msg) def netsnmp_logger(a, b, msg): msg = cast(msg, netsnmp_log_message_p) @@ -435,8 +453,9 @@ def netsnmp_logger(a, b, msg): netsnmp_logger = log_callback(netsnmp_logger) -# include/net-snmp/library/callback.h -> -# int snmp_register_callback(int major, int minor, SNMPCallback * new_callback, void *arg); +# include/net-snmp/library/callback.h +# int snmp_register_callback( +# int major, int minor, SNMPCallback * new_callback, void *arg); lib.snmp_register_callback( SNMP_CALLBACK_LIBRARY, SNMP_CALLBACK_LOGGING, netsnmp_logger, 0 ) @@ -445,39 +464,55 @@ def netsnmp_logger(a, b, msg): lib.snmp_open.restype = POINTER(netsnmp_session) # include/net-snmp/library/snmp_transport.h -netsnmp_transport._fields_ = [ - ("domain", POINTER(oid)), - ("domain_length", c_int), - ("local", u_char_p), - ("local_length", c_int), - ("remote", u_char_p), - ("remote_length", c_int), - ("sock", c_int), - ("flags", u_int), - ("data", c_void_p), - ("data_length", c_int), - ("msgMaxSize", c_size_t), - ] + baseTransport + [ - ("f_recv", c_void_p), - ("f_send", c_void_p), - ("f_close", c_void_p), - ] + fOpen + [ - ("f_accept", c_void_p), - ("f_fmtaddr", c_void_p), -] + fCopy + fCopy + fSetupSession + identifier + fGetTaddr - -# include/net-snmp/library/snmp_transport.h -> -# netsnmp_transport *netsnmp_tdomain_transport( const char *str, int local, const char *default_domain); +netsnmp_transport._fields_ = ( + [ + ("domain", POINTER(oid)), + ("domain_length", c_int), + ("local", u_char_p), + ("local_length", c_int), + ("remote", u_char_p), + ("remote_length", c_int), + ("sock", c_int), + ("flags", u_int), + ("data", c_void_p), + ("data_length", c_int), + ("msgMaxSize", c_size_t), + ] + + baseTransport + + [ + ("f_recv", c_void_p), + ("f_send", c_void_p), + ("f_close", c_void_p), + ] + + fOpen + + [ + ("f_accept", c_void_p), + ("f_fmtaddr", c_void_p), + ] + + fCopy + + fCopy + + fSetupSession + + identifier + + fGetTaddr +) + +# include/net-snmp/library/snmp_transport.h +# netsnmp_transport *netsnmp_tdomain_transport( +# const char *str, int local, const char *default_domain); lib.netsnmp_tdomain_transport.restype = POINTER(netsnmp_transport) -# include/net-snmp/library/snmp_api.h -> netsnmp_session *snmp_add( -# netsnmp_session *, struct netsnmp_transport_s *, -# int (*fpre_parse) (netsnmp_session *, struct netsnmp_transport_s *, void *, int), -# int (*fpost_parse) (netsnmp_session *, netsnmp_pdu *, int) -# ); +# include/net-snmp/library/snmp_api.h +# netsnmp_session *snmp_add( +# netsnmp_session *, +# struct netsnmp_transport_s *, +# int (*fpre_parse) ( +# netsnmp_session *, struct netsnmp_transport_s *, void *, int), +# int (*fpost_parse) (netsnmp_session *, netsnmp_pdu *, int) +# ); lib.snmp_add.restype = POINTER(netsnmp_session) -# include/net-snmp/session_api.h -> int snmp_add_var(netsnmp_pdu *, const oid *, size_t, char, const char *); +# include/net-snmp/session_api.h +# int snmp_add_var(netsnmp_pdu *, const oid *, size_t, char, const char *); lib.snmp_add_var.argtypes = [ netsnmp_pdu_p, POINTER(oid), @@ -488,7 +523,8 @@ def netsnmp_logger(a, b, msg): lib.get_uptime.restype = c_long -# include/net-snmp/session_api.h -> int snmp_send(netsnmp_session *, netsnmp_pdu *); +# include/net-snmp/session_api.h +# int snmp_send(netsnmp_session *, netsnmp_pdu *); lib.snmp_send.argtypes = (POINTER(netsnmp_session), netsnmp_pdu_p) lib.snmp_send.restype = c_int @@ -551,15 +587,11 @@ def decodeString(pdu): return "" -_valueToConstant = dict( - [ - (chr(getattr(CONSTANTS, k)), k) - for k in CONSTANTS.__dict__.keys() - if isinstance(getattr(CONSTANTS, k), int) - and getattr(CONSTANTS, k) >= 0 - and getattr(CONSTANTS, k) < 256 - ] -) +_valueToConstant = { + chr(_v): _k + for _k, _v in CONSTANTS.__dict__.items() + if isinstance(_v, int) and (0 <= _v < 256) +} decoder = { @@ -674,47 +706,58 @@ def initialize_session(sess, cmdLineArgs, kw): args = None kw = kw.copy() if cmdLineArgs: - cmdLine = [x for x in cmdLineArgs] - if isinstance(cmdLine[0], tuple): - result = [] - for opt, val in cmdLine: - result.append(opt) - result.append(val) - cmdLine = result - if kw.get("peername"): - cmdLine.append(kw["peername"]) - del kw["peername"] - args = parse_args(cmdLine, byref(sess)) + args = _init_from_args(sess, cmdLineArgs, kw) else: lib.snmp_sess_init(byref(sess)) for attr, value in kw.items(): pv = getattr(sess, attr, _NoAttribute) if pv is _NoAttribute: continue # Don't set invalid properties - if attr == "timeout": - # -1 means the property hasn't been set - if pv == -1: - # Converts seconds to microseconds - setattr(sess, attr, value * 1000000) - elif attr == "version": - # -1 means the property hasn't been set - if pv == -1: - setattr(sess, attr, value) - elif attr == "community": - # None means the property hasn't been set - if pv is None: - setattr(sess, attr, value) - setattr(sess, "community_len", len(value)) - elif attr == "community_len": - # Setting community_len on its own is a segfault waiting to happen - pass - else: - setattr(sess, attr, value) + _update_session(attr, value, pv, sess) return args -class Session(object): +def _init_from_args(sess, cmdLineArgs, kw): + cmdLine = list(cmdLineArgs) + if isinstance(cmdLine[0], tuple): + result = [] + for opt, val in cmdLine: + result.append(opt) + result.append(val) + cmdLine = result + if kw.get("peername"): + cmdLine.append(kw["peername"]) + del kw["peername"] + return parse_args(cmdLine, byref(sess)) + + +def _update_session(attr, value, pv, sess): + if attr == "timeout": + # -1 means 'timeout' hasn't been set + if pv == -1: + # Converts seconds to microseconds + setattr(sess, attr, value * 1000000) + elif attr == "version": + # -1 means 'version' hasn't been set + if pv == -1: + setattr(sess, attr, value) + elif attr == "community": + # None means 'community' hasn't been set + if pv is None: + setattr(sess, attr, value) + # Set 'community_len' at the same time because it's + # related to the value for the 'community' property. + sess.community_len = len(value) + elif attr == "community_len": + # Do nothing to avoid setting a 'community_len' value when no + # value has been set for 'community', otherwise, a segmentation + # fault can occur. + pass + else: + setattr(sess, attr, value) + +class Session(object): cb = None def __init__(self, cmdLineArgs=(), freeEtimelist=True, **kw): @@ -727,12 +770,14 @@ def __init__(self, cmdLineArgs=(), freeEtimelist=True, **kw): self._log = _getLogger("session") def _snmp_send(self, session, pdu): - """Allows execution of free_etimelist() after each snmp_send() call. + """ + Allows execution of free_etimelist() after each snmp_send() call. Executes lib.free_etimelist() after each lib.snmp_send() call if the - `freeEtimelist` attribute is set, or re-calls lib.snmp_send() otherwise. - This frees all the memory used by entries in the etimelist inside t he - net-snmp library, allowing the processing of devices with duplicated engineID. + `freeEtimelist` attribute is set, or re-calls lib.snmp_send() + otherwise. This frees all the memory used by entries in the + etimelist inside the net-snmp library, allowing the processing of + devices with duplicated engineID. Note: This feature is not supported by RFC. """ @@ -789,8 +834,8 @@ def awaitTraps( ) if fileno >= 0: os.dup2(fileno, transport.contents.sock) - sess = netsnmp_session() + sess = netsnmp_session() self.sess = pointer(sess) lib.snmp_sess_init(self.sess) sess.peername = SNMP_DEFAULT_PEERNAME @@ -830,7 +875,7 @@ def create_users(self, users): ) lib.usm_parse_create_usmUser("createUser", line) self._log.debug("create_users: created user: %s", user) - except StandardError as e: + except Exception as e: self._log.debug( "create_users: could not create user: %s: (%s: %s)", user, @@ -987,6 +1032,7 @@ def fdset2list(rd, n): result.append(i * 32 + j) return result + class netsnmp_large_fd_set(Structure): # This structure must be initialized by calling netsnmp_large_fd_set_init() # and must be cleaned up via netsnmp_large_fd_set_cleanup(). If this last @@ -995,7 +1041,7 @@ class netsnmp_large_fd_set(Structure): _fields_ = [ ("lfs_setsize", c_uint), ("lfs_setptr", POINTER(fdset)), - ("lfs_set", fdset) + ("lfs_set", fdset), ] @@ -1013,6 +1059,7 @@ def snmp_select_info(): t = timeout.tv_sec + timeout.tv_usec / 1e6 return fdset2list(rd, maxfd.value), t + def snmp_select_info2(): rd = netsnmp_large_fd_set() lib.netsnmp_large_fd_set_init(byref(rd), FD_SETSIZE) @@ -1022,7 +1069,9 @@ def snmp_select_info2(): timeout.tv_usec = 0 block = c_int(0) maxfd = c_int(MAXFD) - lib.snmp_select_info2(byref(maxfd), byref(rd), byref(timeout), byref(block)) + lib.snmp_select_info2( + byref(maxfd), byref(rd), byref(timeout), byref(block) + ) t = None if not block: t = timeout.tv_sec + timeout.tv_usec / 1e6 @@ -1035,11 +1084,13 @@ def snmp_select_info2(): lib.netsnmp_large_fd_set_cleanup(byref(rd)) return result, t + def snmp_read(fd): rd = fdset() rd[fd / 32] |= 1 << (fd % 32) lib.snmp_read(byref(rd)) + def snmp_read2(fd): rd = netsnmp_large_fd_set() lib.netsnmp_large_fd_set_init(byref(rd), FD_SETSIZE) @@ -1047,6 +1098,7 @@ def snmp_read2(fd): lib.snmp_read2(byref(rd)) lib.netsnmp_large_fd_set_cleanup(byref(rd)) + done = False From 89ae2dc3572f15d69321ed6db22c939d899b2941 Mon Sep 17 00:00:00 2001 From: Jason Peacock Date: Tue, 8 Oct 2024 13:27:03 -0500 Subject: [PATCH 04/17] Removed unnecessary `init_usm` call in awaitTraps. The call to `init_snmp` later in the same method calls `init_usm` for us. ZEN-35072 --- pynetsnmp/netsnmp.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pynetsnmp/netsnmp.py b/pynetsnmp/netsnmp.py index cf05b81..525434f 100644 --- a/pynetsnmp/netsnmp.py +++ b/pynetsnmp/netsnmp.py @@ -808,7 +808,6 @@ def awaitTraps( lib.netsnmp_ds_set_string( NETSNMP_DS_LIBRARY_ID, NETSNMP_DS_LIB_APPTYPE, "pynetsnmp" ) - lib.init_usm() if debug: lib.debug_register_tokens("snmp_parse") # or "ALL" for everything lib.snmp_set_do_debugging(1) From 553546f1622e4bebbcaa9c36930d819466548b04 Mon Sep 17 00:00:00 2001 From: Jason Peacock Date: Tue, 8 Oct 2024 14:49:12 -0500 Subject: [PATCH 05/17] Updated makefile to use the zenpackbuild image --- makefile | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/makefile b/makefile index e150de6..47eb470 100644 --- a/makefile +++ b/makefile @@ -1,17 +1,17 @@ -IMAGENAME = zenoss/build-tools -VERSION = 0.0.14 +IMAGENAME = zenoss/zenpackbuild +VERSION = ubuntu2204-7 TAG = $(IMAGENAME):$(VERSION) UID := $(shell id -u) GID := $(shell id -g) -DOCKER_COMMAND = docker run --rm -v $(PWD):/mnt -w /mnt -u $(UID):$(GID) $(TAG) +DOCKER_COMMAND = docker run --rm -v $(PWD):/mnt -w /mnt $(TAG) .DEFAULT_GOAL := build .PHONY: bdist bdist: - @$(DOCKER_COMMAND) bash -c "python setup.py bdist_wheel" + $(DOCKER_COMMAND) bash -c "python setup.py bdist_wheel" .PHONY: sdist sdist: From aa92d73066fe34b1beb1562789de0cb4a81eaaa7 Mon Sep 17 00:00:00 2001 From: Jason Peacock Date: Tue, 8 Oct 2024 14:54:42 -0500 Subject: [PATCH 06/17] removed obsolete 'pkg' script. --- pkg | 37 ------------------------------------- 1 file changed, 37 deletions(-) delete mode 100755 pkg diff --git a/pkg b/pkg deleted file mode 100755 index 6471d85..0000000 --- a/pkg +++ /dev/null @@ -1,37 +0,0 @@ -#! /bin/sh -# -# Script used by the Zenoss Dev team (the authors of pynetsnmp) to -# release and deploy updated versions. First, go edit version.py -# and update the VERSION value. -# -# Run this script. This will generate a tarball in the parent -# directory (..). Move this script to your zenoss inst/externallibs -# directory and remove the old tarball and add the new one: -# -# $ svn remove pynetsnmp-OLDVERSION.tar.gz -# $ svn add pynetsnmp-NEWVERSION.tar.gz -# -# Then you are set to re-release Zenoss. -# - -quit() { - echo $@ - exit 1 -} -PACKAGE=pynetsnmp -VERSION=`python -c 'import version; print version.VERSION'` -VPACKAGE=$PACKAGE-$VERSION -SVN=http://dev.zenoss.org/svnint -SVNTRUNK=$SVN/trunk/core/$PACKAGE -SVNTAG=$SVN/tags/core/$VPACKAGE -svn cp -m"making release $VERSION" $SVNTRUNK $SVNTAG || quit cannot create tag -svn export $SVNTAG /tmp/$VPACKAGE || quit cannot create export tree -OLD=`pwd` -( - cd /tmp - tar -czvf $OLD/../$VPACKAGE.tar.gz $VPACKAGE - rm -rf $VPACKAGE -) || quit cannot create tarball -echo "Remember to move ../$VPACKAGE.tar.gz the Zenoss " -echo "inst/externallibs directory and check it in." -exit 0 From 2652fc00368aa4744f4146dd10a5656ad9ce95b8 Mon Sep 17 00:00:00 2001 From: Jason Peacock Date: Tue, 8 Oct 2024 15:02:26 -0500 Subject: [PATCH 07/17] don't specify user 0 --- makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/makefile b/makefile index 47eb470..2956abe 100644 --- a/makefile +++ b/makefile @@ -27,7 +27,7 @@ clean: .PHONY: test HOST ?= 127.0.0.1 test: - docker run --rm -v $(PWD):/mnt -w /mnt --user 0 $(TAG) \ + docker run --rm -v $(PWD):/mnt -w /mnt $(TAG) \ bash -c "python setup.py bdist_wheel \ && pip install dist/pynetsnmp*py2-none-any.whl ipaddr Twisted==20.3.0 \ && cd test \ From 0de23b6de8290759ffbc7813d51725ee5291fb6c Mon Sep 17 00:00:00 2001 From: Jason Peacock Date: Fri, 11 Oct 2024 10:18:19 -0500 Subject: [PATCH 08/17] bump version to 0.43.0 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index feb336c..6a1c130 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ setup( name="pynetsnmp", - version="0.42.1", + version="0.43.0", packages=find_packages(), install_requires=["setuptools"], include_package_data=True, From 0b22a8cd48970891a0aa6986d2bb2c2d4769c2e1 Mon Sep 17 00:00:00 2001 From: Jason Peacock Date: Thu, 10 Oct 2024 15:33:09 -0500 Subject: [PATCH 09/17] Quote arguments used for creating SNMP users. ZEN-35103 --- pynetsnmp/netsnmp.py | 49 +++++++++++++++++++++++--------------------- 1 file changed, 26 insertions(+), 23 deletions(-) diff --git a/pynetsnmp/netsnmp.py b/pynetsnmp/netsnmp.py index 525434f..4966a1a 100644 --- a/pynetsnmp/netsnmp.py +++ b/pynetsnmp/netsnmp.py @@ -858,29 +858,28 @@ def awaitTraps( def create_users(self, users): self._log.debug("create_users: Creating %s users.", len(users)) for user in users: - if user.version == 3: - try: - line = "" - if user.engine_id: - line = "-e {} ".format(user.engine_id) - line += " ".join( - [ - user.username, - user.authentication_type, # MD5 or SHA - user.authentication_passphrase, - user.privacy_protocol, # DES or AES - user.privacy_passphrase, - ] - ) - lib.usm_parse_create_usmUser("createUser", line) - self._log.debug("create_users: created user: %s", user) - except Exception as e: - self._log.debug( - "create_users: could not create user: %s: (%s: %s)", - user, - e.__class__.__name__, - e, - ) + if user.version != SNMP_VERSION_3: + continue + try: + line = "" + if user.engine_id: + line = "-e '{}' ".format(user.engine_id) + line += "'{}' '{}' '{}' '{}' '{}'".format( + _escape_char("'", user.username), + _escape_char("'", user.authentication_type), + _escape_char("'", user.authentication_passphrase), + _escape_char("'", user.privacy_protocol), + _escape_char("'", user.privacy_passphrase), + ) + lib.usm_parse_create_usmUser("createUser", line) + self._log.debug("create_users: created user: %s", user) + except Exception as e: + self._log.debug( + "create_users: could not create user: %s: (%s: %s)", + user, + e.__class__.__name__, + e, + ) def sendTrap(self, trapoid, varbinds=None): if "-v1" in self.cmdLineArgs: @@ -1009,6 +1008,10 @@ def walk(self, root): return req.contents.reqid +def _escape_char(char, text): + return text.replace(char, r"\{}".format(char)) + + MAXFD = 1024 FD_SETSIZE = MAXFD fdset = c_int32 * (MAXFD / 32) From 871b36e29e05b64e52b31ea9faf949d88aec9396 Mon Sep 17 00:00:00 2001 From: Jason Peacock Date: Mon, 14 Oct 2024 07:51:52 -0500 Subject: [PATCH 10/17] Change logger for snmp library logging callback. Using a different name will differentiate libnetsnmp's log messages from pynetsnmp's log messages. --- pynetsnmp/netsnmp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pynetsnmp/netsnmp.py b/pynetsnmp/netsnmp.py index 4966a1a..4c8bcc3 100644 --- a/pynetsnmp/netsnmp.py +++ b/pynetsnmp/netsnmp.py @@ -447,7 +447,7 @@ class netsnmp_log_message(Structure): def netsnmp_logger(a, b, msg): msg = cast(msg, netsnmp_log_message_p) priority = PRIORITY_MAP.get(msg.contents.priority, logging.DEBUG) - _getLogger("netsnmp").log(priority, str(msg.contents.msg).strip()) + _getLogger("libnetsnmp").log(priority, str(msg.contents.msg).strip()) return 0 From 8e160e03e77e5482fa88de876aabeffb52bd516e Mon Sep 17 00:00:00 2001 From: Jason Peacock Date: Mon, 28 Oct 2024 12:21:29 -0500 Subject: [PATCH 11/17] Add types for SNMP security fields. The changes made are to support clients specifying the configuration without having to know about net-snmp's command line arguments. ZEN-35109 --- pynetsnmp/SnmpSession.py | 4 +- pynetsnmp/conversions.py | 29 ++++- pynetsnmp/netsnmp.py | 8 +- pynetsnmp/security.py | 111 ++++++++++++++++++++ pynetsnmp/twistedsnmp.py | 221 ++++++++++++++++++++------------------- pynetsnmp/usm.py | 91 ++++++++++++++++ 6 files changed, 350 insertions(+), 114 deletions(-) create mode 100644 pynetsnmp/security.py create mode 100644 pynetsnmp/usm.py diff --git a/pynetsnmp/SnmpSession.py b/pynetsnmp/SnmpSession.py index 1e5b118..f9d969d 100644 --- a/pynetsnmp/SnmpSession.py +++ b/pynetsnmp/SnmpSession.py @@ -1,7 +1,7 @@ -from __future__ import absolute_import - """Backwards compatible API for SnmpSession""" +from __future__ import absolute_import + from . import netsnmp diff --git a/pynetsnmp/conversions.py b/pynetsnmp/conversions.py index beac056..0398a01 100644 --- a/pynetsnmp/conversions.py +++ b/pynetsnmp/conversions.py @@ -1,11 +1,36 @@ from __future__ import absolute_import +from ipaddr import IPAddress + def asOidStr(oid): """converts an oid int sequence to an oid string""" - return "." + ".".join([str(x) for x in oid]) + return "." + ".".join(str(x) for x in oid) def asOid(oidStr): """converts an OID string into a tuple of integers""" - return tuple([int(x) for x in oidStr.strip(".").split(".")]) + return tuple(int(x) for x in oidStr.strip(".").split(".")) + + +def asAgent(ip, port): + """take a google ipaddr object and port number and produce a net-snmp + agent specification (see the snmpcmd manpage)""" + ip, interface = ip.split("%") if "%" in ip else (ip, None) + address = IPAddress(ip) + + if address.version == 4: + return "udp:{}:{}".format(address.compressed, port) + + if address.version == 6: + if address.is_link_local: + if interface is None: + raise RuntimeError( + "Cannot create agent specification from link local " + "IPv6 address without an interface" + ) + else: + return "udp6:[{}%{}]:{}".format( + address.compressed, interface, port + ) + return "udp6:[{}]:{}".format(address.compressed, port) diff --git a/pynetsnmp/netsnmp.py b/pynetsnmp/netsnmp.py index 4c8bcc3..c5d6c39 100644 --- a/pynetsnmp/netsnmp.py +++ b/pynetsnmp/netsnmp.py @@ -684,9 +684,7 @@ def _doNothingProc(argc, argv, arg): def parse_args(args, session): - args = [ - sys.argv[0], - ] + args + args = [sys.argv[0]] + args argc = len(args) argv = (c_char_p * argc)() for i in range(argc): @@ -694,7 +692,9 @@ def parse_args(args, session): argv[i] = create_string_buffer(args[i]).raw # WARNING: Usage of snmp_parse_args call causes memory leak. if lib.snmp_parse_args(argc, argv, session, "", _doNothingProc) < 0: - raise ArgumentParseError("Unable to parse arguments", " ".join(argv)) + raise ArgumentParseError( + "Unable to parse arguments arguments='{}'".format(" ".join(argv)) + ) # keep a reference to the args for as long as sess is alive return argv diff --git a/pynetsnmp/security.py b/pynetsnmp/security.py new file mode 100644 index 0000000..302761c --- /dev/null +++ b/pynetsnmp/security.py @@ -0,0 +1,111 @@ +from __future__ import absolute_import + +from .CONSTANTS import SNMP_VERSION_1, SNMP_VERSION_2c, SNMP_VERSION_3 +from .usm import auth_protocols, priv_protocols + + +class Community(object): + """ + Provides the community based security model for SNMP v1/V2c. + """ + + def __init__(self, name, version=SNMP_VERSION_2c): + version = _version_map.get(version) + if version is None: + raise ValueError("Unsupported SNMP version '{}'".format(version)) + self.name = name + self.version = version + + def getArguments(self): + community = ("-c", str(self.name)) if self.name else () + return ("-v", self.version) + community + + +class UsmUser(object): + """ + Provides User-based Security Model configuration for SNMP v3. + """ + + def __init__(self, name, auth=None, priv=None, engine=None, context=None): + self.name = name + if not isinstance(auth, (type(None), Authentication)): + raise ValueError("invalid authentication protocol") + self.auth = auth + if not isinstance(auth, (type(None), Privacy)): + raise ValueError("invalid privacy protocol") + self.priv = priv + self.engine = engine + self.context = context + self.version = _version_map.get(SNMP_VERSION_3) + + def getArguments(self): + auth = ( + ("-a", str(self.auth.protocol), "-A", self.auth.passphrase) + if self.auth + else () + ) + if auth: + # The privacy arguments are only given if the authentication + # arguments are also provided. + priv = ( + ("-x", str(self.priv.protocol), "-X", self.priv.passphrase) + if self.priv + else () + ) + else: + priv = () + seclevel = ("-l", _sec_level.get((auth, priv), "noAuthNoPriv")) + + return ( + ("-v", self.version) + + (("-u", self.name) if self.name else ()) + + seclevel + + auth + + priv + + (("-e", self.engine) if self.engine else ()) + + (("-n", self.context) if self.context else ()) + ) + + +_sec_level = {(True, True): "authPriv", (True, False): "authNoPriv"} +_version_map = { + SNMP_VERSION_1: "1", + SNMP_VERSION_2c: "2c", + SNMP_VERSION_3: "3", + "v1": "1", + "v2c": "2c", + "v3": "3", +} + + +class Authentication(object): + """ + Provides the authentication data for UsmUser objects. + """ + + def __init__(self, protocol, passphrase): + if protocol is None: + raise ValueError( + "Invalid Authentication protocol '{}'".format(protocol) + ) + self.protocol = auth_protocols[protocol] + if not passphrase: + raise ValueError( + "authentication protocol requires an " + "authentication passphrase" + ) + self.passphrase = passphrase + + +class Privacy(object): + """ + Provides the privacy data for UsmUser objects. + """ + + def __init__(self, protocol, passphrase): + if protocol is None: + raise ValueError("Invalid Privacy protocol '{}'".format(protocol)) + self.protocol = priv_protocols[protocol] + if not passphrase: + raise ValueError("privacy protocol requires a privacy passphrase") + self.passphrase = passphrase diff --git a/pynetsnmp/twistedsnmp.py b/pynetsnmp/twistedsnmp.py index e5fae50..701e804 100644 --- a/pynetsnmp/twistedsnmp.py +++ b/pynetsnmp/twistedsnmp.py @@ -3,7 +3,6 @@ import logging import struct -from ipaddr import IPAddress from twisted.internet import defer, reactor from twisted.internet.selectreactor import SelectReactor from twisted.internet.error import TimeoutError @@ -31,7 +30,7 @@ SNMP_ERR_WRONGTYPE, SNMP_ERR_WRONGVALUE, ) -from .conversions import asOidStr, asOid +from .conversions import asAgent, asOidStr, asOid from .tableretriever import TableRetriever @@ -39,6 +38,10 @@ class Timer(object): callLater = None +DEFAULT_PORT = 161 +DEFAULT_TIMEOUT = 2 +DEFAULT_RETRIES = 6 + timer = Timer() fdMap = {} @@ -103,9 +106,12 @@ def updateReactor(): log.debug("reactor settings: %r, %r", fds, t) for fd in fds: if isSelect and fd > netsnmp.MAXFD: - log.error("fd > %d detected!!" + - " This will not work properly with the SelectReactor and is being ignored." + - " Timeouts will occur unless you switch to EPollReactor instead!") + log.error( + "fd > %d detected!! " + "This will not work properly with the SelectReactor and " + "is being ignored. Timeouts will occur unless you switch " + "to EPollReactor instead!" + ) continue if fd not in fdMap: @@ -130,34 +136,6 @@ def __init__(self, oid): Exception.__init__(self, "Bad Name", oid) -def _get_agent_spec(ipobj, interface, port): - """take a google ipaddr object and port number and produce a net-snmp - agent specification (see the snmpcmd manpage)""" - if ipobj.version == 4: - agent = "udp:%s:%s" % (ipobj.compressed, port) - elif ipobj.version == 6: - if ipobj.is_link_local: - if interface is None: - raise RuntimeError( - "Cannot create agent specification from link local " - "IPv6 address without an interface" - ) - else: - agent = "udp6:[%s%%%s]:%s" % ( - ipobj.compressed, - interface, - port, - ) - else: - agent = "udp6:[%s]:%s" % (ipobj.compressed, port) - else: - raise RuntimeError( - "Cannot create agent specification for IP address version: %s" - % ipobj.version - ) - return agent - - class SnmpError(Exception): def __init__(self, message, *args, **kwargs): self.message = message @@ -206,6 +184,30 @@ class AgentProxy(object): the SNMP query. The list is ordered correctly by the OID (i.e. it is not ordered by the OID string).""" + @classmethod + def create( + cls, + address, + security=None, + timeout=DEFAULT_TIMEOUT, + retries=DEFAULT_RETRIES, + ): + try: + ip, port = address + except ValueError: + port = DEFAULT_PORT + try: + ip = address.pop(0) + except AttributeError: + ip = address + return cls( + ip, + port=port, + security=security, + timeout=timeout, + tries=retries, + ) + def __init__( self, ip, @@ -213,15 +215,21 @@ def __init__( community="public", snmpVersion="1", protocol=None, - allowCache=False, + allowCache=False, # no longer used timeout=1.5, tries=3, cmdLineArgs=(), + security=None, ): + if security is not None: + self._security = security + self.snmpVersion = security.version + else: + self._security = None + self.snmpVersion = snmpVersion self.ip = ip self.port = port self.community = community - self.snmpVersion = snmpVersion self.timeout = timeout self.tries = tries self.cmdLineArgs = cmdLineArgs @@ -256,16 +264,59 @@ def _signSafePop(self, d, intkey): def callback(self, pdu): """netsnmp session callback""" - result = [] response = netsnmp.getResult(pdu, self._log) try: - d, oids_requested = self._signSafePop(self.defers, pdu.reqid) + d, oids_requested = self._pop_requested_oids(pdu, response) + except RuntimeError: + return + + result = tuple( + (oid, asOidStr(value) if isinstance(value, tuple) else value) + for oid, value in response + ) + + if len(result) == 1 and result[0][0] not in oids_requested: + usmStatsOidStr = asOidStr(result[0][0]) + if usmStatsOidStr in USM_STATS_OIDS: + msg = USM_STATS_OIDS.get(usmStatsOidStr) + reactor.callLater( + 0, d.errback, failure.Failure(Snmpv3Error(msg)) + ) + return + elif usmStatsOidStr == ".1.3.6.1.6.3.15.1.1.2.0": + # we may get a subsequent snmp result with the correct value + # if not the timeout will be called at some point + self.defers[pdu.reqid] = (d, oids_requested) + return + if pdu.errstat != SNMP_ERR_NOERROR: + pduError = PDU_ERRORS.get( + pdu.errstat, "Unknown error (%d)" % pdu.errstat + ) + message = "Packet for %s has error: %s" % (self.ip, pduError) + if pdu.errstat in ( + SNMP_ERR_NOACCESS, + SNMP_ERR_RESOURCEUNAVAILABLE, + SNMP_ERR_AUTHORIZATIONERROR, + ): + reactor.callLater( + 0, d.errback, failure.Failure(SnmpError(message)) + ) + return + else: + result = [] + self._log.warning(message + ". OIDS: %s", oids_requested) + + reactor.callLater(0, d.callback, result) + + def _pop_requested_oids(self, pdu, response): + try: + return self._signSafePop(self.defers, pdu.reqid) except KeyError: # We seem to end up here if we use bad credentials with authPriv. # The only reasonable thing to do is call all of the deferreds with # Snmpv3Errors. - for usmStatsOid, count in response: + for usmStatsOid, _ in response: usmStatsOidStr = asOidStr(usmStatsOid) if usmStatsOidStr == ".1.3.6.1.6.3.15.1.1.2.0": @@ -280,7 +331,7 @@ def callback(self, pdu): "devices use usmStatsNotInTimeWindows as a normal " "part of the SNMPv3 handshake." ) - return + raise RuntimeError("usmStatsNotInTimeWindows error") if usmStatsOidStr == ".1.3.6.1.2.1.1.1.0": # Some devices (Cisco Nexus/MDS) use sysDescr as a normal @@ -289,7 +340,7 @@ def callback(self, pdu): "Received sysDescr during handshake. Some devices use " "sysDescr as a normal part of the SNMPv3 handshake." ) - return + raise RuntimeError("sysDescr during handshake") default_msg = "packet dropped (OID: {0})".format( usmStatsOidStr @@ -306,44 +357,7 @@ def callback(self, pdu): 0, d.errback, failure.Failure(Snmpv3Error(message)) ) - return - - for oid, value in response: - if isinstance(value, tuple): - value = asOidStr(value) - result.append((oid, value)) - if len(result) == 1 and result[0][0] not in oids_requested: - usmStatsOidStr = asOidStr(result[0][0]) - if usmStatsOidStr in USM_STATS_OIDS: - msg = USM_STATS_OIDS.get(usmStatsOidStr) - reactor.callLater( - 0, d.errback, failure.Failure(Snmpv3Error(msg)) - ) - return - elif usmStatsOidStr == ".1.3.6.1.6.3.15.1.1.2.0": - # we may get a subsequent snmp result with the correct value - # if not the timeout will be called at some point - self.defers[pdu.reqid] = (d, oids_requested) - return - if pdu.errstat != SNMP_ERR_NOERROR: - pduError = PDU_ERRORS.get( - pdu.errstat, "Unknown error (%d)" % pdu.errstat - ) - message = "Packet for %s has error: %s" % (self.ip, pduError) - if pdu.errstat in ( - SNMP_ERR_NOACCESS, - SNMP_ERR_RESOURCEUNAVAILABLE, - SNMP_ERR_AUTHORIZATIONERROR, - ): - reactor.callLater( - 0, d.errback, failure.Failure(SnmpError(message)) - ) - return - else: - result = [] - self._log.warning(message + ". OIDS: %s", oids_requested) - - reactor.callLater(0, d.callback, result) + raise RuntimeError(message) def timeout_(self, reqid): d = self._signSafePop(self.defers, reqid)[0] @@ -357,18 +371,7 @@ def _getCmdLineArgs(self): if version == "2": version += "c" - if "%" in self.ip: - address, interface = self.ip.split("%") - else: - address = self.ip - interface = None - - self._log.debug( - "AgentProxy._getCmdLineArgs: using google ipaddr on %s", address - ) - - ipobj = IPAddress(address) - agent = _get_agent_spec(ipobj, interface, self.port) + agent = asAgent(self.ip, self.port) cmdLineArgs = list(self.cmdLineArgs) + [ "-v", @@ -388,17 +391,26 @@ def open(self): self.session.close() self.session = None - self.session = netsnmp.Session( - version=netsnmp.SNMP_VERSION_MAP.get( - self.snmpVersion, netsnmp.SNMP_VERSION_2c - ), - timeout=int(self.timeout), - retries=int(self.tries), - peername="%s:%d" % (self.ip, self.port), - community=self.community, - community_len=len(self.community), - cmdLineArgs=self._getCmdLineArgs(), - ) + if self._security: + agent = asAgent(self.ip, self.port) + cmdlineargs = self._security.getArguments() + ( + ("-t", str(self.timeout), "-r", str(self.tries), agent) + ) + self.session = netsnmp.Session( + cmdLineArgs=cmdlineargs + ) + else: + self.session = netsnmp.Session( + version=netsnmp.SNMP_VERSION_MAP.get( + self.snmpVersion, netsnmp.SNMP_VERSION_2c + ), + timeout=int(self.timeout), + retries=int(self.tries), + peername="%s:%d" % (self.ip, self.port), + community=self.community, + community_len=len(self.community), + cmdLineArgs=self._getCmdLineArgs(), + ) self.session.callback = self.callback self.session.timeout = self.timeout_ @@ -468,11 +480,8 @@ def getbulk(self, nonrepeaters, maxrepititions, oidStrs): return deferred def _convertToDict(self, result): - def strKey(item): - return asOidStr(item[0]), item[1] - - if isinstance(result, list): - return dict(map(strKey, result)) + if isinstance(result, (list, tuple)): + return {asOidStr(key): value for key, value in result} return result diff --git a/pynetsnmp/usm.py b/pynetsnmp/usm.py new file mode 100644 index 0000000..bccd675 --- /dev/null +++ b/pynetsnmp/usm.py @@ -0,0 +1,91 @@ +class _Protocol(object): + __slots__ = ("__name",) + + def __init__(self, name): + self.__name = name + + def __str__(self): + return self.__name + + def __repr__(self): + return "<{0.__module__}.{0.__name__} {1}>".format( + self.__class__, self.__name + ) + + +class _Protocols(object): + __slots__ = ("__protocols", "__kind") + + def __init__(self, protocols, kind): + self.__protocols = protocols + self.__kind = kind + + def __len__(self): + return len(self.__protocols) + + def __iter__(self): + return iter(self.__protocols) + + def __contains__(self, proto): + if proto not in self.__protocols: + return any(str(p) == proto for p in self.__protocols) + return True + + def __getitem__(self, name): + name = str(name) + proto = next((p for p in self.__protocols if str(p) == name), None) + if proto is None: + raise KeyError("No {} protocol '{}'".format(self.__kind, name)) + return proto + + def __repr__(self): + return "<{0.__module__}.{0.__name__} {1}>".format( + self.__class__, ", ".join(str(p) for p in self.__protocols) + ) + + +AUTH_MD5 = _Protocol("MD5") +AUTH_SHA = _Protocol("SHA") +AUTH_SHA_224 = _Protocol("SHA-224") +AUTH_SHA_256 = _Protocol("SHA-256") +AUTH_SHA_384 = _Protocol("SHA-384") +AUTH_SHA_512 = _Protocol("SHA-512") + +auth_protocols = _Protocols( + ( + AUTH_MD5, + AUTH_SHA, + AUTH_SHA_224, + AUTH_SHA_256, + AUTH_SHA_384, + AUTH_SHA_512, + ), + "authentication", +) + +PRIV_DES = _Protocol("DES") +PRIV_AES = _Protocol("AES") +PRIV_AES_192 = _Protocol("AES-192") +PRIV_AES_256 = _Protocol("AES-256") + +priv_protocols = _Protocols( + (PRIV_DES, PRIV_AES, PRIV_AES_192, PRIV_AES_256), "privacy" +) + +del _Protocol +del _Protocols + +__all__ = ( + "AUTH_MD5", + "AUTH_SHA", + "AUTH_SHA_224", + "AUTH_SHA_256", + "AUTH_SHA_384", + "AUTH_SHA_512", + "auth_protocols", + "PRIV_DES", + "PRIV_AES", + "PRIV_AES_192", + "PRIV_AES_256", + "priv_protocols", +) From 8aea1fcef210242f0233c0fa0247f2ebdbea94f3 Mon Sep 17 00:00:00 2001 From: Jason Peacock Date: Thu, 31 Oct 2024 09:21:46 -0500 Subject: [PATCH 12/17] Fix silly errors in the 'security' module. * Check the correct variable for whether it's a Privacy type. * Convert the inputs into bools first when performing security level lookup. --- pynetsnmp/security.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pynetsnmp/security.py b/pynetsnmp/security.py index 302761c..5757309 100644 --- a/pynetsnmp/security.py +++ b/pynetsnmp/security.py @@ -31,7 +31,7 @@ def __init__(self, name, auth=None, priv=None, engine=None, context=None): if not isinstance(auth, (type(None), Authentication)): raise ValueError("invalid authentication protocol") self.auth = auth - if not isinstance(auth, (type(None), Privacy)): + if not isinstance(priv, (type(None), Privacy)): raise ValueError("invalid privacy protocol") self.priv = priv self.engine = engine @@ -54,7 +54,10 @@ def getArguments(self): ) else: priv = () - seclevel = ("-l", _sec_level.get((auth, priv), "noAuthNoPriv")) + seclevel = ( + "-l", + _sec_level.get((bool(auth), bool(priv)), "noAuthNoPriv"), + ) return ( ("-v", self.version) From a721a28f1fd56cd92c241eb6a1b8e4d4aea113d2 Mon Sep 17 00:00:00 2001 From: Jason Peacock Date: Fri, 1 Nov 2024 15:41:59 -0500 Subject: [PATCH 13/17] Minor edit to a couple of exception messages. --- pynetsnmp/security.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pynetsnmp/security.py b/pynetsnmp/security.py index 5757309..b01bb9b 100644 --- a/pynetsnmp/security.py +++ b/pynetsnmp/security.py @@ -94,8 +94,7 @@ def __init__(self, protocol, passphrase): self.protocol = auth_protocols[protocol] if not passphrase: raise ValueError( - "authentication protocol requires an " - "authentication passphrase" + "Authentication protocol requires a passphrase" ) self.passphrase = passphrase @@ -110,5 +109,5 @@ def __init__(self, protocol, passphrase): raise ValueError("Invalid Privacy protocol '{}'".format(protocol)) self.protocol = priv_protocols[protocol] if not passphrase: - raise ValueError("privacy protocol requires a privacy passphrase") + raise ValueError("Privacy protocol requires a passphrase") self.passphrase = passphrase From a3d527c59efe785bac62b4e1deef5f6ba961431c Mon Sep 17 00:00:00 2001 From: Jason Peacock Date: Thu, 7 Nov 2024 16:09:34 -0600 Subject: [PATCH 14/17] Improve API regarding SNMP v3 credentials. The Authentication, Privacy, and related classes now support equality operations. The Session.create_users method now assumes it's passed objects of type UsmUser and excludes incomplete security options from the argument string. ZEN-35108 --- pynetsnmp/netsnmp.py | 30 ++++++++++++++-------- pynetsnmp/security.py | 59 +++++++++++++++++++++++++++++++++++++++---- pynetsnmp/usm.py | 17 ++++++++++--- 3 files changed, 86 insertions(+), 20 deletions(-) diff --git a/pynetsnmp/netsnmp.py b/pynetsnmp/netsnmp.py index c5d6c39..43dab9d 100644 --- a/pynetsnmp/netsnmp.py +++ b/pynetsnmp/netsnmp.py @@ -858,20 +858,28 @@ def awaitTraps( def create_users(self, users): self._log.debug("create_users: Creating %s users.", len(users)) for user in users: - if user.version != SNMP_VERSION_3: + if str(user.version) != str(SNMP_VERSION_3): + self._log.info("create_users: user is not v3 %s", user) continue try: line = "" - if user.engine_id: - line = "-e '{}' ".format(user.engine_id) - line += "'{}' '{}' '{}' '{}' '{}'".format( - _escape_char("'", user.username), - _escape_char("'", user.authentication_type), - _escape_char("'", user.authentication_passphrase), - _escape_char("'", user.privacy_protocol), - _escape_char("'", user.privacy_passphrase), - ) - lib.usm_parse_create_usmUser("createUser", line) + if user.engine: + line = "-e '{}'".format(user.engine) + if user.name: + line += " '{}'".format( + _escape_char("'", user.name), + ) + if user.auth: + line += " '{}' '{}'".format( + _escape_char("'", user.auth.protocol.name), + _escape_char("'", user.auth.passphrase), + ) + if user.priv: + line += " '{}' '{}'".format( + _escape_char("'", user.priv.protocol.name), + _escape_char("'", user.priv.passphrase), + ) + lib.usm_parse_create_usmUser("createUser", line.strip()) self._log.debug("create_users: created user: %s", user) except Exception as e: self._log.debug( diff --git a/pynetsnmp/security.py b/pynetsnmp/security.py index b01bb9b..21e4b4f 100644 --- a/pynetsnmp/security.py +++ b/pynetsnmp/security.py @@ -40,7 +40,7 @@ def __init__(self, name, auth=None, priv=None, engine=None, context=None): def getArguments(self): auth = ( - ("-a", str(self.auth.protocol), "-A", self.auth.passphrase) + ("-a", self.auth.protocol.name, "-A", self.auth.passphrase) if self.auth else () ) @@ -48,7 +48,7 @@ def getArguments(self): # The privacy arguments are only given if the authentication # arguments are also provided. priv = ( - ("-x", str(self.priv.protocol), "-X", self.priv.passphrase) + ("-x", self.priv.protocol.name, "-X", self.priv.passphrase) if self.priv else () ) @@ -69,6 +69,31 @@ def getArguments(self): + (("-n", self.context) if self.context else ()) ) + def __eq__(self, other): + return ( + self.name == other.name + and self.auth == other.auth + and self.priv == other.priv + and self.engine == other.engine + and self.context == other.context + ) + + def __str__(self): + info = ", ".join( + "{0}={1}".format(k, v) + for k, v in ( + ("name", self.name), + ("auth", self.auth), + ("priv", self.priv), + ("engine", self.engine), + ("context", self.context), + ) + if v + ) + return "{0.__class__.__name__}(version={0.version}{1}{2})".format( + self, ", " if info else "", info + ) + _sec_level = {(True, True): "authPriv", (True, False): "authNoPriv"} _version_map = { @@ -86,6 +111,8 @@ class Authentication(object): Provides the authentication data for UsmUser objects. """ + __slots__ = ("protocol", "passphrase") + def __init__(self, protocol, passphrase): if protocol is None: raise ValueError( @@ -93,17 +120,28 @@ def __init__(self, protocol, passphrase): ) self.protocol = auth_protocols[protocol] if not passphrase: - raise ValueError( - "Authentication protocol requires a passphrase" - ) + raise ValueError("Authentication protocol requires a passphrase") self.passphrase = passphrase + def __eq__(self, other): + if not isinstance(other, Authentication): + return NotImplemented + return ( + self.protocol == other.protocol + and self.passphrase == other.passphrase + ) + + def __str__(self): + return "{0.__class__.__name__}(protocol={0.protocol})".format(self) + class Privacy(object): """ Provides the privacy data for UsmUser objects. """ + __slots__ = ("protocol", "passphrase") + def __init__(self, protocol, passphrase): if protocol is None: raise ValueError("Invalid Privacy protocol '{}'".format(protocol)) @@ -111,3 +149,14 @@ def __init__(self, protocol, passphrase): if not passphrase: raise ValueError("Privacy protocol requires a passphrase") self.passphrase = passphrase + + def __eq__(self, other): + if not isinstance(other, Privacy): + return NotImplemented + return ( + self.protocol == other.protocol + and self.passphrase == other.passphrase + ) + + def __str__(self): + return "{0.__class__.__name__}(protocol={0.protocol})".format(self) diff --git a/pynetsnmp/usm.py b/pynetsnmp/usm.py index bccd675..2e9a422 100644 --- a/pynetsnmp/usm.py +++ b/pynetsnmp/usm.py @@ -1,15 +1,24 @@ +from __future__ import absolute_import + class _Protocol(object): - __slots__ = ("__name",) + """ """ + + __slots__ = ("name",) def __init__(self, name): - self.__name = name + self.name = name + + def __eq__(self, other): + if not isinstance(other, type(self)): + return NotImplemented + return self.name == other.name def __str__(self): - return self.__name + return self.name def __repr__(self): return "<{0.__module__}.{0.__name__} {1}>".format( - self.__class__, self.__name + self.__class__, self.name ) From 2fe4448a30e612bf5681a053b039654ce539dfc9 Mon Sep 17 00:00:00 2001 From: Jason Peacock Date: Thu, 14 Nov 2024 11:22:24 -0600 Subject: [PATCH 15/17] Add logging handler for USM security errors. Add no-auth and no-priv protocol objects. Add the OIDs for all the auth and priv protocol objects. ZEN-35146 --- pynetsnmp/CONSTANTS.py | 6 +++--- pynetsnmp/netsnmp.py | 8 ++++++++ pynetsnmp/twistedsnmp.py | 13 ++++--------- pynetsnmp/usm.py | 36 +++++++++++++++++++++--------------- 4 files changed, 36 insertions(+), 27 deletions(-) diff --git a/pynetsnmp/CONSTANTS.py b/pynetsnmp/CONSTANTS.py index 61d699f..2828148 100644 --- a/pynetsnmp/CONSTANTS.py +++ b/pynetsnmp/CONSTANTS.py @@ -1,5 +1,5 @@ +NULL = 0 USM_LENGTH_OID_TRANSFORM = 10 -NULL = None MAX_CALLBACK_IDS = 2 MAX_CALLBACK_SUBIDS = 16 SNMP_CALLBACK_LIBRARY = 0 @@ -306,7 +306,8 @@ NETSNMP_CALLBACK_OP_SEND_FAILED = 3 NETSNMP_CALLBACK_OP_CONNECT = 4 NETSNMP_CALLBACK_OP_DISCONNECT = 5 -snmp_init_statistics = () +NETSNMP_CALLBACK_OP_RESEND = 6 +NETSNMP_CALLBACK_OP_SEC_ERROR = 7 STAT_SNMPUNKNOWNSECURITYMODELS = 0 STAT_SNMPINVALIDMSGS = 1 STAT_SNMPUNKNOWNPDUHANDLERS = 2 @@ -377,7 +378,6 @@ MAX_STATS = NETSNMP_STAT_MAX_STATS COMMUNITY_MAX_LEN = 256 SPRINT_MAX_LEN = 2560 -NULL = 0 TRUE = 1 FALSE = 0 READ = 1 diff --git a/pynetsnmp/netsnmp.py b/pynetsnmp/netsnmp.py index 43dab9d..4240a94 100644 --- a/pynetsnmp/netsnmp.py +++ b/pynetsnmp/netsnmp.py @@ -59,6 +59,7 @@ MAX_OID_LEN, NETSNMP_CALLBACK_OP_RECEIVED_MESSAGE, NETSNMP_CALLBACK_OP_TIMED_OUT, + NETSNMP_CALLBACK_OP_SEC_ERROR, NETSNMP_DS_LIB_APPTYPE, NETSNMP_DS_LIBRARY_ID, NETSNMP_LOGHANDLER_CALLBACK, @@ -658,6 +659,13 @@ def _callback(operation, sp, reqid, pdu, magic): sess.callback(pdu.contents) elif operation == NETSNMP_CALLBACK_OP_TIMED_OUT: sess.timeout(reqid) + elif operation == NETSNMP_CALLBACK_OP_SEC_ERROR: + _getLogger("callback").error( + "peer has rejected security credentials " + "peername=%s security-name=%s", + sp.contents.peername, + sp.contents.securityName, + ) else: _getLogger("callback").error("Unknown operation: %d", operation) except Exception as ex: diff --git a/pynetsnmp/twistedsnmp.py b/pynetsnmp/twistedsnmp.py index 701e804..9a4592f 100644 --- a/pynetsnmp/twistedsnmp.py +++ b/pynetsnmp/twistedsnmp.py @@ -201,11 +201,7 @@ def create( except AttributeError: ip = address return cls( - ip, - port=port, - security=security, - timeout=timeout, - tries=retries, + ip, port=port, security=security, timeout=timeout, tries=retries ) def __init__( @@ -214,7 +210,7 @@ def __init__( port=161, community="public", snmpVersion="1", - protocol=None, + protocol=None, # no longer used allowCache=False, # no longer used timeout=1.5, tries=3, @@ -390,15 +386,14 @@ def open(self): if self.session is not None: self.session.close() self.session = None + updateReactor() if self._security: agent = asAgent(self.ip, self.port) cmdlineargs = self._security.getArguments() + ( ("-t", str(self.timeout), "-r", str(self.tries), agent) ) - self.session = netsnmp.Session( - cmdLineArgs=cmdlineargs - ) + self.session = netsnmp.Session(cmdLineArgs=cmdlineargs) else: self.session = netsnmp.Session( version=netsnmp.SNMP_VERSION_MAP.get( diff --git a/pynetsnmp/usm.py b/pynetsnmp/usm.py index 2e9a422..7455e9f 100644 --- a/pynetsnmp/usm.py +++ b/pynetsnmp/usm.py @@ -1,24 +1,26 @@ from __future__ import absolute_import + class _Protocol(object): """ """ - __slots__ = ("name",) + __slots__ = ("name", "oid") - def __init__(self, name): + def __init__(self, name, oid): self.name = name + self.oid = oid def __eq__(self, other): if not isinstance(other, type(self)): return NotImplemented - return self.name == other.name + return self.name == other.name and self.oid == other.oid def __str__(self): return self.name def __repr__(self): - return "<{0.__module__}.{0.__name__} {1}>".format( - self.__class__, self.name + return "<{0.__module__}.{0.__name__} {1} {2}>".format( + self.__class__, self.name, ".".join(str(v) for v in self.oid) ) @@ -53,12 +55,13 @@ def __repr__(self): ) -AUTH_MD5 = _Protocol("MD5") -AUTH_SHA = _Protocol("SHA") -AUTH_SHA_224 = _Protocol("SHA-224") -AUTH_SHA_256 = _Protocol("SHA-256") -AUTH_SHA_384 = _Protocol("SHA-384") -AUTH_SHA_512 = _Protocol("SHA-512") +AUTH_NOAUTH = _Protocol("NOAUTH", (1, 3, 6, 1, 6, 3, 10, 1, 1, 1)) +AUTH_MD5 = _Protocol("MD5", (1, 3, 6, 1, 6, 3, 10, 1, 1, 2)) +AUTH_SHA = _Protocol("SHA", (1, 3, 6, 1, 6, 3, 10, 1, 1, 3)) +AUTH_SHA_224 = _Protocol("SHA-224", (1, 3, 6, 1, 6, 3, 10, 1, 1, 4)) +AUTH_SHA_256 = _Protocol("SHA-256", (1, 3, 6, 1, 6, 3, 10, 1, 1, 5)) +AUTH_SHA_384 = _Protocol("SHA-384", (1, 3, 6, 1, 6, 3, 10, 1, 1, 6)) +AUTH_SHA_512 = _Protocol("SHA-512", (1, 3, 6, 1, 6, 3, 10, 1, 1, 7)) auth_protocols = _Protocols( ( @@ -72,10 +75,11 @@ def __repr__(self): "authentication", ) -PRIV_DES = _Protocol("DES") -PRIV_AES = _Protocol("AES") -PRIV_AES_192 = _Protocol("AES-192") -PRIV_AES_256 = _Protocol("AES-256") +PRIV_NOPRIV = _Protocol("NOPRIV", (1, 3, 6, 1, 6, 3, 10, 1, 2, 1)) +PRIV_DES = _Protocol("DES", (1, 3, 6, 1, 6, 3, 10, 1, 2, 2)) +PRIV_AES = _Protocol("AES", (1, 3, 6, 1, 6, 3, 10, 1, 2, 4)) +PRIV_AES_192 = _Protocol("AES-192", (1, 3, 6, 1, 4, 1, 14832, 1, 3)) +PRIV_AES_256 = _Protocol("AES-256", (1, 3, 6, 1, 4, 1, 14832, 1, 4)) priv_protocols = _Protocols( (PRIV_DES, PRIV_AES, PRIV_AES_192, PRIV_AES_256), "privacy" @@ -85,6 +89,7 @@ def __repr__(self): del _Protocols __all__ = ( + "AUTH_NOAUTH", "AUTH_MD5", "AUTH_SHA", "AUTH_SHA_224", @@ -92,6 +97,7 @@ def __repr__(self): "AUTH_SHA_384", "AUTH_SHA_512", "auth_protocols", + "PRIV_NOPRIV", "PRIV_DES", "PRIV_AES", "PRIV_AES_192", From 20422092a8317444ab6132f01b7ef76381af39dc Mon Sep 17 00:00:00 2001 From: Jason Peacock Date: Tue, 19 Nov 2024 16:07:11 -0600 Subject: [PATCH 16/17] Replace more string with types. --- makefile | 17 +- pynetsnmp/errors.py | 61 +++++++ pynetsnmp/netsnmp.py | 25 +-- pynetsnmp/oids.py | 94 ++++++++++ pynetsnmp/security.py | 118 ++++++++---- pynetsnmp/twistedsnmp.py | 277 ++++++++++++---------------- pynetsnmp/usm.py | 17 +- tests/__init__.py | 0 tests/test_security.py | 382 +++++++++++++++++++++++++++++++++++++++ tests/test_usm.py | 114 ++++++++++++ 10 files changed, 874 insertions(+), 231 deletions(-) create mode 100644 pynetsnmp/errors.py create mode 100644 pynetsnmp/oids.py create mode 100644 tests/__init__.py create mode 100644 tests/test_security.py create mode 100644 tests/test_usm.py diff --git a/makefile b/makefile index 2956abe..ef8b6ac 100644 --- a/makefile +++ b/makefile @@ -25,11 +25,14 @@ clean: rm -rf *.pyc dist build pynetsnmp.egg-info .PHONY: test -HOST ?= 127.0.0.1 test: - docker run --rm -v $(PWD):/mnt -w /mnt $(TAG) \ - bash -c "python setup.py bdist_wheel \ - && pip install dist/pynetsnmp*py2-none-any.whl ipaddr Twisted==20.3.0 \ - && cd test \ - && python test_runner.py --host $(HOST) \ - && chown -R $(UID):$(GID) /mnt" ; + @$(DOCKER_COMMAND) bash -c "pip --no-python-version-warning install -q .; cd tests; python -m unittest discover" + +# HOST ?= 127.0.0.1 +# test: +# docker run --rm -v $(PWD):/mnt -w /mnt $(TAG) \ +# bash -c "python setup.py bdist_wheel \ +# && pip install dist/pynetsnmp*py2-none-any.whl ipaddr Twisted==20.3.0 \ +# && cd test \ +# && python test_runner.py --host $(HOST) \ +# && chown -R $(UID):$(GID) /mnt" ; diff --git a/pynetsnmp/errors.py b/pynetsnmp/errors.py new file mode 100644 index 0000000..6cc73e2 --- /dev/null +++ b/pynetsnmp/errors.py @@ -0,0 +1,61 @@ +from __future__ import absolute_import + +from . import oids + + +class SnmpTimeoutError(Exception): + pass + + +class ArgumentParseError(Exception): + pass + + +class TransportError(Exception): + pass + + +class SnmpNameError(Exception): + def __init__(self, oid): + Exception.__init__(self, "Bad Name", oid) + + +class SnmpError(Exception): + def __init__(self, message, *args, **kwargs): + self.message = message + + def __str__(self): + return self.message + + def __repr__(self): + return self.message + + +class SnmpUsmError(SnmpError): + pass + + +class SnmpUsmStatsError(SnmpUsmError): + def __init__(self, mesg, oid): + super(SnmpUsmStatsError, self).__init__(mesg) + self.oid = oid + + +_stats_oid_error_map = { + oids.WrongDigest: SnmpUsmStatsError( + "unexpected authentication digest", oids.WrongDigest + ), + oids.UnknownUserName: SnmpUsmStatsError( + "unknown user", oids.UnknownUserName + ), + oids.UnknownSecurityLevel: SnmpUsmStatsError( + "unknown or unavailable security level", oids.UnknownSecurityLevel + ), + oids.DecryptionError: SnmpUsmStatsError( + "privacy decryption error", oids.DecryptionError + ), +} + + +def get_stats_error(oid): + return _stats_oid_error_map.get(oid) diff --git a/pynetsnmp/netsnmp.py b/pynetsnmp/netsnmp.py index 4240a94..f893ae5 100644 --- a/pynetsnmp/netsnmp.py +++ b/pynetsnmp/netsnmp.py @@ -87,6 +87,7 @@ USM_AUTH_KU_LEN, USM_PRIV_KU_LEN, ) +from .errors import ArgumentParseError, SnmpTimeoutError def _getLogger(name): @@ -304,7 +305,7 @@ class netsnmp_container_s(Structure): ("securityAuthLocalKeyLen", c_size_t), ("securityPrivProto", POINTER(oid)), ("securityPrivProtoLen", c_size_t), - ("securityPrivKey", c_char * USM_PRIV_KU_LEN), + ("securityPrivKey", u_char * USM_PRIV_KU_LEN), ("securityPrivKeyLen", c_size_t), ("securityPrivLocalKey", c_char_p), ("securityPrivLocalKeyLen", c_size_t), @@ -638,16 +639,12 @@ def getResult(pdu, log): return result -class SnmpError(Exception): +class NetSnmpError(Exception): def __init__(self, why): lib.snmp_perror(why) Exception.__init__(self, why) -class SnmpTimeoutError(Exception): - pass - - sessionMap = {} @@ -676,14 +673,6 @@ def _callback(operation, sp, reqid, pdu, magic): _callback = netsnmp_callback(_callback) -class ArgumentParseError(Exception): - pass - - -class TransportError(Exception): - pass - - def _doNothingProc(argc, argv, arg): return 0 @@ -807,7 +796,7 @@ def open(self): ref = byref(sess) self.sess = lib.snmp_open(ref) if not self.sess: - raise SnmpError("snmp_open") + raise NetSnmpError("snmp_open") def awaitTraps( self, peername, fileno=-1, pre_parse_callback=None, debug=False @@ -834,7 +823,7 @@ def awaitTraps( lib.setup_engineID(None, None) transport = lib.netsnmp_tdomain_transport(peername, 1, "udp") if not transport: - raise SnmpError( + raise NetSnmpError( "Unable to create transport {peername}".format( peername=peername ) @@ -861,7 +850,7 @@ def awaitTraps( sess.isAuthoritative = SNMP_SESS_UNKNOWNAUTH rc = lib.snmp_add(self.sess, transport, pre_parse_callback, None) if not rc: - raise SnmpError("snmp_add") + raise NetSnmpError("snmp_add") def create_users(self, users): self._log.debug("create_users: Creating %s users.", len(users)) @@ -991,7 +980,7 @@ def _handle_send_status(self, req, send_status, send_type): lib.snmp_free_pdu(req) if snmperr.value == SNMPERR_TIMEOUT: raise SnmpTimeoutError() - raise SnmpError(msg_fmt % msg_args) + raise NetSnmpError(msg_fmt % msg_args) def get(self, oids): req = self._create_request(SNMP_MSG_GET) diff --git a/pynetsnmp/oids.py b/pynetsnmp/oids.py new file mode 100644 index 0000000..873991b --- /dev/null +++ b/pynetsnmp/oids.py @@ -0,0 +1,94 @@ +from __future__ import absolute_import + +from .conversions import asOidStr + + +class OID(object): + __slots__ = ("oid",) + + def __init__(self, oid): + super(OID, self).__setattr__("oid", oid) + + def __setattr__(self, key, value): + if key in OID.__slots__: + raise AttributeError( + "can't set attribute '{}' on 'OID' object".format(key) + ) + super(OID, self).__setattr__(key, value) + + def __eq__(this, that): + if isinstance(that, (tuple, list)): + return this.oid == that + if isinstance(that, OID): + return this.oid == that.oid + return NotImplemented + + def __ne__(this, that): + if isinstance(that, (tuple, list)): + return this.oid != that + if isinstance(that, OID): + return this.oid != that.oid + return NotImplemented + + def __hash__(self): + return hash(self.oid) + + def __repr__(self): + return "<{0.__module__}.{0.__class__.__name__} {1}>".format( + self, asOidStr(self.oid) + ) + + def __str__(self): + return asOidStr(self.oid) + + +_base_status_oid = (1, 3, 6, 1, 6, 3, 15, 1, 1) + + +class UnknownSecurityLevel(OID): + __slots__ = () + + +UnknownSecurityLevel = UnknownSecurityLevel(_base_status_oid + (1, 0)) + + +class NotInTimeWindow(OID): + __slots__ = () + + +NotInTimeWindow = NotInTimeWindow(_base_status_oid + (2, 0)) + + +class UnknownUserName(OID): + __slots__ = () + + +UnknownUserName = UnknownUserName(_base_status_oid + (3, 0)) + + +class UnknownEngineId(OID): + __slots__ = () + + +UnknownEngineId = UnknownEngineId(_base_status_oid + (4, 0)) + + +class WrongDigest(OID): + __slots__ = () + + +WrongDigest = WrongDigest(_base_status_oid + (5, 0)) + + +class DecryptionError(OID): + __slots__ = () + + +DecryptionError = DecryptionError(_base_status_oid + (6, 0)) + + +class SysDescr(OID): + __slots__ = () + + +SysDescr = SysDescr((1, 3, 6, 1, 2, 1, 1, 1, 0)) diff --git a/pynetsnmp/security.py b/pynetsnmp/security.py index 21e4b4f..0a48b40 100644 --- a/pynetsnmp/security.py +++ b/pynetsnmp/security.py @@ -1,7 +1,9 @@ from __future__ import absolute_import from .CONSTANTS import SNMP_VERSION_1, SNMP_VERSION_2c, SNMP_VERSION_3 -from .usm import auth_protocols, priv_protocols +from .usm import AUTH_NOAUTH, auth_protocols, PRIV_NOPRIV, priv_protocols + +__all__ = ("Community", "UsmUser", "Authentication", "Privacy") class Community(object): @@ -10,11 +12,13 @@ class Community(object): """ def __init__(self, name, version=SNMP_VERSION_2c): - version = _version_map.get(version) - if version is None: - raise ValueError("Unsupported SNMP version '{}'".format(version)) + mapped = _version_map.get(version) + if mapped is None or mapped == "3": + raise ValueError( + "SNMP version '{}' not supported for Community".format(version) + ) self.name = name - self.version = version + self.version = mapped def getArguments(self): community = ("-c", str(self.name)) if self.name else () @@ -28,43 +32,44 @@ class UsmUser(object): def __init__(self, name, auth=None, priv=None, engine=None, context=None): self.name = name - if not isinstance(auth, (type(None), Authentication)): - raise ValueError("invalid authentication protocol") + if auth is None: + auth = Authentication.new_noauth() + if not isinstance(auth, Authentication): + raise ValueError("invalid authentication object") self.auth = auth - if not isinstance(priv, (type(None), Privacy)): - raise ValueError("invalid privacy protocol") + if priv is None: + priv = Privacy.new_nopriv() + if not isinstance(priv, Privacy): + raise ValueError("invalid privacy object") self.priv = priv self.engine = engine self.context = context self.version = _version_map.get(SNMP_VERSION_3) def getArguments(self): - auth = ( + auth_args = ( ("-a", self.auth.protocol.name, "-A", self.auth.passphrase) if self.auth else () ) - if auth: + if auth_args: # The privacy arguments are only given if the authentication # arguments are also provided. - priv = ( + priv_args = ( ("-x", self.priv.protocol.name, "-X", self.priv.passphrase) if self.priv else () ) else: - priv = () - seclevel = ( - "-l", - _sec_level.get((bool(auth), bool(priv)), "noAuthNoPriv"), - ) + priv_args = () + seclevel_arg = ("-l", _sec_level[(bool(self.auth), bool(self.priv))]) return ( ("-v", self.version) + (("-u", self.name) if self.name else ()) - + seclevel - + auth - + priv + + seclevel_arg + + auth_args + + priv_args + (("-e", self.engine) if self.engine else ()) + (("-n", self.context) if self.context else ()) ) @@ -95,8 +100,15 @@ def __str__(self): ) -_sec_level = {(True, True): "authPriv", (True, False): "authNoPriv"} +_sec_level = { + (True, True): "authPriv", + (True, False): "authNoPriv", + (False, False): "noAuthNoPriv", +} _version_map = { + "1": "1", + "2c": "2c", + "3": "3", SNMP_VERSION_1: "1", SNMP_VERSION_2c: "2c", SNMP_VERSION_3: "3", @@ -113,15 +125,25 @@ class Authentication(object): __slots__ = ("protocol", "passphrase") + @classmethod + def new_noauth(cls): + return cls(None, None) + def __init__(self, protocol, passphrase): - if protocol is None: - raise ValueError( - "Invalid Authentication protocol '{}'".format(protocol) - ) - self.protocol = auth_protocols[protocol] - if not passphrase: - raise ValueError("Authentication protocol requires a passphrase") - self.passphrase = passphrase + if ( + not protocol + or protocol is AUTH_NOAUTH + or protocol == "AUTH_NOAUTH" + ): + self.protocol = AUTH_NOAUTH + self.passphrase = None + else: + self.protocol = auth_protocols[protocol] + if not passphrase: + raise ValueError( + "Authentication protocol requires a passphrase" + ) + self.passphrase = passphrase def __eq__(self, other): if not isinstance(other, Authentication): @@ -131,6 +153,14 @@ def __eq__(self, other): and self.passphrase == other.passphrase ) + def __nonzero__(self): + return self.protocol is not AUTH_NOAUTH + + def __repr__(self): + return ( + "<{0.__module__}.{0.__class__.__name__} protocol={0.protocol}>" + ).format(self) + def __str__(self): return "{0.__class__.__name__}(protocol={0.protocol})".format(self) @@ -142,13 +172,23 @@ class Privacy(object): __slots__ = ("protocol", "passphrase") + @classmethod + def new_nopriv(cls): + return cls(None, None) + def __init__(self, protocol, passphrase): - if protocol is None: - raise ValueError("Invalid Privacy protocol '{}'".format(protocol)) - self.protocol = priv_protocols[protocol] - if not passphrase: - raise ValueError("Privacy protocol requires a passphrase") - self.passphrase = passphrase + if ( + not protocol + or protocol is PRIV_NOPRIV + or protocol == "PRIV_NOPRIV" + ): + self.protocol = PRIV_NOPRIV + self.passphrase = None + else: + self.protocol = priv_protocols[protocol] + if not passphrase: + raise ValueError("Privacy protocol requires a passphrase") + self.passphrase = passphrase def __eq__(self, other): if not isinstance(other, Privacy): @@ -158,5 +198,13 @@ def __eq__(self, other): and self.passphrase == other.passphrase ) + def __nonzero__(self): + return self.protocol is not PRIV_NOPRIV + + def __repr__(self): + return ( + "<{0.__module__}.{0.__class__.__name__} protocol={0.protocol}>" + ).format(self) + def __str__(self): return "{0.__class__.__name__}(protocol={0.protocol})".format(self) diff --git a/pynetsnmp/twistedsnmp.py b/pynetsnmp/twistedsnmp.py index 9a4592f..52142de 100644 --- a/pynetsnmp/twistedsnmp.py +++ b/pynetsnmp/twistedsnmp.py @@ -8,7 +8,7 @@ from twisted.internet.error import TimeoutError from twisted.python import failure -from . import netsnmp +from . import netsnmp, oids from .CONSTANTS import ( SNMP_ERR_AUTHORIZATIONERROR, SNMP_ERR_BADVALUE, @@ -31,8 +31,11 @@ SNMP_ERR_WRONGVALUE, ) from .conversions import asAgent, asOidStr, asOid +from .errors import SnmpError, SnmpUsmError, get_stats_error from .tableretriever import TableRetriever +log = netsnmp._getLogger("agentproxy") + class Timer(object): callLater = None @@ -131,47 +134,6 @@ def updateReactor(): timer.callLater = reactor.callLater(t, checkTimeouts) -class SnmpNameError(Exception): - def __init__(self, oid): - Exception.__init__(self, "Bad Name", oid) - - -class SnmpError(Exception): - def __init__(self, message, *args, **kwargs): - self.message = message - - def __str__(self): - return self.message - - def __repr__(self): - return self.message - - -class Snmpv3Error(SnmpError): - pass - - -USM_STATS_OIDS = { - # usmStatsWrongDigests - ".1.3.6.1.6.3.15.1.1.5.0": ( - "check zSnmpAuthType and zSnmpAuthPassword, " - "packet did not include the expected digest value" - ), - # usmStatsUnknownUserNames - ".1.3.6.1.6.3.15.1.1.3.0": ( - "check zSnmpSecurityName, packet referenced an unknown user" - ), - # usmStatsUnsupportedSecLevels - ".1.3.6.1.6.3.15.1.1.1.0": ( - "packet requested an unknown or unavailable security level" - ), - # usmStatsDecryptionErrors - ".1.3.6.1.6.3.15.1.1.6.0": ( - "check zSnmpPrivType, packet could not be decrypted" - ), -} - - class AgentProxy(object): """The public methods on AgentProxy (get, walk, getbulk) expect input OIDs to be strings, and the result they produce is a dictionary. The @@ -229,42 +191,53 @@ def __init__( self.timeout = timeout self.tries = tries self.cmdLineArgs = cmdLineArgs - self.defers = {} + self.defers = _DeferredMap() self.session = None - self._log = netsnmp._getLogger("agentproxy") - def _signSafePop(self, d, intkey): - """ - Attempt to pop the item at intkey from dictionary d. - Upon failure, try to convert intkey from a signed to an unsigned - integer and try to pop again. + def open(self): + if self.session is not None: + self.session.close() + self.session = None + updateReactor() - This addresses potential integer rollover issues caused by the fact - that netsnmp_pdu.reqid is a c_long and the netsnmp_callback function - pointer definition specifies it as a c_int. See ZEN-4481. - """ - try: - return d.pop(intkey) - except KeyError as ex: - if intkey < 0: - self._log.debug("Negative ID for _signSafePop: %s", intkey) - # convert to unsigned, try that key - uintkey = struct.unpack("I", struct.pack("i", intkey))[0] - try: - return d.pop(uintkey) - except KeyError: - # Nothing by the unsigned key either, - # throw the original KeyError for consistency - raise ex - raise + if self._security: + agent = asAgent(self.ip, self.port) + cmdlineargs = self._security.getArguments() + ( + ("-t", str(self.timeout), "-r", str(self.tries), agent) + ) + self.session = netsnmp.Session(cmdLineArgs=cmdlineargs) + else: + self.session = netsnmp.Session( + version=netsnmp.SNMP_VERSION_MAP.get( + self.snmpVersion, netsnmp.SNMP_VERSION_2c + ), + timeout=int(self.timeout), + retries=int(self.tries), + peername="%s:%d" % (self.ip, self.port), + community=self.community, + community_len=len(self.community), + cmdLineArgs=self._getCmdLineArgs(), + ) + + self.session.callback = self.callback + self.session.timeout = self._handle_timeout + self.session.open() + updateReactor() + + def close(self): + if self.session is not None: + self.session.close() + self.session = None + updateReactor() def callback(self, pdu): """netsnmp session callback""" - response = netsnmp.getResult(pdu, self._log) + response = netsnmp.getResult(pdu, log) try: - d, oids_requested = self._pop_requested_oids(pdu, response) - except RuntimeError: + d, oids_requested = self.defers.pop(pdu.reqid) + except KeyError: + self._handle_missing_request(response) return result = tuple( @@ -273,23 +246,21 @@ def callback(self, pdu): ) if len(result) == 1 and result[0][0] not in oids_requested: - usmStatsOidStr = asOidStr(result[0][0]) - if usmStatsOidStr in USM_STATS_OIDS: - msg = USM_STATS_OIDS.get(usmStatsOidStr) - reactor.callLater( - 0, d.errback, failure.Failure(Snmpv3Error(msg)) - ) + statsOid = result[0][0] + error = get_stats_error(statsOid) + if error: + reactor.callLater(0, d.errback, failure.Failure(error)) return - elif usmStatsOidStr == ".1.3.6.1.6.3.15.1.1.2.0": + if statsOid == oids.NotInTimeWindow: # we may get a subsequent snmp result with the correct value # if not the timeout will be called at some point self.defers[pdu.reqid] = (d, oids_requested) return if pdu.errstat != SNMP_ERR_NOERROR: pduError = PDU_ERRORS.get( - pdu.errstat, "Unknown error (%d)" % pdu.errstat + pdu.errstat, "unknown error (%d)" % pdu.errstat ) - message = "Packet for %s has error: %s" % (self.ip, pduError) + message = "packet for %s has error: %s" % (self.ip, pduError) if pdu.errstat in ( SNMP_ERR_NOACCESS, SNMP_ERR_RESOURCEUNAVAILABLE, @@ -301,63 +272,53 @@ def callback(self, pdu): return else: result = [] - self._log.warning(message + ". OIDS: %s", oids_requested) + log.warning(message + ". OIDS: %s", oids_requested) reactor.callLater(0, d.callback, result) - def _pop_requested_oids(self, pdu, response): - try: - return self._signSafePop(self.defers, pdu.reqid) - except KeyError: - # We seem to end up here if we use bad credentials with authPriv. - # The only reasonable thing to do is call all of the deferreds with - # Snmpv3Errors. - for usmStatsOid, _ in response: - usmStatsOidStr = asOidStr(usmStatsOid) - - if usmStatsOidStr == ".1.3.6.1.6.3.15.1.1.2.0": - # Some devices use usmStatsNotInTimeWindows as a normal - # part of the SNMPv3 handshake (JIRA-1565). - # net-snmp automatically retries the request with the - # previous request_id and the values for - # msgAuthoritativeEngineBoots and - # msgAuthoritativeEngineTime from this error packet. - self._log.debug( - "Received a usmStatsNotInTimeWindows error. Some " - "devices use usmStatsNotInTimeWindows as a normal " - "part of the SNMPv3 handshake." - ) - raise RuntimeError("usmStatsNotInTimeWindows error") - - if usmStatsOidStr == ".1.3.6.1.2.1.1.1.0": - # Some devices (Cisco Nexus/MDS) use sysDescr as a normal - # part of the SNMPv3 handshake (JIRA-7943) - self._log.debug( - "Received sysDescr during handshake. Some devices use " - "sysDescr as a normal part of the SNMPv3 handshake." - ) - raise RuntimeError("sysDescr during handshake") - - default_msg = "packet dropped (OID: {0})".format( - usmStatsOidStr - ) - message = USM_STATS_OIDS.get(usmStatsOidStr, default_msg) - break - else: - message = "packet dropped" + def _handle_missing_request(self, response): + usmStatsOid, _ = next(iter(response), (None, None)) + + if usmStatsOid == oids.NotInTimeWindow: + # Some devices use usmStatsNotInTimeWindows as a normal part of + # the SNMPv3 handshake (JIRA-1565). net-snmp automatically + # retries the request with the previous request_id and the + # values for msgAuthoritativeEngineBoots and + # msgAuthoritativeEngineTime from this error packet. + log.debug( + "Received a usmStatsNotInTimeWindows error. Some " + "devices use usmStatsNotInTimeWindows as a normal " + "part of the SNMPv3 handshake." + ) + return - for d in ( - d for d, rOids in self.defers.itervalues() if not d.called - ): - reactor.callLater( - 0, d.errback, failure.Failure(Snmpv3Error(message)) + if usmStatsOid == oids.SysDescr: + # Some devices (Cisco Nexus/MDS) use sysDescr as a normal + # part of the SNMPv3 handshake (JIRA-7943) + log.debug( + "Received sysDescr during handshake. Some devices use " + "sysDescr as a normal part of the SNMPv3 handshake." + ) + return + + if usmStatsOid is not None: + error = get_stats_error(usmStatsOid) + if not error: + error = SnmpUsmError( + "packet dropped (OID: {0})".format(asOidStr(usmStatsOid)) ) + else: + error = SnmpUsmError("packet dropped") - raise RuntimeError(message) + for d in (d for d, _ in self.defers.itervalues() if not d.called): + reactor.callLater(0, d.errback, failure.Failure(error)) - def timeout_(self, reqid): - d = self._signSafePop(self.defers, reqid)[0] - reactor.callLater(0, d.errback, failure.Failure(TimeoutError())) + def _handle_timeout(self, reqid): + try: + d = self.defers.pop(reqid)[0] + reactor.callLater(0, d.errback, failure.Failure(TimeoutError())) + except KeyError: + log.warning("handled timeout for unknown request") def _getCmdLineArgs(self): if not self.cmdLineArgs: @@ -382,42 +343,6 @@ def _getCmdLineArgs(self): ] return cmdLineArgs - def open(self): - if self.session is not None: - self.session.close() - self.session = None - updateReactor() - - if self._security: - agent = asAgent(self.ip, self.port) - cmdlineargs = self._security.getArguments() + ( - ("-t", str(self.timeout), "-r", str(self.tries), agent) - ) - self.session = netsnmp.Session(cmdLineArgs=cmdlineargs) - else: - self.session = netsnmp.Session( - version=netsnmp.SNMP_VERSION_MAP.get( - self.snmpVersion, netsnmp.SNMP_VERSION_2c - ), - timeout=int(self.timeout), - retries=int(self.tries), - peername="%s:%d" % (self.ip, self.port), - community=self.community, - community_len=len(self.community), - cmdLineArgs=self._getCmdLineArgs(), - ) - - self.session.callback = self.callback - self.session.timeout = self.timeout_ - self.session.open() - updateReactor() - - def close(self): - if self.session is not None: - self.session.close() - self.session = None - updateReactor() - def _get(self, oids, timeout=None, retryCount=None): d = defer.Deferred() try: @@ -488,3 +413,25 @@ def port(self): snmpprotocol = _FakeProtocol() + + +class _DeferredMap(dict): + """ + Wrap the dict type to add extra behavior. + """ + + def pop(self, key): + """ + Attempt to pop the item at key from the dictionary. + """ + # Check for negative key to address potential integer rollover issues + # caused by the fact that netsnmp_pdu.reqid is a c_long and the + # netsnmp_callback function pointer definition specifies it as a + # c_int. See ZEN-4481. + if key not in self and key < 0: + log.debug("try negative ID for deferred map: %s", key) + # convert to unsigned, try that key + uintkey = struct.unpack("I", struct.pack("i", key))[0] + if uintkey in self: + key = uintkey + return super(_DeferredMap, self).pop(key) diff --git a/pynetsnmp/usm.py b/pynetsnmp/usm.py index 7455e9f..3606eb3 100644 --- a/pynetsnmp/usm.py +++ b/pynetsnmp/usm.py @@ -38,6 +38,8 @@ def __iter__(self): return iter(self.__protocols) def __contains__(self, proto): + if not proto: + proto = self.__noargs if proto not in self.__protocols: return any(str(p) == proto for p in self.__protocols) return True @@ -46,7 +48,9 @@ def __getitem__(self, name): name = str(name) proto = next((p for p in self.__protocols if str(p) == name), None) if proto is None: - raise KeyError("No {} protocol '{}'".format(self.__kind, name)) + raise KeyError( + "unknown {} protocol '{}'".format(self.__kind, name) + ) return proto def __repr__(self): @@ -65,6 +69,7 @@ def __repr__(self): auth_protocols = _Protocols( ( + AUTH_NOAUTH, AUTH_MD5, AUTH_SHA, AUTH_SHA_224, @@ -82,25 +87,25 @@ def __repr__(self): PRIV_AES_256 = _Protocol("AES-256", (1, 3, 6, 1, 4, 1, 14832, 1, 4)) priv_protocols = _Protocols( - (PRIV_DES, PRIV_AES, PRIV_AES_192, PRIV_AES_256), "privacy" + (PRIV_NOPRIV, PRIV_DES, PRIV_AES, PRIV_AES_192, PRIV_AES_256), "privacy" ) del _Protocol del _Protocols __all__ = ( - "AUTH_NOAUTH", "AUTH_MD5", + "AUTH_NOAUTH", + "auth_protocols", "AUTH_SHA", "AUTH_SHA_224", "AUTH_SHA_256", "AUTH_SHA_384", "AUTH_SHA_512", - "auth_protocols", - "PRIV_NOPRIV", - "PRIV_DES", "PRIV_AES", "PRIV_AES_192", "PRIV_AES_256", + "PRIV_DES", + "PRIV_NOPRIV", "priv_protocols", ) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_security.py b/tests/test_security.py new file mode 100644 index 0000000..3783922 --- /dev/null +++ b/tests/test_security.py @@ -0,0 +1,382 @@ +import unittest + +from pynetsnmp import security, usm + + +class TestCommunity(unittest.TestCase): + name = "public" + + def test_default(t): + c = security.Community(t.name) + t.assertEqual(c.name, t.name) + t.assertEqual(c.version, "2c") + expected = ("-v", "2c", "-c", t.name) + t.assertSequenceEqual(c.getArguments(), expected) + + def test_v1_constant(t): + c = security.Community(t.name, security.SNMP_VERSION_1) + t.assertEqual(c.name, t.name) + t.assertEqual(c.version, "1") + expected = ("-v", "1", "-c", t.name) + t.assertSequenceEqual(c.getArguments(), expected) + + def test_v1_v1(t): + c = security.Community(t.name, "v1") + t.assertEqual(c.name, t.name) + t.assertEqual(c.version, "1") + expected = ("-v", "1", "-c", t.name) + t.assertSequenceEqual(c.getArguments(), expected) + + def test_v2c_constant(t): + c = security.Community(t.name, security.SNMP_VERSION_2c) + t.assertEqual(c.name, t.name) + t.assertEqual(c.version, "2c") + expected = ("-v", "2c", "-c", t.name) + t.assertSequenceEqual(c.getArguments(), expected) + + def test_v2c_v2c(t): + c = security.Community(t.name, "v2c") + t.assertEqual(c.name, t.name) + t.assertEqual(c.version, "2c") + expected = ("-v", "2c", "-c", t.name) + t.assertSequenceEqual(c.getArguments(), expected) + + def test_v3_constant(t): + with t.assertRaises(ValueError): + security.Community(t.name, security.SNMP_VERSION_3) + + def test_v3_v3(t): + with t.assertRaises(ValueError): + security.Community(t.name, "v3") + + def test_none_version(t): + with t.assertRaises(ValueError): + security.Community(t.name, None) + + def test_not_a_version_str(t): + with t.assertRaises(ValueError): + security.Community(t.name, "oi") + + def test_not_a_version_number(t): + with t.assertRaises(ValueError): + security.Community(t.name, 3947) + + +class TestUsmUser(unittest.TestCase): + name = "john_doe" + passwd = "secured123" # noqa: S105 + + def test_default(t): + user = security.UsmUser(t.name) + t.assertEqual(t.name, user.name) + t.assertEqual(security.Authentication.new_noauth(), user.auth) + t.assertEqual(security.Privacy.new_nopriv(), user.priv) + t.assertIsNone(user.engine) + t.assertIsNone(user.context) + t.assertEqual(user.version, "3") + expected = ("-v", "3", "-u", t.name, "-l", "noAuthNoPriv") + t.assertSequenceEqual(expected, user.getArguments()) + + def test_engineid(t): + engineid = hex(3443489794829589283483234)[2:].strip("L") + user = security.UsmUser(t.name, engine=engineid) + t.assertEqual(t.name, user.name) + t.assertEqual(security.Authentication.new_noauth(), user.auth) + t.assertEqual(security.Privacy.new_nopriv(), user.priv) + t.assertEqual(engineid, user.engine) + t.assertIsNone(user.context) + expected = ( + "-v", + "3", + "-u", + t.name, + "-l", + "noAuthNoPriv", + "-e", + engineid, + ) + t.assertSequenceEqual(expected, user.getArguments()) + + def test_contextid(t): + contextid = hex(9084090984572743455234)[2:].strip("L") + user = security.UsmUser(t.name, context=contextid) + t.assertEqual(t.name, user.name) + t.assertEqual(security.Authentication.new_noauth(), user.auth) + t.assertEqual(security.Privacy.new_nopriv(), user.priv) + t.assertIsNone(user.engine) + t.assertEqual(contextid, user.context) + expected = ( + "-v", + "3", + "-u", + t.name, + "-l", + "noAuthNoPriv", + "-n", + contextid, + ) + t.assertSequenceEqual(expected, user.getArguments()) + + def test_auth(t): + auth = security.Authentication(usm.AUTH_SHA_224, t.passwd) + user = security.UsmUser(t.name, auth=auth) + t.assertEqual(t.name, user.name) + t.assertEqual(auth, user.auth) + t.assertEqual(security.Privacy.new_nopriv(), user.priv) + t.assertIsNone(user.engine) + t.assertIsNone(user.context) + t.assertEqual(user.version, "3") + expected = ( + "-v", + "3", + "-u", + t.name, + "-l", + "authNoPriv", + "-a", + auth.protocol.name, + "-A", + auth.passphrase, + ) + t.assertSequenceEqual(expected, user.getArguments()) + + def test_authpriv(t): + auth = security.Authentication(usm.AUTH_SHA_224, t.passwd) + priv = security.Privacy(usm.PRIV_AES_256, t.passwd) + user = security.UsmUser(t.name, auth=auth, priv=priv) + t.assertEqual(t.name, user.name) + t.assertEqual(auth, user.auth) + t.assertEqual(priv, user.priv) + t.assertIsNone(user.engine) + t.assertIsNone(user.context) + t.assertEqual(user.version, "3") + expected = ( + "-v", + "3", + "-u", + t.name, + "-l", + "authPriv", + "-a", + auth.protocol.name, + "-A", + auth.passphrase, + "-x", + priv.protocol.name, + "-X", + priv.passphrase, + ) + t.assertSequenceEqual(expected, user.getArguments()) + + def test_all_args(t): + auth = security.Authentication(usm.AUTH_SHA_224, t.passwd) + priv = security.Privacy(usm.PRIV_AES_256, t.passwd) + contextid = hex(9084090984572743455234)[2:].strip("L") + engineid = hex(3443489794829589283483234)[2:].strip("L") + user = security.UsmUser( + t.name, auth=auth, priv=priv, engine=engineid, context=contextid + ) + t.assertEqual(t.name, user.name) + t.assertEqual(auth, user.auth) + t.assertEqual(priv, user.priv) + t.assertEqual(engineid, user.engine) + t.assertEqual(contextid, user.context) + t.assertEqual(user.version, "3") + expected = ( + "-v", + "3", + "-u", + t.name, + "-l", + "authPriv", + "-a", + auth.protocol.name, + "-A", + auth.passphrase, + "-x", + priv.protocol.name, + "-X", + priv.passphrase, + "-e", + engineid, + "-n", + contextid, + ) + t.assertSequenceEqual(expected, user.getArguments()) + + def test_equality(t): + auth1 = security.Authentication(usm.AUTH_SHA_224, t.passwd) + user1 = security.UsmUser(t.name, auth=auth1) + auth2 = security.Authentication(usm.AUTH_SHA_224, t.passwd) + user2 = security.UsmUser(t.name, auth=auth2) + auth3 = security.Authentication(usm.AUTH_SHA_256, t.passwd) + priv3 = security.Privacy(usm.PRIV_AES_256, t.passwd) + user3 = security.UsmUser(t.name, auth=auth3, priv=priv3) + t.assertEqual(user1, user2) + t.assertNotEqual(user1, user3) + + +class TestAuthentication(unittest.TestCase): + passwd = "security123" # noqa: S105 + + def test_noauth_classmethod(t): + auth = security.Authentication.new_noauth() + t.assertEqual(usm.AUTH_NOAUTH, auth.protocol) + t.assertIsNone(auth.passphrase) + + def test_none_init(t): + auth = security.Authentication(None, None) + t.assertEqual(usm.AUTH_NOAUTH, auth.protocol) + t.assertIsNone(auth.passphrase) + + auth = security.Authentication(None, t.passwd) + t.assertEqual(usm.AUTH_NOAUTH, auth.protocol) + t.assertIsNone(auth.passphrase) + + def test_noauth_init(t): + auth = security.Authentication(usm.AUTH_NOAUTH, None) + t.assertEqual(usm.AUTH_NOAUTH, auth.protocol) + t.assertIsNone(auth.passphrase) + + auth = security.Authentication(usm.AUTH_NOAUTH, t.passwd) + t.assertEqual(usm.AUTH_NOAUTH, auth.protocol) + t.assertIsNone(auth.passphrase) + + def test_noauth_str_init(t): + auth = security.Authentication("AUTH_NOAUTH", None) + t.assertEqual(usm.AUTH_NOAUTH, auth.protocol) + t.assertIsNone(auth.passphrase) + + auth = security.Authentication("AUTH_NOAUTH", t.passwd) + t.assertEqual(usm.AUTH_NOAUTH, auth.protocol) + t.assertIsNone(auth.passphrase) + + def test_noauth_is_false(t): + auth = security.Authentication.new_noauth() + t.assertFalse(auth) + + def test_md5(t): + auth = security.Authentication(usm.AUTH_MD5, t.passwd) + t.assertTrue(auth) + t.assertEqual(auth.protocol, usm.AUTH_MD5) + t.assertEqual(auth.passphrase, t.passwd) + + def test_sha(t): + auth = security.Authentication(usm.AUTH_SHA, t.passwd) + t.assertTrue(auth) + t.assertEqual(auth.protocol, usm.AUTH_SHA) + t.assertEqual(auth.passphrase, t.passwd) + + def test_sha_224(t): + auth = security.Authentication(usm.AUTH_SHA_224, t.passwd) + t.assertTrue(auth) + t.assertEqual(auth.protocol, usm.AUTH_SHA_224) + t.assertEqual(auth.passphrase, t.passwd) + + def test_sha_256(t): + auth = security.Authentication(usm.AUTH_SHA_256, t.passwd) + t.assertTrue(auth) + t.assertEqual(auth.protocol, usm.AUTH_SHA_256) + t.assertEqual(auth.passphrase, t.passwd) + + def test_sha_384(t): + auth = security.Authentication(usm.AUTH_SHA_384, t.passwd) + t.assertTrue(auth) + t.assertEqual(auth.protocol, usm.AUTH_SHA_384) + t.assertEqual(auth.passphrase, t.passwd) + + def test_sha_512(t): + auth = security.Authentication(usm.AUTH_SHA_512, t.passwd) + t.assertTrue(auth) + t.assertEqual(auth.protocol, usm.AUTH_SHA_512) + t.assertEqual(auth.passphrase, t.passwd) + + def test_equal(t): + auth1 = security.Authentication(usm.AUTH_MD5, t.passwd) + auth2 = security.Authentication(usm.AUTH_MD5, t.passwd) + t.assertEqual(auth1, auth2) + + def test_not_equal(t): + auth1 = security.Authentication(usm.AUTH_MD5, t.passwd) + auth2 = security.Authentication(usm.AUTH_SHA, t.passwd) + t.assertNotEqual(auth1, auth2) + + auth3 = security.Authentication(usm.AUTH_SHA, t.passwd + "456") + t.assertNotEqual(auth2, auth3) + + +class TestPrivacy(unittest.TestCase): + passwd = "security123" # noqa: S105 + + def test_nopriv_classmethod(t): + priv = security.Privacy.new_nopriv() + t.assertEqual(usm.PRIV_NOPRIV, priv.protocol) + t.assertIsNone(priv.passphrase) + + def test_none_init(t): + priv = security.Privacy(None, None) + t.assertEqual(usm.PRIV_NOPRIV, priv.protocol) + t.assertIsNone(priv.passphrase) + + priv = security.Privacy(None, t.passwd) + t.assertEqual(usm.PRIV_NOPRIV, priv.protocol) + t.assertIsNone(priv.passphrase) + + def test_nopriv_init(t): + priv = security.Privacy(usm.PRIV_NOPRIV, None) + t.assertEqual(usm.PRIV_NOPRIV, priv.protocol) + t.assertIsNone(priv.passphrase) + + priv = security.Privacy(usm.PRIV_NOPRIV, t.passwd) + t.assertEqual(usm.PRIV_NOPRIV, priv.protocol) + t.assertIsNone(priv.passphrase) + + def test_nopriv_str_init(t): + priv = security.Privacy("PRIV_NOPRIV", None) + t.assertEqual(usm.PRIV_NOPRIV, priv.protocol) + t.assertIsNone(priv.passphrase) + + priv = security.Privacy("PRIV_NOPRIV", t.passwd) + t.assertEqual(usm.PRIV_NOPRIV, priv.protocol) + t.assertIsNone(priv.passphrase) + + def test_nopriv_is_false(t): + priv = security.Privacy.new_nopriv() + t.assertFalse(priv) + + def test_des(t): + priv = security.Privacy(usm.PRIV_DES, t.passwd) + t.assertTrue(priv) + t.assertEqual(priv.protocol, usm.PRIV_DES) + t.assertEqual(priv.passphrase, t.passwd) + + def test_aes(t): + priv = security.Privacy(usm.PRIV_AES, t.passwd) + t.assertTrue(priv) + t.assertEqual(priv.protocol, usm.PRIV_AES) + t.assertEqual(priv.passphrase, t.passwd) + + def test_aes_192(t): + priv = security.Privacy(usm.PRIV_AES_192, t.passwd) + t.assertTrue(priv) + t.assertEqual(priv.protocol, usm.PRIV_AES_192) + t.assertEqual(priv.passphrase, t.passwd) + + def test_aes_256(t): + priv = security.Privacy(usm.PRIV_AES_256, t.passwd) + t.assertTrue(priv) + t.assertEqual(priv.protocol, usm.PRIV_AES_256) + t.assertEqual(priv.passphrase, t.passwd) + + def test_equal(t): + priv1 = security.Privacy(usm.PRIV_DES, t.passwd) + priv2 = security.Privacy(usm.PRIV_DES, t.passwd) + t.assertEqual(priv1, priv2) + + def test_not_equal(t): + priv1 = security.Privacy(usm.PRIV_DES, t.passwd) + priv2 = security.Privacy(usm.PRIV_AES, t.passwd) + t.assertNotEqual(priv1, priv2) + + priv3 = security.Privacy(usm.PRIV_AES, t.passwd + "456") + t.assertNotEqual(priv2, priv3) diff --git a/tests/test_usm.py b/tests/test_usm.py new file mode 100644 index 0000000..3a6e4e9 --- /dev/null +++ b/tests/test_usm.py @@ -0,0 +1,114 @@ +import unittest + +from pynetsnmp import usm + +_sorted_auth_names = sorted( + ["NOAUTH", "MD5", "SHA", "SHA-224", "SHA-256", "SHA-384", "SHA-512"] +) + + +class TestAuthProtocols(unittest.TestCase): + def test_noauth_contained(t): + t.assertIn(usm.AUTH_NOAUTH, usm.auth_protocols) + + def test_md5_contained(t): + t.assertIn(usm.AUTH_MD5, usm.auth_protocols) + + def test_sha_contained(t): + t.assertIn(usm.AUTH_SHA, usm.auth_protocols) + + def test_sha_224_contained(t): + t.assertIn(usm.AUTH_SHA_224, usm.auth_protocols) + + def test_sha_256_contained(t): + t.assertIn(usm.AUTH_SHA_256, usm.auth_protocols) + + def test_sha_384_contained(t): + t.assertIn(usm.AUTH_SHA_384, usm.auth_protocols) + + def test_sha_512_contained(t): + t.assertIn(usm.AUTH_SHA_512, usm.auth_protocols) + + def test_length(t): + t.assertEqual(7, len(usm.auth_protocols)) + + def test_iterable(t): + names = sorted(str(p) for p in usm.auth_protocols) + t.assertEqual(7, len(names)) + t.assertListEqual(_sorted_auth_names, names) + + def test_noauth_getitem(t): + proto = usm.auth_protocols[usm.AUTH_NOAUTH.name] + t.assertEqual(usm.AUTH_NOAUTH, proto) + + def test_md5_getitem(t): + proto = usm.auth_protocols[usm.AUTH_MD5.name] + t.assertEqual(usm.AUTH_MD5, proto) + + def test_sha_getitem(t): + proto = usm.auth_protocols[usm.AUTH_SHA.name] + t.assertEqual(usm.AUTH_SHA, proto) + + def test_sha_224_getitem(t): + proto = usm.auth_protocols[usm.AUTH_SHA_224.name] + t.assertEqual(usm.AUTH_SHA_224, proto) + + def test_sha_256_getitem(t): + proto = usm.auth_protocols[usm.AUTH_SHA_256.name] + t.assertEqual(usm.AUTH_SHA_256, proto) + + def test_sha_384_getitem(t): + proto = usm.auth_protocols[usm.AUTH_SHA_384.name] + t.assertEqual(usm.AUTH_SHA_384, proto) + + def test_sha_512_getitem(t): + proto = usm.auth_protocols[usm.AUTH_SHA_512.name] + t.assertEqual(usm.AUTH_SHA_512, proto) + + +_sorted_priv_names = sorted(["NOPRIV", "DES", "AES", "AES-192", "AES-256"]) + + +class TestPrivProtocols(unittest.TestCase): + def test_nopriv_contained(t): + t.assertIn(usm.PRIV_NOPRIV, usm.priv_protocols) + + def test_des_contained(t): + t.assertIn(usm.PRIV_DES, usm.priv_protocols) + + def test_aes_contained(t): + t.assertIn(usm.PRIV_AES, usm.priv_protocols) + + def test_aes_192_contained(t): + t.assertIn(usm.PRIV_AES_192, usm.priv_protocols) + + def test_aes_256_contained(t): + t.assertIn(usm.PRIV_AES_256, usm.priv_protocols) + + def test_length(t): + t.assertEqual(5, len(usm.priv_protocols)) + + def test_iterable(t): + names = sorted(str(p) for p in usm.priv_protocols) + t.assertEqual(5, len(names)) + t.assertListEqual(_sorted_priv_names, names) + + def test_nopriv_getitem(t): + proto = usm.priv_protocols[usm.PRIV_NOPRIV.name] + t.assertEqual(usm.PRIV_NOPRIV, proto) + + def test_des_getitem(t): + proto = usm.priv_protocols[usm.PRIV_DES.name] + t.assertEqual(usm.PRIV_DES, proto) + + def test_aes_getitem(t): + proto = usm.priv_protocols[usm.PRIV_AES.name] + t.assertEqual(usm.PRIV_AES, proto) + + def test_aes_192_getitem(t): + proto = usm.priv_protocols[usm.PRIV_AES_192.name] + t.assertEqual(usm.PRIV_AES_192, proto) + + def test_aes_256_getitem(t): + proto = usm.priv_protocols[usm.PRIV_AES_256.name] + t.assertEqual(usm.PRIV_AES_256, proto) From 88b5a1b6dbb3ed174bf6ebf905e7605306eaaa8a Mon Sep 17 00:00:00 2001 From: Jason Peacock Date: Wed, 20 Nov 2024 10:25:05 -0600 Subject: [PATCH 17/17] Refactored usm.py and security.py modules into a `usm` package. --- pynetsnmp/security.py | 210 -------------- pynetsnmp/usm/__init__.py | 43 +++ pynetsnmp/usm/auth.py | 50 ++++ pynetsnmp/usm/common.py | 19 ++ pynetsnmp/usm/community.py | 23 ++ pynetsnmp/usm/priv.py | 48 ++++ pynetsnmp/{usm.py => usm/protocols.py} | 17 -- pynetsnmp/usm/user.py | 89 ++++++ tests/test_auth.py | 92 ++++++ tests/test_community.py | 62 ++++ tests/test_priv.py | 80 ++++++ tests/test_security.py | 382 ------------------------- tests/test_usmuser.py | 157 ++++++++++ 13 files changed, 663 insertions(+), 609 deletions(-) delete mode 100644 pynetsnmp/security.py create mode 100644 pynetsnmp/usm/__init__.py create mode 100644 pynetsnmp/usm/auth.py create mode 100644 pynetsnmp/usm/common.py create mode 100644 pynetsnmp/usm/community.py create mode 100644 pynetsnmp/usm/priv.py rename pynetsnmp/{usm.py => usm/protocols.py} (90%) create mode 100644 pynetsnmp/usm/user.py create mode 100644 tests/test_auth.py create mode 100644 tests/test_community.py create mode 100644 tests/test_priv.py delete mode 100644 tests/test_security.py create mode 100644 tests/test_usmuser.py diff --git a/pynetsnmp/security.py b/pynetsnmp/security.py deleted file mode 100644 index 0a48b40..0000000 --- a/pynetsnmp/security.py +++ /dev/null @@ -1,210 +0,0 @@ -from __future__ import absolute_import - -from .CONSTANTS import SNMP_VERSION_1, SNMP_VERSION_2c, SNMP_VERSION_3 -from .usm import AUTH_NOAUTH, auth_protocols, PRIV_NOPRIV, priv_protocols - -__all__ = ("Community", "UsmUser", "Authentication", "Privacy") - - -class Community(object): - """ - Provides the community based security model for SNMP v1/V2c. - """ - - def __init__(self, name, version=SNMP_VERSION_2c): - mapped = _version_map.get(version) - if mapped is None or mapped == "3": - raise ValueError( - "SNMP version '{}' not supported for Community".format(version) - ) - self.name = name - self.version = mapped - - def getArguments(self): - community = ("-c", str(self.name)) if self.name else () - return ("-v", self.version) + community - - -class UsmUser(object): - """ - Provides User-based Security Model configuration for SNMP v3. - """ - - def __init__(self, name, auth=None, priv=None, engine=None, context=None): - self.name = name - if auth is None: - auth = Authentication.new_noauth() - if not isinstance(auth, Authentication): - raise ValueError("invalid authentication object") - self.auth = auth - if priv is None: - priv = Privacy.new_nopriv() - if not isinstance(priv, Privacy): - raise ValueError("invalid privacy object") - self.priv = priv - self.engine = engine - self.context = context - self.version = _version_map.get(SNMP_VERSION_3) - - def getArguments(self): - auth_args = ( - ("-a", self.auth.protocol.name, "-A", self.auth.passphrase) - if self.auth - else () - ) - if auth_args: - # The privacy arguments are only given if the authentication - # arguments are also provided. - priv_args = ( - ("-x", self.priv.protocol.name, "-X", self.priv.passphrase) - if self.priv - else () - ) - else: - priv_args = () - seclevel_arg = ("-l", _sec_level[(bool(self.auth), bool(self.priv))]) - - return ( - ("-v", self.version) - + (("-u", self.name) if self.name else ()) - + seclevel_arg - + auth_args - + priv_args - + (("-e", self.engine) if self.engine else ()) - + (("-n", self.context) if self.context else ()) - ) - - def __eq__(self, other): - return ( - self.name == other.name - and self.auth == other.auth - and self.priv == other.priv - and self.engine == other.engine - and self.context == other.context - ) - - def __str__(self): - info = ", ".join( - "{0}={1}".format(k, v) - for k, v in ( - ("name", self.name), - ("auth", self.auth), - ("priv", self.priv), - ("engine", self.engine), - ("context", self.context), - ) - if v - ) - return "{0.__class__.__name__}(version={0.version}{1}{2})".format( - self, ", " if info else "", info - ) - - -_sec_level = { - (True, True): "authPriv", - (True, False): "authNoPriv", - (False, False): "noAuthNoPriv", -} -_version_map = { - "1": "1", - "2c": "2c", - "3": "3", - SNMP_VERSION_1: "1", - SNMP_VERSION_2c: "2c", - SNMP_VERSION_3: "3", - "v1": "1", - "v2c": "2c", - "v3": "3", -} - - -class Authentication(object): - """ - Provides the authentication data for UsmUser objects. - """ - - __slots__ = ("protocol", "passphrase") - - @classmethod - def new_noauth(cls): - return cls(None, None) - - def __init__(self, protocol, passphrase): - if ( - not protocol - or protocol is AUTH_NOAUTH - or protocol == "AUTH_NOAUTH" - ): - self.protocol = AUTH_NOAUTH - self.passphrase = None - else: - self.protocol = auth_protocols[protocol] - if not passphrase: - raise ValueError( - "Authentication protocol requires a passphrase" - ) - self.passphrase = passphrase - - def __eq__(self, other): - if not isinstance(other, Authentication): - return NotImplemented - return ( - self.protocol == other.protocol - and self.passphrase == other.passphrase - ) - - def __nonzero__(self): - return self.protocol is not AUTH_NOAUTH - - def __repr__(self): - return ( - "<{0.__module__}.{0.__class__.__name__} protocol={0.protocol}>" - ).format(self) - - def __str__(self): - return "{0.__class__.__name__}(protocol={0.protocol})".format(self) - - -class Privacy(object): - """ - Provides the privacy data for UsmUser objects. - """ - - __slots__ = ("protocol", "passphrase") - - @classmethod - def new_nopriv(cls): - return cls(None, None) - - def __init__(self, protocol, passphrase): - if ( - not protocol - or protocol is PRIV_NOPRIV - or protocol == "PRIV_NOPRIV" - ): - self.protocol = PRIV_NOPRIV - self.passphrase = None - else: - self.protocol = priv_protocols[protocol] - if not passphrase: - raise ValueError("Privacy protocol requires a passphrase") - self.passphrase = passphrase - - def __eq__(self, other): - if not isinstance(other, Privacy): - return NotImplemented - return ( - self.protocol == other.protocol - and self.passphrase == other.passphrase - ) - - def __nonzero__(self): - return self.protocol is not PRIV_NOPRIV - - def __repr__(self): - return ( - "<{0.__module__}.{0.__class__.__name__} protocol={0.protocol}>" - ).format(self) - - def __str__(self): - return "{0.__class__.__name__}(protocol={0.protocol})".format(self) diff --git a/pynetsnmp/usm/__init__.py b/pynetsnmp/usm/__init__.py new file mode 100644 index 0000000..d089315 --- /dev/null +++ b/pynetsnmp/usm/__init__.py @@ -0,0 +1,43 @@ +from __future__ import absolute_import + +from .auth import Authentication +from .community import Community +from .priv import Privacy +from .user import User +from .protocols import ( + AUTH_MD5, + AUTH_NOAUTH, + auth_protocols, + AUTH_SHA, + AUTH_SHA_224, + AUTH_SHA_256, + AUTH_SHA_384, + AUTH_SHA_512, + PRIV_AES, + PRIV_AES_192, + PRIV_AES_256, + PRIV_DES, + PRIV_NOPRIV, + priv_protocols, +) + +__all__ = ( + "Authentication", + "AUTH_MD5", + "AUTH_NOAUTH", + "auth_protocols", + "AUTH_SHA", + "AUTH_SHA_224", + "AUTH_SHA_256", + "AUTH_SHA_384", + "AUTH_SHA_512", + "Community", + "Privacy", + "PRIV_AES", + "PRIV_AES_192", + "PRIV_AES_256", + "PRIV_DES", + "PRIV_NOPRIV", + "priv_protocols", + "User", +) diff --git a/pynetsnmp/usm/auth.py b/pynetsnmp/usm/auth.py new file mode 100644 index 0000000..38e3efc --- /dev/null +++ b/pynetsnmp/usm/auth.py @@ -0,0 +1,50 @@ +from __future__ import absolute_import + +from .protocols import AUTH_NOAUTH, auth_protocols + + +class Authentication(object): + """ + Provides the authentication data for User objects. + """ + + __slots__ = ("protocol", "passphrase") + + @classmethod + def new_noauth(cls): + return cls(None, None) + + def __init__(self, protocol, passphrase): + if ( + not protocol + or protocol is AUTH_NOAUTH + or protocol == "AUTH_NOAUTH" + ): + self.protocol = AUTH_NOAUTH + self.passphrase = None + else: + self.protocol = auth_protocols[protocol] + if not passphrase: + raise ValueError( + "Authentication protocol requires a passphrase" + ) + self.passphrase = passphrase + + def __eq__(self, other): + if not isinstance(other, Authentication): + return NotImplemented + return ( + self.protocol == other.protocol + and self.passphrase == other.passphrase + ) + + def __nonzero__(self): + return self.protocol is not AUTH_NOAUTH + + def __repr__(self): + return ( + "<{0.__module__}.{0.__class__.__name__} protocol={0.protocol}>" + ).format(self) + + def __str__(self): + return "{0.__class__.__name__}(protocol={0.protocol})".format(self) diff --git a/pynetsnmp/usm/common.py b/pynetsnmp/usm/common.py new file mode 100644 index 0000000..24b0b6a --- /dev/null +++ b/pynetsnmp/usm/common.py @@ -0,0 +1,19 @@ +from __future__ import absolute_import + +from ..CONSTANTS import ( + SNMP_VERSION_1 as _V1, + SNMP_VERSION_2c as _V2C, + SNMP_VERSION_3 as _V3, +) + +version_map = { + "1": "1", + "2c": "2c", + "3": "3", + _V1: "1", + _V2C: "2c", + _V3: "3", + "v1": "1", + "v2c": "2c", + "v3": "3", +} diff --git a/pynetsnmp/usm/community.py b/pynetsnmp/usm/community.py new file mode 100644 index 0000000..44c3a73 --- /dev/null +++ b/pynetsnmp/usm/community.py @@ -0,0 +1,23 @@ +from __future__ import absolute_import + +from ..CONSTANTS import SNMP_VERSION_2c as _V2C +from .common import version_map + + +class Community(object): + """ + Provides the community based security model for SNMP v1/V2c. + """ + + def __init__(self, name, version=_V2C): + mapped = version_map.get(version) + if mapped is None or mapped == "3": + raise ValueError( + "SNMP version '{}' not supported for Community".format(version) + ) + self.name = name + self.version = mapped + + def getArguments(self): + community = ("-c", str(self.name)) if self.name else () + return ("-v", self.version) + community diff --git a/pynetsnmp/usm/priv.py b/pynetsnmp/usm/priv.py new file mode 100644 index 0000000..22aa210 --- /dev/null +++ b/pynetsnmp/usm/priv.py @@ -0,0 +1,48 @@ +from __future__ import absolute_import + +from .protocols import PRIV_NOPRIV, priv_protocols + + +class Privacy(object): + """ + Provides the privacy data for User objects. + """ + + __slots__ = ("protocol", "passphrase") + + @classmethod + def new_nopriv(cls): + return cls(None, None) + + def __init__(self, protocol, passphrase): + if ( + not protocol + or protocol is PRIV_NOPRIV + or protocol == "PRIV_NOPRIV" + ): + self.protocol = PRIV_NOPRIV + self.passphrase = None + else: + self.protocol = priv_protocols[protocol] + if not passphrase: + raise ValueError("Privacy protocol requires a passphrase") + self.passphrase = passphrase + + def __eq__(self, other): + if not isinstance(other, Privacy): + return NotImplemented + return ( + self.protocol == other.protocol + and self.passphrase == other.passphrase + ) + + def __nonzero__(self): + return self.protocol is not PRIV_NOPRIV + + def __repr__(self): + return ( + "<{0.__module__}.{0.__class__.__name__} protocol={0.protocol}>" + ).format(self) + + def __str__(self): + return "{0.__class__.__name__}(protocol={0.protocol})".format(self) diff --git a/pynetsnmp/usm.py b/pynetsnmp/usm/protocols.py similarity index 90% rename from pynetsnmp/usm.py rename to pynetsnmp/usm/protocols.py index 3606eb3..fd69de2 100644 --- a/pynetsnmp/usm.py +++ b/pynetsnmp/usm/protocols.py @@ -92,20 +92,3 @@ def __repr__(self): del _Protocol del _Protocols - -__all__ = ( - "AUTH_MD5", - "AUTH_NOAUTH", - "auth_protocols", - "AUTH_SHA", - "AUTH_SHA_224", - "AUTH_SHA_256", - "AUTH_SHA_384", - "AUTH_SHA_512", - "PRIV_AES", - "PRIV_AES_192", - "PRIV_AES_256", - "PRIV_DES", - "PRIV_NOPRIV", - "priv_protocols", -) diff --git a/pynetsnmp/usm/user.py b/pynetsnmp/usm/user.py new file mode 100644 index 0000000..eca9a2b --- /dev/null +++ b/pynetsnmp/usm/user.py @@ -0,0 +1,89 @@ +from __future__ import absolute_import + +from ..CONSTANTS import SNMP_VERSION_3 as _V3 + +from .auth import Authentication +from .common import version_map +from .priv import Privacy + + +_sec_level = { + (True, True): "authPriv", + (True, False): "authNoPriv", + (False, False): "noAuthNoPriv", +} + + +class User(object): + """ + Provides User-based Security Model configuration for SNMP v3. + """ + + def __init__(self, name, auth=None, priv=None, engine=None, context=None): + self.name = name + if auth is None: + auth = Authentication.new_noauth() + if not isinstance(auth, Authentication): + raise ValueError("invalid authentication object") + self.auth = auth + if priv is None: + priv = Privacy.new_nopriv() + if not isinstance(priv, Privacy): + raise ValueError("invalid privacy object") + self.priv = priv + self.engine = engine + self.context = context + self.version = version_map.get(_V3) + + def getArguments(self): + auth_args = ( + ("-a", self.auth.protocol.name, "-A", self.auth.passphrase) + if self.auth + else () + ) + if auth_args: + # The privacy arguments are only given if the authentication + # arguments are also provided. + priv_args = ( + ("-x", self.priv.protocol.name, "-X", self.priv.passphrase) + if self.priv + else () + ) + else: + priv_args = () + seclevel_arg = ("-l", _sec_level[(bool(self.auth), bool(self.priv))]) + + return ( + ("-v", self.version) + + (("-u", self.name) if self.name else ()) + + seclevel_arg + + auth_args + + priv_args + + (("-e", self.engine) if self.engine else ()) + + (("-n", self.context) if self.context else ()) + ) + + def __eq__(self, other): + return ( + self.name == other.name + and self.auth == other.auth + and self.priv == other.priv + and self.engine == other.engine + and self.context == other.context + ) + + def __str__(self): + info = ", ".join( + "{0}={1}".format(k, v) + for k, v in ( + ("name", self.name), + ("auth", self.auth), + ("priv", self.priv), + ("engine", self.engine), + ("context", self.context), + ) + if v + ) + return "{0.__class__.__name__}(version={0.version}{1}{2})".format( + self, ", " if info else "", info + ) diff --git a/tests/test_auth.py b/tests/test_auth.py new file mode 100644 index 0000000..a17d68f --- /dev/null +++ b/tests/test_auth.py @@ -0,0 +1,92 @@ +import unittest + +from pynetsnmp import usm + + +class TestAuthentication(unittest.TestCase): + passwd = "security123" # noqa: S105 + + def test_noauth_classmethod(t): + auth = usm.Authentication.new_noauth() + t.assertEqual(usm.AUTH_NOAUTH, auth.protocol) + t.assertIsNone(auth.passphrase) + + def test_none_init(t): + auth = usm.Authentication(None, None) + t.assertEqual(usm.AUTH_NOAUTH, auth.protocol) + t.assertIsNone(auth.passphrase) + + auth = usm.Authentication(None, t.passwd) + t.assertEqual(usm.AUTH_NOAUTH, auth.protocol) + t.assertIsNone(auth.passphrase) + + def test_noauth_init(t): + auth = usm.Authentication(usm.AUTH_NOAUTH, None) + t.assertEqual(usm.AUTH_NOAUTH, auth.protocol) + t.assertIsNone(auth.passphrase) + + auth = usm.Authentication(usm.AUTH_NOAUTH, t.passwd) + t.assertEqual(usm.AUTH_NOAUTH, auth.protocol) + t.assertIsNone(auth.passphrase) + + def test_noauth_str_init(t): + auth = usm.Authentication("AUTH_NOAUTH", None) + t.assertEqual(usm.AUTH_NOAUTH, auth.protocol) + t.assertIsNone(auth.passphrase) + + auth = usm.Authentication("AUTH_NOAUTH", t.passwd) + t.assertEqual(usm.AUTH_NOAUTH, auth.protocol) + t.assertIsNone(auth.passphrase) + + def test_noauth_is_false(t): + auth = usm.Authentication.new_noauth() + t.assertFalse(auth) + + def test_md5(t): + auth = usm.Authentication(usm.AUTH_MD5, t.passwd) + t.assertTrue(auth) + t.assertEqual(auth.protocol, usm.AUTH_MD5) + t.assertEqual(auth.passphrase, t.passwd) + + def test_sha(t): + auth = usm.Authentication(usm.AUTH_SHA, t.passwd) + t.assertTrue(auth) + t.assertEqual(auth.protocol, usm.AUTH_SHA) + t.assertEqual(auth.passphrase, t.passwd) + + def test_sha_224(t): + auth = usm.Authentication(usm.AUTH_SHA_224, t.passwd) + t.assertTrue(auth) + t.assertEqual(auth.protocol, usm.AUTH_SHA_224) + t.assertEqual(auth.passphrase, t.passwd) + + def test_sha_256(t): + auth = usm.Authentication(usm.AUTH_SHA_256, t.passwd) + t.assertTrue(auth) + t.assertEqual(auth.protocol, usm.AUTH_SHA_256) + t.assertEqual(auth.passphrase, t.passwd) + + def test_sha_384(t): + auth = usm.Authentication(usm.AUTH_SHA_384, t.passwd) + t.assertTrue(auth) + t.assertEqual(auth.protocol, usm.AUTH_SHA_384) + t.assertEqual(auth.passphrase, t.passwd) + + def test_sha_512(t): + auth = usm.Authentication(usm.AUTH_SHA_512, t.passwd) + t.assertTrue(auth) + t.assertEqual(auth.protocol, usm.AUTH_SHA_512) + t.assertEqual(auth.passphrase, t.passwd) + + def test_equal(t): + auth1 = usm.Authentication(usm.AUTH_MD5, t.passwd) + auth2 = usm.Authentication(usm.AUTH_MD5, t.passwd) + t.assertEqual(auth1, auth2) + + def test_not_equal(t): + auth1 = usm.Authentication(usm.AUTH_MD5, t.passwd) + auth2 = usm.Authentication(usm.AUTH_SHA, t.passwd) + t.assertNotEqual(auth1, auth2) + + auth3 = usm.Authentication(usm.AUTH_SHA, t.passwd + "456") + t.assertNotEqual(auth2, auth3) diff --git a/tests/test_community.py b/tests/test_community.py new file mode 100644 index 0000000..7432548 --- /dev/null +++ b/tests/test_community.py @@ -0,0 +1,62 @@ +import unittest + +from pynetsnmp import CONSTANTS, usm + + +class TestCommunity(unittest.TestCase): + name = "public" + + def test_default(t): + c = usm.Community(t.name) + t.assertEqual(c.name, t.name) + t.assertEqual(c.version, "2c") + expected = ("-v", "2c", "-c", t.name) + t.assertSequenceEqual(c.getArguments(), expected) + + def test_v1_constant(t): + c = usm.Community(t.name, CONSTANTS.SNMP_VERSION_1) + t.assertEqual(c.name, t.name) + t.assertEqual(c.version, "1") + expected = ("-v", "1", "-c", t.name) + t.assertSequenceEqual(c.getArguments(), expected) + + def test_v1_v1(t): + c = usm.Community(t.name, "v1") + t.assertEqual(c.name, t.name) + t.assertEqual(c.version, "1") + expected = ("-v", "1", "-c", t.name) + t.assertSequenceEqual(c.getArguments(), expected) + + def test_v2c_constant(t): + c = usm.Community(t.name, CONSTANTS.SNMP_VERSION_2c) + t.assertEqual(c.name, t.name) + t.assertEqual(c.version, "2c") + expected = ("-v", "2c", "-c", t.name) + t.assertSequenceEqual(c.getArguments(), expected) + + def test_v2c_v2c(t): + c = usm.Community(t.name, "v2c") + t.assertEqual(c.name, t.name) + t.assertEqual(c.version, "2c") + expected = ("-v", "2c", "-c", t.name) + t.assertSequenceEqual(c.getArguments(), expected) + + def test_v3_constant(t): + with t.assertRaises(ValueError): + usm.Community(t.name, CONSTANTS.SNMP_VERSION_3) + + def test_v3_v3(t): + with t.assertRaises(ValueError): + usm.Community(t.name, "v3") + + def test_none_version(t): + with t.assertRaises(ValueError): + usm.Community(t.name, None) + + def test_not_a_version_str(t): + with t.assertRaises(ValueError): + usm.Community(t.name, "oi") + + def test_not_a_version_number(t): + with t.assertRaises(ValueError): + usm.Community(t.name, 3947) diff --git a/tests/test_priv.py b/tests/test_priv.py new file mode 100644 index 0000000..9dc3842 --- /dev/null +++ b/tests/test_priv.py @@ -0,0 +1,80 @@ +import unittest + +from pynetsnmp import usm + + +class TestPrivacy(unittest.TestCase): + passwd = "security123" # noqa: S105 + + def test_nopriv_classmethod(t): + priv = usm.Privacy.new_nopriv() + t.assertEqual(usm.PRIV_NOPRIV, priv.protocol) + t.assertIsNone(priv.passphrase) + + def test_none_init(t): + priv = usm.Privacy(None, None) + t.assertEqual(usm.PRIV_NOPRIV, priv.protocol) + t.assertIsNone(priv.passphrase) + + priv = usm.Privacy(None, t.passwd) + t.assertEqual(usm.PRIV_NOPRIV, priv.protocol) + t.assertIsNone(priv.passphrase) + + def test_nopriv_init(t): + priv = usm.Privacy(usm.PRIV_NOPRIV, None) + t.assertEqual(usm.PRIV_NOPRIV, priv.protocol) + t.assertIsNone(priv.passphrase) + + priv = usm.Privacy(usm.PRIV_NOPRIV, t.passwd) + t.assertEqual(usm.PRIV_NOPRIV, priv.protocol) + t.assertIsNone(priv.passphrase) + + def test_nopriv_str_init(t): + priv = usm.Privacy("PRIV_NOPRIV", None) + t.assertEqual(usm.PRIV_NOPRIV, priv.protocol) + t.assertIsNone(priv.passphrase) + + priv = usm.Privacy("PRIV_NOPRIV", t.passwd) + t.assertEqual(usm.PRIV_NOPRIV, priv.protocol) + t.assertIsNone(priv.passphrase) + + def test_nopriv_is_false(t): + priv = usm.Privacy.new_nopriv() + t.assertFalse(priv) + + def test_des(t): + priv = usm.Privacy(usm.PRIV_DES, t.passwd) + t.assertTrue(priv) + t.assertEqual(priv.protocol, usm.PRIV_DES) + t.assertEqual(priv.passphrase, t.passwd) + + def test_aes(t): + priv = usm.Privacy(usm.PRIV_AES, t.passwd) + t.assertTrue(priv) + t.assertEqual(priv.protocol, usm.PRIV_AES) + t.assertEqual(priv.passphrase, t.passwd) + + def test_aes_192(t): + priv = usm.Privacy(usm.PRIV_AES_192, t.passwd) + t.assertTrue(priv) + t.assertEqual(priv.protocol, usm.PRIV_AES_192) + t.assertEqual(priv.passphrase, t.passwd) + + def test_aes_256(t): + priv = usm.Privacy(usm.PRIV_AES_256, t.passwd) + t.assertTrue(priv) + t.assertEqual(priv.protocol, usm.PRIV_AES_256) + t.assertEqual(priv.passphrase, t.passwd) + + def test_equal(t): + priv1 = usm.Privacy(usm.PRIV_DES, t.passwd) + priv2 = usm.Privacy(usm.PRIV_DES, t.passwd) + t.assertEqual(priv1, priv2) + + def test_not_equal(t): + priv1 = usm.Privacy(usm.PRIV_DES, t.passwd) + priv2 = usm.Privacy(usm.PRIV_AES, t.passwd) + t.assertNotEqual(priv1, priv2) + + priv3 = usm.Privacy(usm.PRIV_AES, t.passwd + "456") + t.assertNotEqual(priv2, priv3) diff --git a/tests/test_security.py b/tests/test_security.py deleted file mode 100644 index 3783922..0000000 --- a/tests/test_security.py +++ /dev/null @@ -1,382 +0,0 @@ -import unittest - -from pynetsnmp import security, usm - - -class TestCommunity(unittest.TestCase): - name = "public" - - def test_default(t): - c = security.Community(t.name) - t.assertEqual(c.name, t.name) - t.assertEqual(c.version, "2c") - expected = ("-v", "2c", "-c", t.name) - t.assertSequenceEqual(c.getArguments(), expected) - - def test_v1_constant(t): - c = security.Community(t.name, security.SNMP_VERSION_1) - t.assertEqual(c.name, t.name) - t.assertEqual(c.version, "1") - expected = ("-v", "1", "-c", t.name) - t.assertSequenceEqual(c.getArguments(), expected) - - def test_v1_v1(t): - c = security.Community(t.name, "v1") - t.assertEqual(c.name, t.name) - t.assertEqual(c.version, "1") - expected = ("-v", "1", "-c", t.name) - t.assertSequenceEqual(c.getArguments(), expected) - - def test_v2c_constant(t): - c = security.Community(t.name, security.SNMP_VERSION_2c) - t.assertEqual(c.name, t.name) - t.assertEqual(c.version, "2c") - expected = ("-v", "2c", "-c", t.name) - t.assertSequenceEqual(c.getArguments(), expected) - - def test_v2c_v2c(t): - c = security.Community(t.name, "v2c") - t.assertEqual(c.name, t.name) - t.assertEqual(c.version, "2c") - expected = ("-v", "2c", "-c", t.name) - t.assertSequenceEqual(c.getArguments(), expected) - - def test_v3_constant(t): - with t.assertRaises(ValueError): - security.Community(t.name, security.SNMP_VERSION_3) - - def test_v3_v3(t): - with t.assertRaises(ValueError): - security.Community(t.name, "v3") - - def test_none_version(t): - with t.assertRaises(ValueError): - security.Community(t.name, None) - - def test_not_a_version_str(t): - with t.assertRaises(ValueError): - security.Community(t.name, "oi") - - def test_not_a_version_number(t): - with t.assertRaises(ValueError): - security.Community(t.name, 3947) - - -class TestUsmUser(unittest.TestCase): - name = "john_doe" - passwd = "secured123" # noqa: S105 - - def test_default(t): - user = security.UsmUser(t.name) - t.assertEqual(t.name, user.name) - t.assertEqual(security.Authentication.new_noauth(), user.auth) - t.assertEqual(security.Privacy.new_nopriv(), user.priv) - t.assertIsNone(user.engine) - t.assertIsNone(user.context) - t.assertEqual(user.version, "3") - expected = ("-v", "3", "-u", t.name, "-l", "noAuthNoPriv") - t.assertSequenceEqual(expected, user.getArguments()) - - def test_engineid(t): - engineid = hex(3443489794829589283483234)[2:].strip("L") - user = security.UsmUser(t.name, engine=engineid) - t.assertEqual(t.name, user.name) - t.assertEqual(security.Authentication.new_noauth(), user.auth) - t.assertEqual(security.Privacy.new_nopriv(), user.priv) - t.assertEqual(engineid, user.engine) - t.assertIsNone(user.context) - expected = ( - "-v", - "3", - "-u", - t.name, - "-l", - "noAuthNoPriv", - "-e", - engineid, - ) - t.assertSequenceEqual(expected, user.getArguments()) - - def test_contextid(t): - contextid = hex(9084090984572743455234)[2:].strip("L") - user = security.UsmUser(t.name, context=contextid) - t.assertEqual(t.name, user.name) - t.assertEqual(security.Authentication.new_noauth(), user.auth) - t.assertEqual(security.Privacy.new_nopriv(), user.priv) - t.assertIsNone(user.engine) - t.assertEqual(contextid, user.context) - expected = ( - "-v", - "3", - "-u", - t.name, - "-l", - "noAuthNoPriv", - "-n", - contextid, - ) - t.assertSequenceEqual(expected, user.getArguments()) - - def test_auth(t): - auth = security.Authentication(usm.AUTH_SHA_224, t.passwd) - user = security.UsmUser(t.name, auth=auth) - t.assertEqual(t.name, user.name) - t.assertEqual(auth, user.auth) - t.assertEqual(security.Privacy.new_nopriv(), user.priv) - t.assertIsNone(user.engine) - t.assertIsNone(user.context) - t.assertEqual(user.version, "3") - expected = ( - "-v", - "3", - "-u", - t.name, - "-l", - "authNoPriv", - "-a", - auth.protocol.name, - "-A", - auth.passphrase, - ) - t.assertSequenceEqual(expected, user.getArguments()) - - def test_authpriv(t): - auth = security.Authentication(usm.AUTH_SHA_224, t.passwd) - priv = security.Privacy(usm.PRIV_AES_256, t.passwd) - user = security.UsmUser(t.name, auth=auth, priv=priv) - t.assertEqual(t.name, user.name) - t.assertEqual(auth, user.auth) - t.assertEqual(priv, user.priv) - t.assertIsNone(user.engine) - t.assertIsNone(user.context) - t.assertEqual(user.version, "3") - expected = ( - "-v", - "3", - "-u", - t.name, - "-l", - "authPriv", - "-a", - auth.protocol.name, - "-A", - auth.passphrase, - "-x", - priv.protocol.name, - "-X", - priv.passphrase, - ) - t.assertSequenceEqual(expected, user.getArguments()) - - def test_all_args(t): - auth = security.Authentication(usm.AUTH_SHA_224, t.passwd) - priv = security.Privacy(usm.PRIV_AES_256, t.passwd) - contextid = hex(9084090984572743455234)[2:].strip("L") - engineid = hex(3443489794829589283483234)[2:].strip("L") - user = security.UsmUser( - t.name, auth=auth, priv=priv, engine=engineid, context=contextid - ) - t.assertEqual(t.name, user.name) - t.assertEqual(auth, user.auth) - t.assertEqual(priv, user.priv) - t.assertEqual(engineid, user.engine) - t.assertEqual(contextid, user.context) - t.assertEqual(user.version, "3") - expected = ( - "-v", - "3", - "-u", - t.name, - "-l", - "authPriv", - "-a", - auth.protocol.name, - "-A", - auth.passphrase, - "-x", - priv.protocol.name, - "-X", - priv.passphrase, - "-e", - engineid, - "-n", - contextid, - ) - t.assertSequenceEqual(expected, user.getArguments()) - - def test_equality(t): - auth1 = security.Authentication(usm.AUTH_SHA_224, t.passwd) - user1 = security.UsmUser(t.name, auth=auth1) - auth2 = security.Authentication(usm.AUTH_SHA_224, t.passwd) - user2 = security.UsmUser(t.name, auth=auth2) - auth3 = security.Authentication(usm.AUTH_SHA_256, t.passwd) - priv3 = security.Privacy(usm.PRIV_AES_256, t.passwd) - user3 = security.UsmUser(t.name, auth=auth3, priv=priv3) - t.assertEqual(user1, user2) - t.assertNotEqual(user1, user3) - - -class TestAuthentication(unittest.TestCase): - passwd = "security123" # noqa: S105 - - def test_noauth_classmethod(t): - auth = security.Authentication.new_noauth() - t.assertEqual(usm.AUTH_NOAUTH, auth.protocol) - t.assertIsNone(auth.passphrase) - - def test_none_init(t): - auth = security.Authentication(None, None) - t.assertEqual(usm.AUTH_NOAUTH, auth.protocol) - t.assertIsNone(auth.passphrase) - - auth = security.Authentication(None, t.passwd) - t.assertEqual(usm.AUTH_NOAUTH, auth.protocol) - t.assertIsNone(auth.passphrase) - - def test_noauth_init(t): - auth = security.Authentication(usm.AUTH_NOAUTH, None) - t.assertEqual(usm.AUTH_NOAUTH, auth.protocol) - t.assertIsNone(auth.passphrase) - - auth = security.Authentication(usm.AUTH_NOAUTH, t.passwd) - t.assertEqual(usm.AUTH_NOAUTH, auth.protocol) - t.assertIsNone(auth.passphrase) - - def test_noauth_str_init(t): - auth = security.Authentication("AUTH_NOAUTH", None) - t.assertEqual(usm.AUTH_NOAUTH, auth.protocol) - t.assertIsNone(auth.passphrase) - - auth = security.Authentication("AUTH_NOAUTH", t.passwd) - t.assertEqual(usm.AUTH_NOAUTH, auth.protocol) - t.assertIsNone(auth.passphrase) - - def test_noauth_is_false(t): - auth = security.Authentication.new_noauth() - t.assertFalse(auth) - - def test_md5(t): - auth = security.Authentication(usm.AUTH_MD5, t.passwd) - t.assertTrue(auth) - t.assertEqual(auth.protocol, usm.AUTH_MD5) - t.assertEqual(auth.passphrase, t.passwd) - - def test_sha(t): - auth = security.Authentication(usm.AUTH_SHA, t.passwd) - t.assertTrue(auth) - t.assertEqual(auth.protocol, usm.AUTH_SHA) - t.assertEqual(auth.passphrase, t.passwd) - - def test_sha_224(t): - auth = security.Authentication(usm.AUTH_SHA_224, t.passwd) - t.assertTrue(auth) - t.assertEqual(auth.protocol, usm.AUTH_SHA_224) - t.assertEqual(auth.passphrase, t.passwd) - - def test_sha_256(t): - auth = security.Authentication(usm.AUTH_SHA_256, t.passwd) - t.assertTrue(auth) - t.assertEqual(auth.protocol, usm.AUTH_SHA_256) - t.assertEqual(auth.passphrase, t.passwd) - - def test_sha_384(t): - auth = security.Authentication(usm.AUTH_SHA_384, t.passwd) - t.assertTrue(auth) - t.assertEqual(auth.protocol, usm.AUTH_SHA_384) - t.assertEqual(auth.passphrase, t.passwd) - - def test_sha_512(t): - auth = security.Authentication(usm.AUTH_SHA_512, t.passwd) - t.assertTrue(auth) - t.assertEqual(auth.protocol, usm.AUTH_SHA_512) - t.assertEqual(auth.passphrase, t.passwd) - - def test_equal(t): - auth1 = security.Authentication(usm.AUTH_MD5, t.passwd) - auth2 = security.Authentication(usm.AUTH_MD5, t.passwd) - t.assertEqual(auth1, auth2) - - def test_not_equal(t): - auth1 = security.Authentication(usm.AUTH_MD5, t.passwd) - auth2 = security.Authentication(usm.AUTH_SHA, t.passwd) - t.assertNotEqual(auth1, auth2) - - auth3 = security.Authentication(usm.AUTH_SHA, t.passwd + "456") - t.assertNotEqual(auth2, auth3) - - -class TestPrivacy(unittest.TestCase): - passwd = "security123" # noqa: S105 - - def test_nopriv_classmethod(t): - priv = security.Privacy.new_nopriv() - t.assertEqual(usm.PRIV_NOPRIV, priv.protocol) - t.assertIsNone(priv.passphrase) - - def test_none_init(t): - priv = security.Privacy(None, None) - t.assertEqual(usm.PRIV_NOPRIV, priv.protocol) - t.assertIsNone(priv.passphrase) - - priv = security.Privacy(None, t.passwd) - t.assertEqual(usm.PRIV_NOPRIV, priv.protocol) - t.assertIsNone(priv.passphrase) - - def test_nopriv_init(t): - priv = security.Privacy(usm.PRIV_NOPRIV, None) - t.assertEqual(usm.PRIV_NOPRIV, priv.protocol) - t.assertIsNone(priv.passphrase) - - priv = security.Privacy(usm.PRIV_NOPRIV, t.passwd) - t.assertEqual(usm.PRIV_NOPRIV, priv.protocol) - t.assertIsNone(priv.passphrase) - - def test_nopriv_str_init(t): - priv = security.Privacy("PRIV_NOPRIV", None) - t.assertEqual(usm.PRIV_NOPRIV, priv.protocol) - t.assertIsNone(priv.passphrase) - - priv = security.Privacy("PRIV_NOPRIV", t.passwd) - t.assertEqual(usm.PRIV_NOPRIV, priv.protocol) - t.assertIsNone(priv.passphrase) - - def test_nopriv_is_false(t): - priv = security.Privacy.new_nopriv() - t.assertFalse(priv) - - def test_des(t): - priv = security.Privacy(usm.PRIV_DES, t.passwd) - t.assertTrue(priv) - t.assertEqual(priv.protocol, usm.PRIV_DES) - t.assertEqual(priv.passphrase, t.passwd) - - def test_aes(t): - priv = security.Privacy(usm.PRIV_AES, t.passwd) - t.assertTrue(priv) - t.assertEqual(priv.protocol, usm.PRIV_AES) - t.assertEqual(priv.passphrase, t.passwd) - - def test_aes_192(t): - priv = security.Privacy(usm.PRIV_AES_192, t.passwd) - t.assertTrue(priv) - t.assertEqual(priv.protocol, usm.PRIV_AES_192) - t.assertEqual(priv.passphrase, t.passwd) - - def test_aes_256(t): - priv = security.Privacy(usm.PRIV_AES_256, t.passwd) - t.assertTrue(priv) - t.assertEqual(priv.protocol, usm.PRIV_AES_256) - t.assertEqual(priv.passphrase, t.passwd) - - def test_equal(t): - priv1 = security.Privacy(usm.PRIV_DES, t.passwd) - priv2 = security.Privacy(usm.PRIV_DES, t.passwd) - t.assertEqual(priv1, priv2) - - def test_not_equal(t): - priv1 = security.Privacy(usm.PRIV_DES, t.passwd) - priv2 = security.Privacy(usm.PRIV_AES, t.passwd) - t.assertNotEqual(priv1, priv2) - - priv3 = security.Privacy(usm.PRIV_AES, t.passwd + "456") - t.assertNotEqual(priv2, priv3) diff --git a/tests/test_usmuser.py b/tests/test_usmuser.py new file mode 100644 index 0000000..d28c025 --- /dev/null +++ b/tests/test_usmuser.py @@ -0,0 +1,157 @@ +import unittest + +from pynetsnmp import usm + + +class TestUser(unittest.TestCase): + name = "john_doe" + passwd = "secured123" # noqa: S105 + + def test_default(t): + user = usm.User(t.name) + t.assertEqual(t.name, user.name) + t.assertEqual(usm.Authentication.new_noauth(), user.auth) + t.assertEqual(usm.Privacy.new_nopriv(), user.priv) + t.assertIsNone(user.engine) + t.assertIsNone(user.context) + t.assertEqual(user.version, "3") + expected = ("-v", "3", "-u", t.name, "-l", "noAuthNoPriv") + t.assertSequenceEqual(expected, user.getArguments()) + + def test_engineid(t): + engineid = hex(3443489794829589283483234)[2:].strip("L") + user = usm.User(t.name, engine=engineid) + t.assertEqual(t.name, user.name) + t.assertEqual(usm.Authentication.new_noauth(), user.auth) + t.assertEqual(usm.Privacy.new_nopriv(), user.priv) + t.assertEqual(engineid, user.engine) + t.assertIsNone(user.context) + expected = ( + "-v", + "3", + "-u", + t.name, + "-l", + "noAuthNoPriv", + "-e", + engineid, + ) + t.assertSequenceEqual(expected, user.getArguments()) + + def test_contextid(t): + contextid = hex(9084090984572743455234)[2:].strip("L") + user = usm.User(t.name, context=contextid) + t.assertEqual(t.name, user.name) + t.assertEqual(usm.Authentication.new_noauth(), user.auth) + t.assertEqual(usm.Privacy.new_nopriv(), user.priv) + t.assertIsNone(user.engine) + t.assertEqual(contextid, user.context) + expected = ( + "-v", + "3", + "-u", + t.name, + "-l", + "noAuthNoPriv", + "-n", + contextid, + ) + t.assertSequenceEqual(expected, user.getArguments()) + + def test_auth(t): + auth = usm.Authentication(usm.AUTH_SHA_224, t.passwd) + user = usm.User(t.name, auth=auth) + t.assertEqual(t.name, user.name) + t.assertEqual(auth, user.auth) + t.assertEqual(usm.Privacy.new_nopriv(), user.priv) + t.assertIsNone(user.engine) + t.assertIsNone(user.context) + t.assertEqual(user.version, "3") + expected = ( + "-v", + "3", + "-u", + t.name, + "-l", + "authNoPriv", + "-a", + auth.protocol.name, + "-A", + auth.passphrase, + ) + t.assertSequenceEqual(expected, user.getArguments()) + + def test_authpriv(t): + auth = usm.Authentication(usm.AUTH_SHA_224, t.passwd) + priv = usm.Privacy(usm.PRIV_AES_256, t.passwd) + user = usm.User(t.name, auth=auth, priv=priv) + t.assertEqual(t.name, user.name) + t.assertEqual(auth, user.auth) + t.assertEqual(priv, user.priv) + t.assertIsNone(user.engine) + t.assertIsNone(user.context) + t.assertEqual(user.version, "3") + expected = ( + "-v", + "3", + "-u", + t.name, + "-l", + "authPriv", + "-a", + auth.protocol.name, + "-A", + auth.passphrase, + "-x", + priv.protocol.name, + "-X", + priv.passphrase, + ) + t.assertSequenceEqual(expected, user.getArguments()) + + def test_all_args(t): + auth = usm.Authentication(usm.AUTH_SHA_224, t.passwd) + priv = usm.Privacy(usm.PRIV_AES_256, t.passwd) + contextid = hex(9084090984572743455234)[2:].strip("L") + engineid = hex(3443489794829589283483234)[2:].strip("L") + user = usm.User( + t.name, auth=auth, priv=priv, engine=engineid, context=contextid + ) + t.assertEqual(t.name, user.name) + t.assertEqual(auth, user.auth) + t.assertEqual(priv, user.priv) + t.assertEqual(engineid, user.engine) + t.assertEqual(contextid, user.context) + t.assertEqual(user.version, "3") + expected = ( + "-v", + "3", + "-u", + t.name, + "-l", + "authPriv", + "-a", + auth.protocol.name, + "-A", + auth.passphrase, + "-x", + priv.protocol.name, + "-X", + priv.passphrase, + "-e", + engineid, + "-n", + contextid, + ) + t.assertSequenceEqual(expected, user.getArguments()) + + def test_equality(t): + auth1 = usm.Authentication(usm.AUTH_SHA_224, t.passwd) + user1 = usm.User(t.name, auth=auth1) + auth2 = usm.Authentication(usm.AUTH_SHA_224, t.passwd) + user2 = usm.User(t.name, auth=auth2) + auth3 = usm.Authentication(usm.AUTH_SHA_256, t.passwd) + priv3 = usm.Privacy(usm.PRIV_AES_256, t.passwd) + user3 = usm.User(t.name, auth=auth3, priv=priv3) + t.assertEqual(user1, user2) + t.assertNotEqual(user1, user3)