Загружает файлы, супер
This commit is contained in:
parent
93e2ffe021
commit
482daf29da
@ -1,30 +1,57 @@
|
|||||||
from typing import Optional
|
from datetime import datetime
|
||||||
|
|
||||||
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
||||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
||||||
|
from sqlalchemy import ForeignKey
|
||||||
|
|
||||||
engine = create_async_engine("sqlite+aiosqlite:///tasks.db")
|
engine = create_async_engine("sqlite+aiosqlite:///tasks.db")
|
||||||
new_session = async_sessionmaker(engine, expire_on_commit=False)
|
new_session = async_sessionmaker(engine, expire_on_commit=False)
|
||||||
|
|
||||||
|
class Base(DeclarativeBase):
|
||||||
class Model(DeclarativeBase):
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
class User(Base):
|
||||||
class TaskOrm(Model):
|
__tablename__ = "users"
|
||||||
__tablename__ = "tasks"
|
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(primary_key=True)
|
id: Mapped[int] = mapped_column(primary_key=True)
|
||||||
name: Mapped[str]
|
username: Mapped[str] = mapped_column(unique=True)
|
||||||
description: Mapped[Optional[str]]
|
password_hash: Mapped[str]
|
||||||
|
|
||||||
|
# Связи
|
||||||
|
csv_files = relationship("CSVFile", back_populates="user")
|
||||||
|
models = relationship("H5Model", back_populates="user")
|
||||||
|
|
||||||
|
class CSVFile(Base):
|
||||||
|
__tablename__ = "csv_files"
|
||||||
|
id: Mapped[int] = mapped_column(primary_key=True)
|
||||||
|
user_id: Mapped[int] = mapped_column(ForeignKey("users.id"))
|
||||||
|
file_path: Mapped[str]
|
||||||
|
uploaded_at: Mapped[datetime] = mapped_column(default=datetime.utcnow)
|
||||||
|
|
||||||
|
user = relationship("User", back_populates="csv_files")
|
||||||
|
|
||||||
|
class H5Model(Base):
|
||||||
|
__tablename__ = "h5_models"
|
||||||
|
id: Mapped[int] = mapped_column(primary_key=True)
|
||||||
|
user_id: Mapped[int] = mapped_column(ForeignKey("users.id"))
|
||||||
|
model_path: Mapped[str]
|
||||||
|
created_at: Mapped[datetime] = mapped_column(default=datetime.utcnow)
|
||||||
|
|
||||||
|
user = relationship("User", back_populates="models")
|
||||||
|
statistics = relationship("ModelStatistics", back_populates="model")
|
||||||
|
|
||||||
|
class ModelStatistics(Base):
|
||||||
|
__tablename__ = "model_statistics"
|
||||||
|
id: Mapped[int] = mapped_column(primary_key=True)
|
||||||
|
model_id: Mapped[int] = mapped_column(ForeignKey("h5_models.id"))
|
||||||
|
accuracy: Mapped[float]
|
||||||
|
loss: Mapped[float]
|
||||||
|
created_at: Mapped[datetime] = mapped_column(default=datetime.utcnow)
|
||||||
|
|
||||||
|
model = relationship("H5Model", back_populates="statistics")
|
||||||
|
|
||||||
async def create_tables():
|
async def create_tables():
|
||||||
# https://docs.sqlalchemy.org/en/20/orm/extensions/asyncio.html#synopsis-core
|
|
||||||
async with engine.begin() as conn:
|
async with engine.begin() as conn:
|
||||||
await conn.run_sync(Model.metadata.create_all)
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
|
||||||
|
|
||||||
async def delete_tables():
|
async def delete_tables():
|
||||||
async with engine.begin() as conn:
|
async with engine.begin() as conn:
|
||||||
await conn.run_sync(Model.metadata.drop_all)
|
await conn.run_sync(Base.metadata.drop_all)
|
||||||
|
@ -1,26 +1,36 @@
|
|||||||
|
from database import new_session, User, CSVFile, H5Model, ModelStatistics
|
||||||
|
from schemas import UserCreate, CSVFileUpload, H5ModelCreate, ModelStatisticsCreate
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
from database import new_session, TaskOrm
|
class UserRepository:
|
||||||
from schemas import STaskAdd, STask
|
@staticmethod
|
||||||
|
async def create_user(data: UserCreate):
|
||||||
|
|
||||||
class TaskRepository:
|
|
||||||
@classmethod
|
|
||||||
async def add_one(cls, data: STaskAdd) -> int:
|
|
||||||
async with new_session() as session:
|
async with new_session() as session:
|
||||||
task_dict = data.model_dump()
|
user = User(username=data.username, password_hash=data.password)
|
||||||
|
session.add(user)
|
||||||
task = TaskOrm(**task_dict)
|
|
||||||
session.add(task)
|
|
||||||
await session.flush()
|
|
||||||
await session.commit()
|
await session.commit()
|
||||||
return task.id
|
return user.id
|
||||||
|
|
||||||
@classmethod
|
class CSVFileRepository:
|
||||||
async def find_all(cls) -> list[STask]:
|
@staticmethod
|
||||||
|
async def upload_file(user_id: int, file_path: str):
|
||||||
async with new_session() as session:
|
async with new_session() as session:
|
||||||
query = select(TaskOrm)
|
csv_file = CSVFile(user_id=user_id, file_path=file_path)
|
||||||
result = await session.execute(query)
|
session.add(csv_file)
|
||||||
task_models = result.scalars().all()
|
await session.commit()
|
||||||
task_schemas = [STask.model_validate(task_model) for task_model in task_models]
|
|
||||||
return task_schemas
|
class H5ModelRepository:
|
||||||
|
@staticmethod
|
||||||
|
async def add_model(user_id: int, model_path: str):
|
||||||
|
async with new_session() as session:
|
||||||
|
model = H5Model(user_id=user_id, model_path=model_path)
|
||||||
|
session.add(model)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
class ModelStatisticsRepository:
|
||||||
|
@staticmethod
|
||||||
|
async def add_statistics(data: ModelStatisticsCreate):
|
||||||
|
async with new_session() as session:
|
||||||
|
stats = ModelStatistics(**data.model_dump())
|
||||||
|
session.add(stats)
|
||||||
|
await session.commit()
|
||||||
|
@ -1,25 +1,45 @@
|
|||||||
from typing import Annotated
|
from fastapi import APIRouter, UploadFile, File, Form
|
||||||
|
from repository import UserRepository, CSVFileRepository, H5ModelRepository, ModelStatisticsRepository
|
||||||
|
from schemas import UserCreate, ModelStatisticsCreate
|
||||||
|
import shutil
|
||||||
|
import os
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends
|
router = APIRouter()
|
||||||
|
|
||||||
from repository import TaskRepository
|
UPLOAD_FOLDER_CSV = "uploads/csv"
|
||||||
from schemas import STaskAdd, STask, STaskId
|
os.makedirs(UPLOAD_FOLDER_CSV, exist_ok=True)
|
||||||
|
|
||||||
router = APIRouter(
|
UPLOAD_FOLDER_MODELS = "uploads/models"
|
||||||
prefix="/tasks",
|
os.makedirs(UPLOAD_FOLDER_MODELS, exist_ok=True)
|
||||||
tags=["Таски"],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# Регистрация пользователя
|
||||||
|
@router.post("/users/")
|
||||||
|
async def create_user(user: UserCreate):
|
||||||
|
user_id = await UserRepository.create_user(user)
|
||||||
|
return {"user_id": user_id}
|
||||||
|
|
||||||
@router.post("")
|
# Загрузка CSV файла
|
||||||
async def add_task(
|
@router.post("/upload/csv/")
|
||||||
task: Annotated[STaskAdd, Depends()],
|
async def upload_csv(user_id: int = Form(...), file: UploadFile = File(...)):
|
||||||
) -> STaskId:
|
file_path = os.path.join(UPLOAD_FOLDER_CSV, file.filename)
|
||||||
task_id = await TaskRepository.add_one(task)
|
with open(file_path, "wb") as buffer:
|
||||||
return {"ok": True, "task_id": task_id}
|
shutil.copyfileobj(file.file, buffer)
|
||||||
|
|
||||||
|
await CSVFileRepository.upload_file(user_id=user_id, file_path=file_path)
|
||||||
|
return {"message": "CSV файл загружен", "file_path": file_path}
|
||||||
|
|
||||||
@router.get("")
|
# Загрузка H5 модели
|
||||||
async def get_tasks() -> list[STask]:
|
@router.post("/upload/h5/")
|
||||||
tasks = await TaskRepository.find_all()
|
async def upload_h5_model(user_id: int = Form(...), file: UploadFile = File(...)):
|
||||||
return tasks
|
file_path = os.path.join(UPLOAD_FOLDER_MODELS, file.filename)
|
||||||
|
with open(file_path, "wb") as buffer:
|
||||||
|
shutil.copyfileobj(file.file, buffer)
|
||||||
|
|
||||||
|
await H5ModelRepository.add_model(user_id=user_id, model_path=file_path)
|
||||||
|
return {"message": "H5 модель загружена", "file_path": file_path}
|
||||||
|
|
||||||
|
# Добавление статистики модели
|
||||||
|
@router.post("/models/statistics/")
|
||||||
|
async def add_model_statistics(stats: ModelStatisticsCreate):
|
||||||
|
await ModelStatisticsRepository.add_statistics(stats)
|
||||||
|
return {"message": "Статистика модели сохранена"}
|
||||||
|
@ -1,19 +1,23 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
from datetime import datetime
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
# UserCreate: для создания пользователя
|
||||||
|
class UserCreate(BaseModel):
|
||||||
|
username: str
|
||||||
|
password: str
|
||||||
|
|
||||||
|
# CSVFileUpload: для загрузки CSV файла
|
||||||
|
class CSVFileUpload(BaseModel):
|
||||||
|
file_path: str
|
||||||
|
|
||||||
class STaskAdd(BaseModel):
|
# H5ModelCreate: для добавления модели
|
||||||
name: str
|
class H5ModelCreate(BaseModel):
|
||||||
description: Optional[str] = None
|
model_path: str
|
||||||
|
|
||||||
|
# ModelStatisticsCreate: для сохранения статистики модели
|
||||||
class STask(STaskAdd):
|
class ModelStatisticsCreate(BaseModel):
|
||||||
id: int
|
model_id: int
|
||||||
|
accuracy: float
|
||||||
model_config = ConfigDict(from_attributes=True)
|
loss: float
|
||||||
|
created_at: Optional[datetime] = None
|
||||||
|
|
||||||
class STaskId(BaseModel):
|
|
||||||
ok: bool = True
|
|
||||||
task_id: int
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user