Skip to content

Commit

Permalink
Improve typing
Browse files Browse the repository at this point in the history
  • Loading branch information
nikteliy committed May 13, 2024
1 parent 8612423 commit 7c6f21e
Show file tree
Hide file tree
Showing 12 changed files with 356 additions and 155 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ repos:
- id: mypy
additional_dependencies: [types-setuptools]
files: ^snap7
args: [--strict, --exclude=snap7/util/db.py]

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: 'v0.4.2'
Expand Down
83 changes: 45 additions & 38 deletions snap7/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,29 @@
import re
import logging
from ctypes import CFUNCTYPE, byref, create_string_buffer, sizeof
from ctypes import Array, c_byte, c_char_p, c_int, c_int32, c_uint16, c_ulong, c_void_p
from ctypes import Array, _SimpleCData, c_byte, c_char_p, c_int, c_int32, c_uint16, c_ulong, c_void_p, _CArgObject
from datetime import datetime
from typing import Any, Callable, List, Optional, Tuple, Union
from typing import Any, Callable, List, Optional, Tuple, Union, ParamSpec, TypeVar, Type
from types import TracebackType

from ..common import check_error, ipv4, load_library
from ..protocols import Snap7CliProtocol
from ..types import S7SZL, Areas, BlocksList, S7CpInfo, S7CpuInfo, S7DataItem
from ..types import S7OrderCode, S7Protection, S7SZLList, TS7BlockInfo, WordLen
from ..types import S7Object, buffer_size, buffer_type, cpu_statuses, param_types
from ..types import RemotePort, wordlen_to_ctypes, block_types

logger = logging.getLogger(__name__)

Param = ParamSpec("Param")
RetType = TypeVar("RetType")

def error_wrap(func):

def error_wrap(func: Callable[Param, RetType]) -> Callable[Param, None]:
"""Parses a s7 error code returned the decorated function."""

def f(*args, **kw):
code = func(*args, **kw)
def f(*args: Param.args, **kwargs: Param.kwargs) -> None:
code = func(*args, **kwargs)
check_error(code, context="client")

return f
Expand All @@ -47,10 +52,10 @@ class Client:
>>> client.db_write(1, 0, data)
"""

_lib: Any # since this is dynamically loaded from a DLL we don't have the type signature.
_lib: Snap7CliProtocol
_read_callback = None
_callback = None
_s7_client: Optional[S7Object] = None
_s7_client: S7Object

def __init__(self, lib_location: Optional[str] = None):
"""Creates a new `Client` instance.
Expand All @@ -66,23 +71,25 @@ def __init__(self, lib_location: Optional[str] = None):
<snap7.client.Client object at 0x0000028B257128E0>
"""

self._lib = load_library(lib_location)
self._lib: Snap7CliProtocol = load_library(lib_location)
self.create()

def __enter__(self):
def __enter__(self) -> "Client":
return self

def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(
self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]
) -> None:
self.destroy()

def __del__(self):
def __del__(self) -> None:
self.destroy()

def create(self):
def create(self) -> None:
"""Creates a SNAP7 client."""
logger.info("creating snap7 client")
self._lib.Cli_Create.restype = S7Object
self._s7_client = S7Object(self._lib.Cli_Create())
self._lib.Cli_Create.restype = S7Object # type: ignore[attr-defined]
self._s7_client = self._lib.Cli_Create()

def destroy(self) -> Optional[int]:
"""Destroys the Client object.
Expand All @@ -97,7 +104,7 @@ def destroy(self) -> Optional[int]:
logger.info("destroying snap7 client")
if self._lib and self._s7_client is not None:
return self._lib.Cli_Destroy(byref(self._s7_client))
self._s7_client = None
self._s7_client = None # type: ignore[assignment]
return None

