52 lines
1.9 KiB
Python
52 lines
1.9 KiB
Python
|
from typing import Type, TypeVar, Sequence, Optional, Dict, Any
|
||
|
|
||
|
from sqlalchemy import update as update_, delete as delete_
|
||
|
from sqlalchemy.future import select
|
||
|
from db.models import ExperimentCategory, ExperimentData, ExperimentParameters, LoadParameters, RecyclingParameters
|
||
|
|
||
|
from db.postgres_db_connection import async_session_postgres
|
||
|
|
||
|
T = TypeVar("T", ExperimentCategory, ExperimentData, ExperimentParameters, LoadParameters, RecyclingParameters)
|
||
|
|
||
|
|
||
|
async def get_all(model_class: Type[T]) -> Sequence[T]:
|
||
|
async with async_session_postgres() as session:
|
||
|
result = await session.execute(select(model_class))
|
||
|
return result.scalars().all()
|
||
|
|
||
|
|
||
|
async def get_by_id(model_class: Type[T], id: int) -> Optional[T]:
|
||
|
async with async_session_postgres() as session:
|
||
|
result = await session.execute(select(model_class).where(model_class.id == id))
|
||
|
return result.scalar_one_or_none()
|
||
|
|
||
|
|
||
|
async def create(model_class: Type[T], **kwargs) -> T:
|
||
|
async with async_session_postgres() as session:
|
||
|
new_instance = model_class(**kwargs)
|
||
|
session.add(new_instance)
|
||
|
await session.commit()
|
||
|
await session.refresh(new_instance)
|
||
|
return new_instance
|
||
|
|
||
|
|
||
|
async def update(model_class: Type[T], id: int, updated_data: Dict[str, Any]) -> Optional[T]:
|
||
|
async with async_session_postgres() as session:
|
||
|
stmt = (
|
||
|
update_(model_class)
|
||
|
.where(model_class.id == id)
|
||
|
.values(**updated_data)
|
||
|
.execution_options(synchronize_session="fetch")
|
||
|
)
|
||
|
await session.execute(stmt)
|
||
|
await session.commit()
|
||
|
return await get_by_id(model_class, id)
|
||
|
|
||
|
|
||
|
async def delete(model_class: Type[T], id: int) -> bool:
|
||
|
async with async_session_postgres() as session:
|
||
|
stmt = delete_(model_class).where(model_class.id == id)
|
||
|
result = await session.execute(stmt)
|
||
|
await session.commit()
|
||
|
return result.rowcount > 0
|