Загружает файлы, супер

This commit is contained in:
maksim 2024-12-17 01:01:44 +04:00
parent 93e2ffe021
commit 482daf29da
4 changed files with 128 additions and 67 deletions

View File

@ -1,30 +1,57 @@
from typing import Optional
from datetime import datetime
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
from sqlalchemy import ForeignKey
engine = create_async_engine("sqlite+aiosqlite:///tasks.db")
new_session = async_sessionmaker(engine, expire_on_commit=False)
class Model(DeclarativeBase):
class Base(DeclarativeBase):
pass
class TaskOrm(Model):
__tablename__ = "tasks"
class User(Base):
__tablename__ = "users"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str]
description: Mapped[Optional[str]]
username: Mapped[str] = mapped_column(unique=True)
password_hash: Mapped[str]
# Связи
csv_files = relationship("CSVFile", back_populates="user")
models = relationship("H5Model", back_populates="user")
class CSVFile(Base):
__tablename__ = "csv_files"
id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[int] = mapped_column(ForeignKey("users.id"))
file_path: Mapped[str]
uploaded_at: Mapped[datetime] = mapped_column(default=datetime.utcnow)
user = relationship("User", back_populates="csv_files")
class H5Model(Base):
__tablename__ = "h5_models"
id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[int] = mapped_column(ForeignKey("users.id"))
model_path: Mapped[str]
created_at: Mapped[datetime] = mapped_column(default=datetime.utcnow)
user = relationship("User", back_populates="models")
statistics = relationship("ModelStatistics", back_populates="model")
class ModelStatistics(Base):
__tablename__ = "model_statistics"
id: Mapped[int] = mapped_column(primary_key=True)
model_id: Mapped[int] = mapped_column(ForeignKey("h5_models.id"))
accuracy: Mapped[float]
loss: Mapped[float]
created_at: Mapped[datetime] = mapped_column(default=datetime.utcnow)
model = relationship("H5Model", back_populates="statistics")
async def create_tables():
# https://docs.sqlalchemy.org/en/20/orm/extensions/asyncio.html#synopsis-core
async with engine.begin() as conn:
await conn.run_sync(Model.metadata.create_all)
await conn.run_sync(Base.metadata.create_all)
async def delete_tables():
async with engine.begin() as conn:
await conn.run_sync(Model.metadata.drop_all)
await conn.run_sync(Base.metadata.drop_all)

View File

@ -1,26 +1,36 @@
from database import new_session, User, CSVFile, H5Model, ModelStatistics
from schemas import UserCreate, CSVFileUpload, H5ModelCreate, ModelStatisticsCreate
from sqlalchemy import select
from database import new_session, TaskOrm
from schemas import STaskAdd, STask
class TaskRepository:
@classmethod
async def add_one(cls, data: STaskAdd) -> int:
class UserRepository:
@staticmethod
async def create_user(data: UserCreate):
async with new_session() as session:
task_dict = data.model_dump()
task = TaskOrm(**task_dict)
session.add(task)
await session.flush()
user = User(username=data.username, password_hash=data.password)
session.add(user)
await session.commit()
return task.id
return user.id
@classmethod
async def find_all(cls) -> list[STask]:
class CSVFileRepository:
@staticmethod
async def upload_file(user_id: int, file_path: str):
async with new_session() as session:
query = select(TaskOrm)
result = await session.execute(query)
task_models = result.scalars().all()
task_schemas = [STask.model_validate(task_model) for task_model in task_models]
return task_schemas
csv_file = CSVFile(user_id=user_id, file_path=file_path)
session.add(csv_file)
await session.commit()
class H5ModelRepository:
@staticmethod
async def add_model(user_id: int, model_path: str):
async with new_session() as session:
model = H5Model(user_id=user_id, model_path=model_path)
session.add(model)
await session.commit()
class ModelStatisticsRepository:
@staticmethod
async def add_statistics(data: ModelStatisticsCreate):
async with new_session() as session:
stats = ModelStatistics(**data.model_dump())
session.add(stats)
await session.commit()

View File

@ -1,25 +1,45 @@
from typing import Annotated
from fastapi import APIRouter, UploadFile, File, Form
from repository import UserRepository, CSVFileRepository, H5ModelRepository, ModelStatisticsRepository
from schemas import UserCreate, ModelStatisticsCreate
import shutil
import os
from fastapi import APIRouter, Depends
router = APIRouter()
from repository import TaskRepository
from schemas import STaskAdd, STask, STaskId
UPLOAD_FOLDER_CSV = "uploads/csv"
os.makedirs(UPLOAD_FOLDER_CSV, exist_ok=True)
router = APIRouter(
prefix="/tasks",
tags=["Таски"],
)
UPLOAD_FOLDER_MODELS = "uploads/models"
os.makedirs(UPLOAD_FOLDER_MODELS, exist_ok=True)
# Регистрация пользователя
@router.post("/users/")
async def create_user(user: UserCreate):
user_id = await UserRepository.create_user(user)
return {"user_id": user_id}
@router.post("")
async def add_task(
task: Annotated[STaskAdd, Depends()],
) -> STaskId:
task_id = await TaskRepository.add_one(task)
return {"ok": True, "task_id": task_id}
# Загрузка CSV файла
@router.post("/upload/csv/")
async def upload_csv(user_id: int = Form(...), file: UploadFile = File(...)):
file_path = os.path.join(UPLOAD_FOLDER_CSV, file.filename)
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
await CSVFileRepository.upload_file(user_id=user_id, file_path=file_path)
return {"message": "CSV файл загружен", "file_path": file_path}
@router.get("")
async def get_tasks() -> list[STask]:
tasks = await TaskRepository.find_all()
return tasks
# Загрузка H5 модели
@router.post("/upload/h5/")
async def upload_h5_model(user_id: int = Form(...), file: UploadFile = File(...)):
file_path = os.path.join(UPLOAD_FOLDER_MODELS, file.filename)
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
await H5ModelRepository.add_model(user_id=user_id, model_path=file_path)
return {"message": "H5 модель загружена", "file_path": file_path}
# Добавление статистики модели
@router.post("/models/statistics/")
async def add_model_statistics(stats: ModelStatisticsCreate):
await ModelStatisticsRepository.add_statistics(stats)
return {"message": "Статистика модели сохранена"}

View File

@ -1,19 +1,23 @@
from pydantic import BaseModel
from datetime import datetime
from typing import Optional
from pydantic import BaseModel, ConfigDict
# UserCreate: для создания пользователя
class UserCreate(BaseModel):
username: str
password: str
# CSVFileUpload: для загрузки CSV файла
class CSVFileUpload(BaseModel):
file_path: str
class STaskAdd(BaseModel):
name: str
description: Optional[str] = None
# H5ModelCreate: для добавления модели
class H5ModelCreate(BaseModel):
model_path: str
class STask(STaskAdd):
id: int
model_config = ConfigDict(from_attributes=True)
class STaskId(BaseModel):
ok: bool = True
task_id: int
# ModelStatisticsCreate: для сохранения статистики модели
class ModelStatisticsCreate(BaseModel):
model_id: int
accuracy: float
loss: float
created_at: Optional[datetime] = None