def plc_stop(self) -> int:
Expand Down Expand Up @@ -199,7 +206,7 @@ def connect(self, address: str, rack: int, slot: int, tcpport: int = 102) -> int
"""
logger.info(f"connecting to {address}:{tcpport} rack {rack} slot {slot}")

self.set_param(RemotePort, tcpport)
self.set_param(number=RemotePort, value=tcpport)
return self._lib.Cli_ConnectTo(self._s7_client, c_char_p(address.encode()), c_int(rack), c_int(slot))

def db_read(self, db_number: int, start: int, size: int) -> bytearray:
Expand Down Expand Up @@ -441,7 +448,7 @@ def write_area(self, area: Areas, dbnumber: int, start: int, data: bytearray) ->
cdata = (type_ * len(data)).from_buffer_copy(data)
return self._lib.Cli_WriteArea(self._s7_client, area.value, dbnumber, start, size, wordlen.value, byref(cdata))

def read_multi_vars(self, items) -> Tuple[int, S7DataItem]:
def read_multi_vars(self, items: Array[S7DataItem]) -> Tuple[int, Array[S7DataItem]]:
"""Reads different kind of variables from a PLC simultaneously.
Args:
Expand Down Expand Up @@ -472,7 +479,7 @@ def list_blocks(self) -> BlocksList:
logger.debug(f"blocks: {blocksList}")
return blocksList

def list_blocks_of_type(self, blocktype: str, size: int) -> Union[int, Array]:
def list_blocks_of_type(self, blocktype: str, size: int) -> Union[int, Array[c_uint16]]:
"""This function returns the AG list of a specified block type.
Args:
Expand Down Expand Up @@ -592,11 +599,11 @@ def set_connection_params(self, address: str, local_tsap: int, remote_tsap: int)
"""
if not re.match(ipv4, address):
raise ValueError(f"{address} is invalid ipv4")
result = self._lib.Cli_SetConnectionParams(self._s7_client, address, c_uint16(local_tsap), c_uint16(remote_tsap))
result = self._lib.Cli_SetConnectionParams(self._s7_client, address.encode(), c_uint16(local_tsap), c_uint16(remote_tsap))
if result != 0:
raise ValueError("The parameter was invalid")

def set_connection_type(self, connection_type: int):
def set_connection_type(self, connection_type: int) -> None:
"""Sets the connection resource type, i.e the way in which the Clients connects to a PLC.
Args:
Expand Down Expand Up @@ -659,7 +666,7 @@ def ab_write(self, start: int, data: bytearray) -> int:
logger.debug(f"ab write: start: {start}: size: {size}: ")
return self._lib.Cli_ABWrite(self._s7_client, start, size, byref(cdata))

def as_ab_read(self, start: int, size: int, data) -> int:
def as_ab_read(self, start: int, size: int, data: _SimpleCData[Any]) -> int:
"""Reads a part of IPU area from a PLC asynchronously.
Args:
Expand Down Expand Up @@ -720,7 +727,7 @@ def as_copy_ram_to_rom(self, timeout: int = 1) -> int:
check_error(result, context="client")
return result

def as_ct_read(self, start: int, amount: int, data) -> int:
def as_ct_read(self, start: int, amount: int, data: _SimpleCData[Any]) -> int:
"""Reads counters from a PLC asynchronously.
Args:
Expand Down Expand Up @@ -752,7 +759,7 @@ def as_ct_write(self, start: int, amount: int, data: bytearray) -> int:
check_error(result, context="client")
return result

def as_db_fill(self, db_number: int, filler) -> int:
def as_db_fill(self, db_number: int, filler: int) -> int:
"""Fills a DB in AG with a given byte.
Args:
Expand All @@ -766,7 +773,7 @@ def as_db_fill(self, db_number: int, filler) -> int:
check_error(result, context="client")
return result

def as_db_get(self, db_number: int, _buffer, size) -> bytearray:
def as_db_get(self, db_number: int, _buffer: _SimpleCData[Any], size: _SimpleCData[Any]) -> int:
"""Uploads a DB from AG using DBRead.
Note:
Expand All @@ -784,7 +791,7 @@ def as_db_get(self, db_number: int, _buffer, size) -> bytearray:
check_error(result, context="client")
return result

def as_db_read(self, db_number: int, start: int, size: int, data) -> Array:
def as_db_read(self, db_number: int, start: int, size: int, data: _SimpleCData[Any]) -> int:
"""Reads a part of a DB from a PLC.
Args:
Expand All @@ -807,7 +814,7 @@ def as_db_read(self, db_number: int, start: int, size: int, data) -> Array:
check_error(result, context="client")
return result

