Skip to content

Commit

Permalink
resolve conflict
Browse files Browse the repository at this point in the history
Signed-off-by: wiseaidev <business@wiseai.dev>
  • Loading branch information
wiseaidev committed Aug 10, 2022
2 parents 3b07e3f + a00a68b commit c483098
Show file tree
Hide file tree
Showing 15 changed files with 496 additions and 713 deletions.
1 change: 1 addition & 0 deletions aredis_om/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .async_redis import redis # isort:skip
from .checks import has_redis_json, has_redisearch
from .connections import get_redis_connection
from .model.migrations.migrator import MigrationError, Migrator
Expand Down
1 change: 1 addition & 0 deletions aredis_om/async_redis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from redis import asyncio as redis
8 changes: 4 additions & 4 deletions aredis_om/connections.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import os

import aioredis
from . import redis


URL = os.environ.get("REDIS_OM_URL", None)


def get_redis_connection(**kwargs) -> aioredis.Redis:
def get_redis_connection(**kwargs) -> redis.Redis:
# If someone passed in a 'url' parameter, or specified a REDIS_OM_URL
# environment variable, we'll create the Redis client from the URL.
url = kwargs.pop("url", URL)
if url:
return aioredis.Redis.from_url(url, **kwargs)
return redis.Redis.from_url(url, **kwargs)

# Decode from UTF-8 by default
if "decode_responses" not in kwargs:
kwargs["decode_responses"] = True
return aioredis.Redis(**kwargs)
return redis.Redis(**kwargs)
39 changes: 20 additions & 19 deletions aredis_om/model/migrations/migrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from enum import Enum
from typing import List, Optional

from aioredis import Redis, ResponseError
from ... import redis


log = logging.getLogger(__name__)
Expand Down Expand Up @@ -39,18 +39,19 @@ def schema_hash_key(index_name):
return f"{index_name}:hash"


async def create_index(redis: Redis, index_name, schema, current_hash):
db_number = redis.connection_pool.connection_kwargs.get("db")
async def create_index(conn: redis.Redis, index_name, schema, current_hash):
db_number = conn.connection_pool.connection_kwargs.get("db")
if db_number and db_number > 0:
raise MigrationError(
"Creating search indexes is only supported in database 0. "
f"You attempted to create an index in database {db_number}"
)
try:
await redis.execute_command(f"ft.info {index_name}")
except ResponseError:
await redis.execute_command(f"ft.create {index_name} {schema}")
await redis.set(schema_hash_key(index_name), current_hash)
await conn.execute_command(f"ft.info {index_name}")
except redis.ResponseError:
await conn.execute_command(f"ft.create {index_name} {schema}")
# TODO: remove "type: ignore" when type stubs will be fixed
await conn.set(schema_hash_key(index_name), current_hash) # type: ignore
else:
log.info("Index already exists, skipping. Index hash: %s", index_name)

Expand All @@ -67,7 +68,7 @@ class IndexMigration:
schema: str
hash: str
action: MigrationAction
redis: Redis
conn: redis.Redis
previous_hash: Optional[str] = None

async def run(self):
Expand All @@ -78,14 +79,14 @@ async def run(self):

async def create(self):
try:
await create_index(self.redis, self.index_name, self.schema, self.hash)
except ResponseError:
await create_index(self.conn, self.index_name, self.schema, self.hash)
except redis.ResponseError:
log.info("Index already exists: %s", self.index_name)

async def drop(self):
try:
await self.redis.execute_command(f"FT.DROPINDEX {self.index_name}")
except ResponseError:
await self.conn.execute_command(f"FT.DROPINDEX {self.index_name}")
except redis.ResponseError:
log.info("Index does not exist: %s", self.index_name)


Expand All @@ -105,7 +106,7 @@ async def detect_migrations(self):

for name, cls in model_registry.items():
hash_key = schema_hash_key(cls.Meta.index_name)
redis = cls.db()
conn = cls.db()
try:
schema = cls.redisearch_schema()
except NotImplementedError:
Expand All @@ -114,21 +115,21 @@ async def detect_migrations(self):
current_hash = hashlib.sha1(schema.encode("utf-8")).hexdigest() # nosec

