Skip to content

Commit

Permalink
make typing consistent
Browse files Browse the repository at this point in the history
  • Loading branch information
gijzelaerr committed Jul 4, 2024
1 parent 9e35342 commit 018ccf4
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 78 deletions.
48 changes: 22 additions & 26 deletions snap7/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import re
import logging
from ctypes import CFUNCTYPE, byref, create_string_buffer, sizeof
from ctypes import Array, _SimpleCData, c_byte, c_char_p, c_int, c_int32, c_uint16, c_ulong, c_void_p
from ctypes import CFUNCTYPE, byref, create_string_buffer, sizeof, c_int16
from ctypes import Array, c_byte, c_char_p, c_int, c_int32, c_uint16, c_ulong, c_void_p
from datetime import datetime
from typing import Any, Callable, Hashable, List, Optional, Tuple, Union, Type
from types import TracebackType
Expand All @@ -15,7 +15,7 @@
from ..types import S7SZL, Area, BlocksList, S7CpInfo, S7CpuInfo, S7DataItem, Block
from ..types import S7OrderCode, S7Protection, S7SZLList, TS7BlockInfo, WordLen
from ..types import S7Object, buffer_size, buffer_type, cpu_statuses, param_types
from ..types import RemotePort
from ..types import RemotePort, CDataArrayType

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -651,7 +651,7 @@ def ab_write(self, start: int, data: bytearray) -> int:
logger.debug(f"ab write: start: {start}: size: {size}: ")
return self._lib.Cli_ABWrite(self._s7_client, start, size, byref(cdata))

def as_ab_read(self, start: int, size: int, data: "Array[_SimpleCData[Any]]") -> int:
def as_ab_read(self, start: int, size: int, data: Union[Array[c_byte], Array[c_int16], Array[c_int32]]) -> int:
"""Reads a part of IPU area from a PLC asynchronously.
Args:
Expand Down Expand Up @@ -712,7 +712,7 @@ def as_copy_ram_to_rom(self, timeout: int = 1) -> int:
check_error(result, context="client")
return result

def as_ct_read(self, start: int, amount: int, data: "Array[_SimpleCData[Any]]") -> int:
def as_ct_read(self, start: int, amount: int, data: CDataArrayType) -> int:
"""Reads counters from a PLC asynchronously.
Args:
Expand Down Expand Up @@ -758,7 +758,7 @@ def as_db_fill(self, db_number: int, filler: int) -> int:
check_error(result, context="client")
return result

def as_db_get(self, db_number: int, _buffer: "Array[_SimpleCData[Any]]", size: "_SimpleCData[Any]") -> int:
def as_db_get(self, db_number: int, _buffer: CDataArrayType, size: int) -> int:
"""Uploads a DB from AG using DBRead.
Note:
Expand All @@ -772,11 +772,11 @@ def as_db_get(self, db_number: int, _buffer: "Array[_SimpleCData[Any]]", size: "
Returns:
Snap7 code.
"""
result = self._lib.Cli_AsDBGet(self._s7_client, db_number, byref(_buffer), byref(size))
result = self._lib.Cli_AsDBGet(self._s7_client, db_number, byref(_buffer), byref(c_int(size)))
check_error(result, context="client")
return result

def as_db_read(self, db_number: int, start: int, size: int, data: "Array[_SimpleCData[Any]]") -> int:
def as_db_read(self, db_number: int, start: int, size: int, data: CDataArrayType) -> int:
"""Reads a part of a DB from a PLC.
Args:
Expand All @@ -799,7 +799,7 @@ def as_db_read(self, db_number: int, start: int, size: int, data: "Array[_Simple
check_error(result, context="client")
return result

def as_db_write(self, db_number: int, start: int, size: int, data: "Array[_SimpleCData[Any]]") -> int:
def as_db_write(self, db_number: int, start: int, size: int, data: CDataArrayType) -> int:
"""Writes a part of a DB into a PLC.
Args:
Expand Down Expand Up @@ -992,9 +992,7 @@ def wait_as_completion(self, timeout: int) -> int:
check_error(result, context="client")
return result