def as_db_write(self, db_number: int, start: int, size: int, data) -> int:
def as_db_write(self, db_number: int, start: int, size: int, data: _SimpleCData[Any]) -> int:
"""Writes a part of a DB into a PLC.
Args:
Expand Down Expand Up @@ -943,7 +950,7 @@ def set_plc_datetime(self, dt: datetime) -> int:

return self._lib.Cli_SetPlcDateTime(self._s7_client, byref(buffer))

def check_as_completion(self, p_value) -> int:
def check_as_completion(self, p_value: c_int) -> int:
"""Method to check Status of an async request. Result contains if the check was successful, not the data value itself
Args:
Expand Down Expand Up @@ -1000,7 +1007,7 @@ def wait_as_completion(self, timeout: int) -> int:
check_error(result, context="client")
return result

def _prepare_as_read_area(self, area: Areas, size: int) -> Tuple[WordLen, Array]:
def _prepare_as_read_area(self, area: Areas, size: int) -> Tuple[WordLen, Array[_SimpleCData[int]]]:
if area not in Areas:
raise ValueError(f"{area} is not implemented in types")
elif area == Areas.TM:
Expand All @@ -1013,7 +1020,7 @@ def _prepare_as_read_area(self, area: Areas, size: int) -> Tuple[WordLen, Array]
usrdata = (type_ * size)()
return wordlen, usrdata