try:
await redis.execute_command("ft.info", cls.Meta.index_name)
except ResponseError:
await conn.execute_command("ft.info", cls.Meta.index_name)
except redis.ResponseError:
self.migrations.append(
IndexMigration(
name,
cls.Meta.index_name,
schema,
current_hash,
MigrationAction.CREATE,
redis,
conn,
)
)
continue

stored_hash = await redis.get(hash_key)
stored_hash = await conn.get(hash_key)
schema_out_of_date = current_hash != stored_hash

if schema_out_of_date:
Expand All @@ -140,7 +141,7 @@ async def detect_migrations(self):
schema,
current_hash,
MigrationAction.DROP,
redis,
conn,
stored_hash,
)
)
Expand All @@ -151,7 +152,7 @@ async def detect_migrations(self):
schema,
current_hash,
MigrationAction.CREATE,
redis,
conn,
stored_hash,
)
)
Expand Down
95 changes: 60 additions & 35 deletions aredis_om/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@
no_type_check,
)

import aioredis
from aioredis.client import Pipeline
from more_itertools import ichunked
from pydantic import BaseModel, validator
from pydantic.fields import FieldInfo as PydanticFieldInfo
from pydantic.fields import ModelField, Undefined, UndefinedType
Expand All @@ -35,9 +34,10 @@
from typing_extensions import Protocol, get_args, get_origin
from ulid import ULID

from .. import redis
from ..checks import has_redis_json, has_redisearch
from ..connections import get_redis_connection
from ..unasync_util import ASYNC_MODE
from ..util import ASYNC_MODE
from .encoders import jsonable_encoder
from .render_tree import render_tree
from .token_escaper import TokenEscaper
Expand Down Expand Up @@ -760,6 +760,9 @@ async def all(self, batch_size=DEFAULT_PAGE_SIZE):
return await query.execute()
return await self.execute()

async def page(self, offset=0, limit=10):
return await self.copy(offset=offset, limit=limit).execute()

def sort_by(self, *fields: str):
if not fields:
return self
Expand Down Expand Up @@ -975,7 +978,7 @@ class BaseMeta(Protocol):
global_key_prefix: str
model_key_prefix: str
primary_key_pattern: str
database: aioredis.Redis
database: redis.Redis
primary_key: PrimaryKey
primary_key_creator_cls: Type[PrimaryKeyCreator]
index_name: str
Expand All @@ -994,7 +997,7 @@ class DefaultMeta:
global_key_prefix: Optional[str] = None
model_key_prefix: Optional[str] = None
primary_key_pattern: Optional[str] = None
database: Optional[aioredis.Redis] = None
database: Optional[redis.Redis] = None
primary_key: Optional[PrimaryKey] = None
primary_key_creator_cls: Optional[Type[PrimaryKeyCreator]] = None
index_name: Optional[str] = None
Expand Down Expand Up @@ -1102,6 +1105,7 @@ class Config:
extra = "allow"

def __init__(__pydantic_self__, **data: Any) -> None:
data = {key: val for key, val in data.items() if val}
super().__init__(**data)
__pydantic_self__.validate_primary_key()

Expand All @@ -1115,9 +1119,17 @@ def key(self):
return self.make_primary_key(pk)

@classmethod
async def delete(cls, pk: Any) -> int:
async def _delete(cls, db, *pks):
return await db.delete(*pks)

@classmethod
async def delete(
cls, pk: Any, pipeline: Optional[redis.client.Pipeline] = None
) -> int:
"""Delete data at this key."""
return await cls.db().delete(cls.make_primary_key(pk))
db = cls._get_db(pipeline)

return await cls._delete(db, cls.make_primary_key(pk))

@classmethod
async def get(cls, pk: Any) -> "RedisModel":
Expand All @@ -1127,14 +1139,15 @@ async def update(self, **field_values):
"""Update this model instance with the specified key-value pairs."""
raise NotImplementedError