def as_read_area(
self, area: Area, db_number: int, start: int, size: int, word_len: WordLen, data: "Array[_SimpleCData[Any]]"
) -> int:
def as_read_area(self, area: Area, db_number: int, start: int, size: int, word_len: WordLen, data: CDataArrayType) -> int:
"""Reads a data area from a PLC asynchronously.
With this you can read DB, Inputs, Outputs, Merkers, Timers and Counters.
Expand All @@ -1017,9 +1015,7 @@ def as_read_area(
check_error(result, context="client")
return result

def as_write_area(
self, area: Area, db_number: int, start: int, size: int, word_len: WordLen, data: "Array[_SimpleCData[Any]]"
) -> int:
def as_write_area(self, area: Area, db_number: int, start: int, size: int, word_len: WordLen, data: CDataArrayType) -> int:
"""Writes a data area into a PLC asynchronously.
Args:
Expand All @@ -1042,7 +1038,7 @@ def as_write_area(
check_error(res, context="client")
return res

def as_eb_read(self, start: int, size: int, data: "Array[_SimpleCData[Any]]") -> int:
def as_eb_read(self, start: int, size: int, data: CDataArrayType) -> int:
"""Reads a part of IPI area from a PLC asynchronously.
Args:
Expand Down Expand Up @@ -1093,7 +1089,7 @@ def as_full_upload(self, block_type: Block, block_num: int) -> int:
check_error(result, context="client")
return result

def as_list_blocks_of_type(self, block_type: Block, data: "Array[_SimpleCData[Any]]", count: "_SimpleCData[Any]") -> int:
def as_list_blocks_of_type(self, block_type: Block, data: CDataArrayType, count: int) -> int:
"""Returns the AG blocks list of a given type.
Args:
Expand All @@ -1104,11 +1100,11 @@ def as_list_blocks_of_type(self, block_type: Block, data: "Array[_SimpleCData[An
Returns:
Snap7 code.
"""
result = self._lib.Cli_AsListBlocksOfType(self._s7_client, block_type.ctype, byref(data), byref(count))
result = self._lib.Cli_AsListBlocksOfType(self._s7_client, block_type.ctype, byref(data), byref(c_int(count)))
check_error(result, context="client")
return result

def as_mb_read(self, start: int, size: int, data: "Array[_SimpleCData[Any]]") -> int:
def as_mb_read(self, start: int, size: int, data: CDataArrayType) -> int:
"""Reads a part of Merkers area from a PLC.
Args:
Expand Down Expand Up @@ -1140,7 +1136,7 @@ def as_mb_write(self, start: int, size: int, data: bytearray) -> int:
check_error(result, context="client")
return result

def as_read_szl(self, ssl_id: int, index: int, s7_szl: S7SZL, size: "_SimpleCData[Any]") -> int:
def as_read_szl(self, ssl_id: int, index: int, s7_szl: S7SZL, size: int) -> int:
"""Reads a partial list of given ID and Index.
Args:
Expand All @@ -1152,11 +1148,11 @@ def as_read_szl(self, ssl_id: int, index: int, s7_szl: S7SZL, size: "_SimpleCDat
Returns:
Snap7 code.
"""
result = self._lib.Cli_AsReadSZL(self._s7_client, ssl_id, index, byref(s7_szl), byref(size))
result = self._lib.Cli_AsReadSZL(self._s7_client, ssl_id, index, byref(s7_szl), byref(c_int(size)))
check_error(result, context="client")
return result

def as_read_szl_list(self, szl_list: S7SZLList, items_count: "_SimpleCData[Any]") -> int:
def as_read_szl_list(self, szl_list: S7SZLList, items_count: int) -> int:
"""Reads the list of partial lists available in the CPU.
Args:
Expand All @@ -1166,11 +1162,11 @@ def as_read_szl_list(self, szl_list: S7SZLList, items_count: "_SimpleCData[Any]"
Returns:
Snap7 code.
"""
result = self._lib.Cli_AsReadSZLList(self._s7_client, byref(szl_list), byref(items_count))
result = self._lib.Cli_AsReadSZLList(self._s7_client, byref(szl_list), byref(c_int(items_count)))
check_error(result, context="client")
return result

def as_tm_read(self, start: int, amount: int, data: "Array[_SimpleCData[Any]]") -> int:
def as_tm_read(self, start: int, amount: int, data: CDataArrayType) -> int:
"""Reads timers from a PLC.
Args:
Expand Down Expand Up @@ -1202,7 +1198,7 @@ def as_tm_write(self, start: int, amount: int, data: bytearray) -> int:
check_error(result)
return result

def as_upload(self, block_num: int, _buffer: "Array[_SimpleCData[Any]]", size: "_SimpleCData[Any]") -> int:
def as_upload(self, block_num: int, _buffer: CDataArrayType, size: int) -> int:
"""Uploads a block from AG.
Note:
Expand All @@ -1216,7 +1212,7 @@ def as_upload(self, block_num: int, _buffer: "Array[_SimpleCData[Any]]", size: "
Returns:
Snap7 code.
"""
result = self._lib.Cli_AsUpload(self._s7_client, Block.DB.ctype, block_num, byref(_buffer), byref(size))
result = self._lib.Cli_AsUpload(self._s7_client, Block.DB.ctype, block_num, byref(_buffer), byref(c_int(size)))
check_error(result, context="client")
return result

Expand Down
18 changes: 7 additions & 11 deletions snap7/server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,16 @@
c_void_p,
CFUNCTYPE,
POINTER,
Array,
_SimpleCData,
c_byte,
c_int16,
)
from _ctypes import CFuncPtr
import struct
import logging
from typing import Any, Callable, Hashable, Optional, Tuple, cast, Type, Union
from typing import Any, Callable, Hashable, Optional, Tuple, cast, Type
from types import TracebackType

from ..common import ipv4, check_error, load_library
from ..protocol import Snap7CliProtocol
from ..types import SrvEvent, LocalPort, cpu_statuses, server_statuses, SrvArea, longword, WordLen, S7Object
from ..types import SrvEvent, LocalPort, cpu_statuses, server_statuses, SrvArea, longword, WordLen, S7Object, CDataArrayType

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -99,7 +95,7 @@ def create(self) -> None:
self._s7_server = S7Object(self._lib.Srv_Create())

@error_wrap
def register_area(self, area: SrvArea, index: int, userdata: "Array[_SimpleCData[int]]") -> int:
def register_area(self, area: SrvArea, index: int, userdata: CDataArrayType) -> int:
"""Shares a memory area with the server. That memory block will be
visible by the clients.
Expand Down Expand Up @@ -400,10 +396,10 @@ def mainloop(tcpport: int = 1102, init_standard_values: bool = False) -> None:

server = Server()
size = 100
DBdata: Union[Array[c_byte], Array[c_int16], Array[c_int32]] = (WordLen.Byte.ctype * size)()
PAdata: Union[Array[c_byte], Array[c_int16], Array[c_int32]] = (WordLen.Byte.ctype * size)()
TMdata: Union[Array[c_byte], Array[c_int16], Array[c_int32]] = (WordLen.Byte.ctype * size)()
CTdata: Union[Array[c_byte], Array[c_int16], Array[c_int32]] = (WordLen.Byte.ctype * size)()
DBdata: CDataArrayType = (WordLen.Byte.ctype * size)()
PAdata: CDataArrayType = (WordLen.Byte.ctype * size)()
TMdata: CDataArrayType = (WordLen.Byte.ctype * size)()
CTdata: CDataArrayType = (WordLen.Byte.ctype * size)()
server.register_area(SrvArea.DB, 1, DBdata)
server.register_area(SrvArea.PA, 1, PAdata)
server.register_area(SrvArea.TM, 1, TMdata)
Expand Down
10 changes: 7 additions & 3 deletions snap7/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Python equivalent for snap7 specific types.
"""

from _ctypes import Array
from ctypes import (
c_int16,
c_int8,
Expand All @@ -19,7 +20,10 @@
c_uint8,
)
from enum import Enum
from typing import Type, Dict
from typing import Dict, Union

CDataArrayType = Union[Array[c_byte], Array[c_int], Array[c_int16], Array[c_int32]]
CDataType = Union[type[c_int8], type[c_int16], type[c_int32]]

S7Object = c_void_p
buffer_size = 65536
Expand Down Expand Up @@ -84,8 +88,8 @@ class WordLen(Enum):
Timer = 0x1D

@property
def ctype(self) -> Type[c_int8 | c_int16 | c_int32]:
map_: Dict[WordLen, Type[c_int8 | c_int16 | c_int32]] = {
def ctype(self) -> CDataType:
map_: Dict[WordLen, CDataType] = {
WordLen.Bit: c_int16,
WordLen.Byte: c_int8,
WordLen.Word: c_int16,
Expand Down
18 changes: 1 addition & 17 deletions snap7/util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,7 @@
"""

import re
import time
from typing import Any, Union
from datetime import date, datetime
from typing import Any
from collections import OrderedDict

from .setters import (
Expand Down Expand Up @@ -174,20 +172,6 @@
]


def utc2local(utc: Union[date, datetime]) -> Union[datetime, date]:
"""Returns the local datetime
Args:
utc: UTC type date or datetime.
Returns:
Local datetime.
"""
epoch = time.mktime(utc.timetuple())
offset = datetime.fromtimestamp(epoch) - datetime.utcfromtimestamp(epoch)
return utc + offset


def parse_specification(db_specification: str) -> OrderedDict[str, Any]:
"""Create a db specification derived from a
dataview of a db in which the byte layout
Expand Down
34 changes: 13 additions & 21 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import logging
import struct
import time
from _ctypes import _SimpleCData
from typing import Tuple, Union

import pytest
Expand All @@ -28,7 +27,6 @@

from snap7.util.getters import get_real, get_int
from snap7.util.setters import set_int
from snap7.util import utc2local
from snap7.common import check_error
from snap7.server import mainloop
from snap7.client import Client
Expand Down Expand Up @@ -67,7 +65,7 @@
slot = 1


def _prepare_as_read_area(area: Area, size: int) -> Tuple[WordLen, "Array[_SimpleCData[int]]"]:
def _prepare_as_read_area(area: Area, size: int) -> Tuple[WordLen, Union[Array[c_byte], Array[c_int16], Array[c_int32]]]:
wordlen = area.wordlen()
usrdata = (wordlen.ctype * size)()
return wordlen, usrdata
Expand Down Expand Up @@ -222,8 +220,8 @@ def test_upload(self) -> None:
self.assertRaises(RuntimeError, self.client.upload, db_number)

def test_as_upload(self) -> None:
_buffer = typing_cast("Array[_SimpleCData[int]]", buffer_type())
size = c_int(sizeof(_buffer))
_buffer = typing_cast(Array[c_int32], buffer_type())
size = sizeof(_buffer)
self.client.as_upload(1, _buffer, size)
self.assertRaises(RuntimeError, self.client.wait_as_completion, 500)

Expand Down Expand Up @@ -436,12 +434,11 @@ def test_as_db_fill(self) -> None:
self.assertEqual(expected, self.client.db_read(1, 0, 100))

def test_as_db_get(self) -> None:
_buffer = typing_cast("Array[_SimpleCData[int]]", buffer_type())
size = c_int(buffer_size)
self.client.as_db_get(db_number, _buffer, size)
_buffer = typing_cast(Array[c_int], buffer_type())
self.client.as_db_get(db_number, _buffer, buffer_size)
self.client.wait_as_completion(500)
result = bytearray(_buffer)[: size.value]
self.assertEqual(100, len(result))
result = bytearray(_buffer)[:buffer_size]
self.assertEqual(buffer_size, len(result))

def test_as_db_read(self) -> None:
size = 40
Expand Down Expand Up @@ -776,9 +773,8 @@ def test_as_full_upload(self) -> None:
self.assertRaises(RuntimeError, self.client.wait_as_completion, 500)

def test_as_list_blocks_of_type(self) -> None:
data = typing_cast("Array[_SimpleCData[int]]", (c_uint16 * 10)())
count = c_int()
self.client.as_list_blocks_of_type(Block.DB, data, count)
data = typing_cast(Array[c_int], (c_uint16 * 10)())
self.client.as_list_blocks_of_type(Block.DB, data, 0)
self.assertRaises(RuntimeError, self.client.wait_as_completion, 500)

def test_as_mb_read(self) -> None:
Expand All @@ -801,24 +797,21 @@ def test_as_read_szl(self) -> None:
ssl_id = 0x011C
index = 0x0005
s7_szl = S7SZL()
size = c_int(sizeof(s7_szl))
self.client.as_read_szl(ssl_id, index, s7_szl, size)
self.client.as_read_szl(ssl_id, index, s7_szl, sizeof(s7_szl))
self.client.wait_as_completion(100)
result = bytes(s7_szl.Data)[2:26]
self.assertEqual(expected, result)

def test_as_read_szl_list(self) -> None:
# Cli_AsReadSZLList
expected = b"\x00\x00\x00\x0f\x02\x00\x11\x00\x11\x01\x11\x0f\x12\x00\x12\x01"
szl_list = S7SZLList()
items_count = c_int(sizeof(szl_list))
items_count = sizeof(szl_list)
self.client.as_read_szl_list(szl_list, items_count)
self.client.wait_as_completion(500)
result = bytearray(szl_list.List)[:16]
self.assertEqual(expected, result)

def test_as_tm_read(self) -> None:
# Cli_AsMBRead
expected = b"\x10\x01"
self.client.tm_write(0, 1, bytearray(expected))
type_ = WordLen.Timer.ctype
Expand All @@ -828,7 +821,6 @@ def test_as_tm_read(self) -> None:
self.assertEqual(expected, bytearray(buffer))

def test_as_tm_write(self) -> None:
# Cli_AsMBWrite
data = b"\x10\x01"
response = self.client.as_tm_write(0, 1, bytearray(data))
result = self.client.wait_as_completion(500)
Expand Down Expand Up @@ -927,8 +919,8 @@ def test_get_pg_block_info(self) -> None:
self.assertEqual(10, block_info.BlkType)
self.assertEqual(99, block_info.BlkNumber)
self.assertEqual(2752512, block_info.SBBLength)
self.assertEqual(bytes((utc2local(date(2019, 6, 27)).strftime("%Y/%m/%d")), encoding="utf-8"), block_info.CodeDate)
self.assertEqual(bytes((utc2local(date(2019, 6, 27)).strftime("%Y/%m/%d")), encoding="utf-8"), block_info.IntfDate)
self.assertEqual(bytes((date(2019, 6, 27).strftime("%Y/%m/%d")), encoding="utf-8"), block_info.CodeDate)
self.assertEqual(bytes((date(2019, 6, 27).strftime("%Y/%m/%d")), encoding="utf-8"), block_info.IntfDate)

def test_iso_exchange_buffer(self) -> None:
# Cli_IsoExchangeBuffer
Expand Down

0 comments on commit 018ccf4

Please sign in to comment.