def as_read_area(self, area: Areas, dbnumber: int, start: int, size: int, wordlen: WordLen, pusrdata) -> int:
def as_read_area(self, area: Areas, dbnumber: int, start: int, size: int, wordlen: WordLen, pusrdata: _CArgObject) -> int:
"""Reads a data area from a PLC asynchronously.
With it you can read DB, Inputs, Outputs, Merkers, Timers and Counters.
Expand All @@ -1036,7 +1043,7 @@ def as_read_area(self, area: Areas, dbnumber: int, start: int, size: int, wordle
check_error(result, context="client")
return result

def _prepare_as_write_area(self, area: Areas, data: bytearray) -> Tuple[WordLen, Array]:
def _prepare_as_write_area(self, area: Areas, data: bytearray) -> Tuple[WordLen, Array[Any]]:
if area not in Areas:
raise ValueError(f"{area} is not implemented in types")
elif area == Areas.TM:
Expand All @@ -1049,7 +1056,7 @@ def _prepare_as_write_area(self, area: Areas, data: bytearray) -> Tuple[WordLen,
cdata = (type_ * len(data)).from_buffer_copy(data)
return wordlen, cdata

def as_write_area(self, area: Areas, dbnumber: int, start: int, size: int, wordlen: WordLen, pusrdata) -> int:
def as_write_area(self, area: Areas, dbnumber: int, start: int, size: int, wordlen: WordLen, pusrdata: bytearray) -> int:
"""Writes a data area into a PLC asynchronously.
Args:
Expand All @@ -1072,7 +1079,7 @@ def as_write_area(self, area: Areas, dbnumber: int, start: int, size: int, wordl
check_error(res, context="client")
return res

def as_eb_read(self, start: int, size: int, data) -> int:
def as_eb_read(self, start: int, size: int, data: _SimpleCData[Any]) -> int:
"""Reads a part of IPI area from a PLC asynchronously.
Args:
Expand Down Expand Up @@ -1124,7 +1131,7 @@ def as_full_upload(self, _type: str, block_num: int) -> int:
check_error(result, context="client")
return result

def as_list_blocks_of_type(self, blocktype: str, data, count) -> int:
def as_list_blocks_of_type(self, blocktype: str, data: _SimpleCData[Any], count: _SimpleCData[Any]) -> int:
"""Returns the AG blocks list of a given type.
Args:
Expand All @@ -1145,7 +1152,7 @@ def as_list_blocks_of_type(self, blocktype: str, data, count) -> int:
check_error(result, context="client")
return result

def as_mb_read(self, start: int, size: int, data) -> int:
def as_mb_read(self, start: int, size: int, data: _SimpleCData[Any]) -> int:
"""Reads a part of Merkers area from a PLC.
Args:
Expand Down Expand Up @@ -1177,7 +1184,7 @@ def as_mb_write(self, start: int, size: int, data: bytearray) -> int:
check_error(result, context="client")
return result

def as_read_szl(self, ssl_id: int, index: int, s7_szl: S7SZL, size) -> int:
def as_read_szl(self, ssl_id: int, index: int, s7_szl: S7SZL, size: _SimpleCData[Any]) -> int:
"""Reads a partial list of given ID and Index.
Args:
Expand All @@ -1193,7 +1200,7 @@ def as_read_szl(self, ssl_id: int, index: int, s7_szl: S7SZL, size) -> int:
check_error(result, context="client")
return result

def as_read_szl_list(self, szl_list, items_count) -> int:
def as_read_szl_list(self, szl_list: _SimpleCData[Any], items_count: _SimpleCData[Any]) -> int:
"""Reads the list of partial lists available in the CPU.
Args:
Expand All @@ -1207,7 +1214,7 @@ def as_read_szl_list(self, szl_list, items_count) -> int:
check_error(result, context="client")
return result

def as_tm_read(self, start: int, amount: int, data) -> bytearray:
def as_tm_read(self, start: int, amount: int, data: _SimpleCData[Any]) -> int:
"""Reads timers from a PLC.
Args:
Expand Down Expand Up @@ -1239,7 +1246,7 @@ def as_tm_write(self, start: int, amount: int, data: bytearray) -> int:
check_error(result)
return result

def as_upload(self, block_num: int, _buffer, size) -> int:
def as_upload(self, block_num: int, _buffer: _SimpleCData[Any], size: _SimpleCData[Any]) -> int:
"""Uploads a block from AG.
Note:
Expand Down
14 changes: 8 additions & 6 deletions snap7/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
import platform
from pathlib import Path
from ctypes import c_char
from typing import Any, Literal, Optional
from typing import Literal, NoReturn, Optional, cast
from ctypes.util import find_library
from functools import cache
from .protocols import Snap7CliProtocol


if platform.system() == "Windows":
from ctypes import windll as cdll # type: ignore
Expand All @@ -19,7 +21,7 @@
ipv4 = r"^(([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\.){3}([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])$"


def _raise_error():
def _raise_error() -> NoReturn:
error = f"""can't find snap7 shared library.
This probably means you are installing python-snap7 from source. When no binary wheel is found for you architecture, pip
Expand Down Expand Up @@ -72,7 +74,7 @@ def _find_in_package() -> Optional[str]:


@cache
def load_library(lib_location: Optional[str] = None) -> Any:
def load_library(lib_location: Optional[str] = None) -> Snap7CliProtocol:
"""Loads the `snap7.dll` library.
Returns:
cdll: a ctypes cdll object with the snap7 shared library loaded.
Expand All @@ -83,7 +85,7 @@ def load_library(lib_location: Optional[str] = None) -> Any:
if not lib_location:
_raise_error()

return cdll.LoadLibrary(lib_location)
return cast(Snap7CliProtocol, cdll.LoadLibrary(lib_location))


Context = Literal["client", "server", "partner"]
Expand All @@ -107,7 +109,7 @@ def check_error(code: int, context: Context = "client") -> None:
raise RuntimeError(error)


def error_text(error, context: Context = "client") -> bytes:
def error_text(error: int, context: Context = "client") -> bytes:
"""Returns a textual explanation of a given error number
Args:
Expand All @@ -127,6 +129,6 @@ def error_text(error, context: Context = "client") -> bytes:
text_type = c_char * len_
text = text_type()
library = load_library()
map_ = {"client": library.Cli_ErrorText, "server": library.Srv_ErrorText, "partner": library.Par_ErrorText}
map_ = {"client": library.Cli_ErrorText, "server": library.Srv_ErrorText, "partner": library.Par_ErrorText} # type: ignore[attr-defined]
map_[context](error, text, len_)
return text.value
Loading

0 comments on commit 7c6f21e

Please sign in to comment.