async def save(self, pipeline: Optional[Pipeline] = None) -> "RedisModel":
async def save(
self, pipeline: Optional[redis.client.Pipeline] = None
) -> "RedisModel":
raise NotImplementedError

async def expire(self, num_seconds: int, pipeline: Optional[Pipeline] = None):
if pipeline is None:
db = self.db()
else:
db = pipeline
async def expire(
self, num_seconds: int, pipeline: Optional[redis.client.Pipeline] = None
):
db = self._get_db(pipeline)

# TODO: Wrap any Redis response errors in a custom exception?
await db.expire(self.make_primary_key(self.pk), num_seconds)
Expand Down Expand Up @@ -1223,19 +1236,10 @@ def get_annotations(cls):
async def add(
cls,
models: Sequence["RedisModel"],
pipeline: Optional[Pipeline] = None,
pipeline: Optional[redis.client.Pipeline] = None,
pipeline_verifier: Callable[..., Any] = verify_pipeline_response,
) -> Sequence["RedisModel"]:
if pipeline is None:
# By default, send commands in a pipeline. Saving each model will
# be atomic, but Redis may process other commands in between
# these saves.
db = cls.db().pipeline(transaction=False)
else:
# If the user gave us a pipeline, add our commands to that. The user
# will be responsible for executing the pipeline after they've accumulated
# the commands they want to send.
db = pipeline
db = cls._get_db(pipeline, bulk=True)

for model in models:
# save() just returns the model, we don't need that here.
Expand All @@ -1249,6 +1253,31 @@ async def add(

return models

@classmethod
def _get_db(
self, pipeline: Optional[redis.client.Pipeline] = None, bulk: bool = False
):
if pipeline is not None:
return pipeline
elif bulk:
return self.db().pipeline(transaction=False)
else:
return self.db()

@classmethod
async def delete_many(
cls,
models: Sequence["RedisModel"],
pipeline: Optional[redis.client.Pipeline] = None,
) -> int:
db = cls._get_db(pipeline)

for chunk in ichunked(models, 100):
pks = [cls.make_primary_key(model.pk) for model in chunk]
await cls._delete(db, *pks)

return len(models)

@classmethod
def redisearch_schema(cls):
raise NotImplementedError
Expand Down Expand Up @@ -1283,17 +1312,13 @@ def __init_subclass__(cls, **kwargs):
f"HashModels cannot index dataclass fields. Field: {name}"
)

def dict(self) -> Dict[str, Any]:
# restore none values
return dict(self)

async def save(self, pipeline: Optional[Pipeline] = None) -> "HashModel":
async def save(self, pipeline: Optional[redis.client.Pipeline] = None) -> "HashModel":
self.check()
if pipeline is None:
db = self.db()
else:
db = pipeline
document = jsonable_encoder({key: val if val else "0" for key, val in self.dict().items()})
document = jsonable_encoder(self.dict())
# TODO: Wrap any Redis response errors in a custom exception?
await db.hset(self.key(), mapping=document)
return self
Expand Down Expand Up @@ -1461,12 +1486,12 @@ def __init__(self, *args, **kwargs):
)
super().__init__(*args, **kwargs)

async def save(self, pipeline: Optional[Pipeline] = None) -> "JsonModel":
async def save(
self, pipeline: Optional[redis.client.Pipeline] = None
) -> "JsonModel":
self.check()
if pipeline is None:
db = self.db()
else:
db = pipeline
db = self._get_db(pipeline)

# TODO: Wrap response errors in a custom exception?
await db.execute_command("JSON.SET", self.key(), ".", self.json())
return self
Expand Down
1 change: 1 addition & 0 deletions aredis_om/sync_redis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
import redis
41 changes: 0 additions & 41 deletions aredis_om/unasync_util.py

This file was deleted.

12 changes: 12 additions & 0 deletions aredis_om/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import inspect


def is_async_mode():
async def f():
"""Unasync transforms async functions in sync functions"""
return None

return inspect.iscoroutinefunction(f)


ASYNC_MODE = is_async_mode()
Loading

0 comments on commit c483098

Please sign in to comment.