Files
AIM-PIbd-32-kuznetsov-A-V/lab12/lab12.ipynb
2025-05-16 21:09:29 +04:00

5.8 MiB
Raw Blame History

Лабораторная 12

Задача: Мультиклассовая классификация изображений на 5 категорий (daisy, dandelion, rose, sunflower, tulip)

Ссылка на датасет: https://www.kaggle.com/datasets/rahmasleam/flowers-dataset

In [10]:
import os
import cv2
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import keras

os.environ["KERAS_BACKEND"] = "torch"

print(keras.__version__)
def load_images_from_folder(folder, target_size=(512, 512)):
    images = []
    labels = []
    for label in os.listdir(folder):
        if label in ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']:
            label_folder = os.path.join(folder, label)
            if os.path.isdir(label_folder):
                for filename in os.listdir(label_folder):
                    img_path = os.path.join(label_folder, filename)
                    img = cv2.imread(img_path)
                    if img is not None:
                        img_resized = cv2.resize(img, target_size)
                        images.append(img_resized)
                        labels.append(label)
    return images, labels

folder_path = "./static/csv/dataset_flower"
images, labels = load_images_from_folder(folder_path)
num_images_to_display = min(8, len(images))

def display_images(images, labels, max_images=10):
    if not images:
        print("Нет изображений для отображения.")
        return
    
    count = min(max_images, len(images))
    cols = 4
    rows = (count + cols - 1) // cols

    plt.figure(figsize=(15, 5 * rows))
    for i in range(count):
        plt.subplot(rows, cols, i + 1)
        plt.imshow(cv2.cvtColor(images[i], cv2.COLOR_BGR2RGB))
        plt.title(labels[i])
        plt.axis('off')
    plt.tight_layout()
    plt.show()

images, labels = load_images_from_folder(folder_path)
display_images(images, labels)


# Преобразование в массивы
images = np.array(images)  
labels = np.array(labels)
3.9.2
No description has been provided for this image

Предобработка изображений

In [11]:
import cv2
import numpy as np
import matplotlib.pyplot as plt

def preprocess_images(images):
    processed_images = []
    for img in images:
        img_resized = cv2.resize(img, (128, 128))
        img_gray = cv2.cvtColor(img_resized, cv2.COLOR_BGR2GRAY)
        img_eq = cv2.equalizeHist(img_gray)
        img_bin = cv2.adaptiveThreshold(img_eq, 255,  cv2.ADAPTIVE_THRESH_GAUSSIAN_C, 
    cv2.THRESH_BINARY, 11, 2)
        processed_images.append(img_bin)
    return np.array(processed_images)

processed_images = preprocess_images(images)

def display_single_image(original, processed, index):
    plt.figure(figsize=(10, 5))
    
    plt.subplot(1, 2, 1)
    plt.imshow(cv2.cvtColor(original[index], cv2.COLOR_BGR2RGB))
    plt.title('Оригинальное изображение')
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.imshow(processed[index], cmap='gray')
    plt.title('Обработанное изображение')
    plt.axis('off')

    plt.show()

index = 0  
display_single_image(images, processed_images, index)
No description has been provided for this image
In [12]:
from sklearn.model_selection import train_test_split

images_resized = np.array([cv2.resize(img, (28, 28)) for img in processed_images])
if images_resized.shape[-1] == 3:  # 
    images_resized = np.dot(images_resized[...,:3], [0.2989, 0.5870, 0.1140])  # Формула из OpenCV
    images_resized = np.expand_dims(images_resized, axis=-1)  # Добавляем размерность канала
images_resized = images_resized.astype("float32") / 255.0

X_train, X_valid, y_train, y_valid = train_test_split(
    images_resized, labels, test_size=0.2, stratify=labels, random_state=42
)
In [13]:
from sklearn.preprocessing import LabelEncoder
import tensorflow as tf

n_classes = 5
le = LabelEncoder()
y_train = le.fit_transform(y_train) 
y_valid = le.transform(y_valid)

X_train = X_train.reshape(-1, 28, 28, 1).astype("float32") / 255
X_valid = X_valid.reshape(-1, 28, 28, 1).astype("float32") / 255
y_train = keras.utils.to_categorical(y_train, n_classes)
y_valid = keras.utils.to_categorical(y_valid, n_classes)

display(X_train[0])
display(y_train[0])
array([[[3.7370243e-03],
        [1.3994618e-03],
        [1.0611304e-03],
        [1.9530950e-03],
        [3.0757401e-03],
        [1.9377163e-03],
        [1.3225683e-03],
        [3.0757401e-03],
        [3.9215689e-03],
        [1.0611304e-03],
        [1.9530950e-03],
        [2.8604383e-03],
        [1.3994618e-03],
        [3.2602844e-03],
        [3.7370243e-03],
        [1.9377163e-03],
        [3.9215689e-03],
        [0.0000000e+00],
        [2.7681663e-04],
        [3.6139947e-03],
        [3.9215689e-03],
        [8.3044986e-04],
        [2.2760478e-03],
        [0.0000000e+00],
        [0.0000000e+00],
        [0.0000000e+00],
        [3.9215689e-03],
        [2.4144561e-03]],

       [[3.9215689e-03],
        [1.3994618e-03],
        [0.0000000e+00],
        [3.9215689e-03],
        [3.9215689e-03],
        [1.6147635e-03],
        [3.9215689e-03],
        [1.6455210e-03],
        [3.9215689e-03],
        [9.2272203e-05],
        [3.2141483e-03],
        [2.5221070e-03],
        [2.5221070e-03],
        [5.3825456e-04],
        [3.0757402e-04],
        [3.9215689e-03],
        [0.0000000e+00],
        [2.6605153e-03],
        [2.5221070e-03],
        [1.3994618e-03],
        [2.8143022e-03],
        [3.9215689e-03],
        [0.0000000e+00],
        [0.0000000e+00],
        [0.0000000e+00],
        [3.9215689e-03],
        [0.0000000e+00],
        [3.9215689e-03]],

       [[3.9215689e-03],
        [2.6143792e-03],
        [5.3825456e-04],
        [3.9215689e-03],
        [3.3833142e-03],
        [3.9215689e-03],
        [9.9961564e-04],
        [3.1372549e-03],
        [3.8139177e-03],
        [1.5378702e-05],
        [1.9530950e-03],
        [3.6601308e-03],
        [0.0000000e+00],
        [0.0000000e+00],
        [0.0000000e+00],
        [2.3375626e-03],
        [3.6447521e-03],
        [2.0915035e-03],
        [3.9061899e-03],
        [1.3994618e-03],
        [9.9961564e-04],
        [0.0000000e+00],
        [1.8454441e-04],
        [3.9215689e-03],
        [1.9530950e-03],
        [3.3833142e-03],
        [3.8139177e-03],
        [3.6447521e-03]],

       [[4.1522493e-04],
        [0.0000000e+00],
        [3.9215689e-03],
        [2.9373318e-03],
        [3.9215689e-03],
        [1.2610535e-03],
        [3.9215689e-03],
        [3.9215689e-03],
        [0.0000000e+00],
        [3.9215689e-03],
        [9.8423695e-04],
        [1.3840832e-04],
        [1.2610535e-03],
        [4.1522493e-04],
        [3.5063438e-03],
        [3.2141483e-03],
        [0.0000000e+00],
        [2.9373318e-03],
        [0.0000000e+00],
        [0.0000000e+00],
        [3.5063438e-03],
        [3.9215689e-03],
        [1.2610535e-03],
        [0.0000000e+00],
        [0.0000000e+00],
        [0.0000000e+00],
        [0.0000000e+00],
        [3.9215689e-03]],

       [[3.9215689e-03],
        [3.9215689e-03],
        [0.0000000e+00],
        [3.9215689e-03],
        [3.6447521e-03],
        [2.7681663e-04],
        [3.7062669e-03],
        [0.0000000e+00],
        [1.5840061e-03],
        [5.2287587e-04],
        [1.9530950e-03],
        [3.9215689e-03],
        [2.6143792e-03],
        [3.1372549e-03],
        [3.8600538e-03],
        [3.9215689e-03],
        [2.7681663e-04],
        [0.0000000e+00],
        [0.0000000e+00],
        [3.6447521e-03],
        [8.4582856e-04],
        [3.1372549e-03],
        [0.0000000e+00],
        [3.6447521e-03],
        [1.8146867e-03],
        [0.0000000e+00],
        [0.0000000e+00],
        [3.9215689e-03]],

       [[3.0757402e-04],
        [9.0734335e-04],
        [3.9215689e-03],
        [1.2610535e-03],
        [3.8139177e-03],
        [2.1222609e-03],
        [0.0000000e+00],
        [8.4582856e-04],
        [3.4140716e-03],
        [2.7681663e-04],
        [0.0000000e+00],
        [2.6143792e-03],
        [3.9215689e-03],
        [3.0757401e-03],
        [0.0000000e+00],
        [2.3068052e-03],
        [3.9215689e-03],
        [0.0000000e+00],
        [3.9215689e-03],
        [3.9215689e-03],
        [8.4582856e-04],
        [3.9215689e-03],
        [2.3068052e-03],
        [0.0000000e+00],
        [0.0000000e+00],
        [3.9215689e-03],
        [3.4140716e-03],
        [3.9215689e-03]],

       [[3.9215689e-03],
        [1.9377163e-03],
        [3.9215689e-03],
        [2.3836987e-03],
        [3.0757401e-03],
        [3.9215689e-03],
        [3.9215689e-03],
        [3.9215689e-03],
        [3.9215689e-03],
        [3.1372549e-03],
        [2.3836987e-03],
        [6.1514809e-05],
        [8.4582856e-04],
        [0.0000000e+00],
        [3.7370243e-03],
        [3.0757402e-04],
        [3.9215689e-03],
        [1.9530950e-03],
        [3.7062669e-03],
        [2.8143022e-03],
        [3.9215689e-03],
        [3.2602844e-03],
        [1.9377163e-03],
        [0.0000000e+00],
        [0.0000000e+00],
        [3.9215689e-03],
        [3.9215689e-03],
        [2.5990005e-03]],

       [[1.8454441e-04],
        [1.9838526e-03],
        [2.8604383e-03],
        [0.0000000e+00],
        [2.8604383e-03],
        [3.9215689e-03],
        [3.2602844e-03],
        [0.0000000e+00],
        [1.6455210e-03],
        [1.0611304e-03],
        [4.1522493e-04],
        [3.8600538e-03],
        [1.9377163e-03],
        [6.6128414e-04],
        [0.0000000e+00],
        [3.9215689e-03],
        [3.9215689e-03],
        [0.0000000e+00],
        [2.7681663e-04],
        [1.3994618e-03],
        [3.9215689e-03],
        [8.4582856e-04],
        [1.1072665e-03],
        [9.9961564e-04],
        [1.9530950e-03],
        [2.9219531e-03],
        [3.3833142e-03],
        [3.7370243e-03]],

       [[2.8143022e-03],
        [3.0142253e-03],
        [9.2272203e-05],
        [3.9215689e-03],
        [3.9215689e-03],
        [4.9211847e-04],
        [3.6139947e-03],
        [0.0000000e+00],
        [3.9215689e-03],
        [0.0000000e+00],
        [7.0742023e-04],
        [1.3071896e-03],
        [3.9215689e-03],
        [3.6139947e-03],
        [3.0757401e-03],
        [2.5221070e-03],
        [0.0000000e+00],
        [0.0000000e+00],
        [1.5840061e-03],
        [4.9211847e-04],
        [2.5221070e-03],
        [5.3825456e-04],
        [1.3994618e-03],
        [0.0000000e+00],
        [2.6605153e-03],
        [3.9215689e-03],
        [3.9215689e-03],
        [8.4582856e-04]],

       [[2.1530181e-04],
        [2.7681663e-04],
        [0.0000000e+00],
        [0.0000000e+00],
        [3.6447521e-03],
        [3.9215689e-03],
        [8.4582856e-04],
        [2.9219531e-03],
        [2.6143792e-03],
        [3.6601308e-03],
        [3.9215689e-03],
        [0.0000000e+00],
        [1.8454441e-04],
        [3.7062669e-03],
        [3.9215689e-03],
        [3.8139177e-03],
        [0.0000000e+00],
        [0.0000000e+00],
        [2.6143793e-04],
        [3.9215689e-03],
        [7.8431371e-04],
        [7.8431371e-04],
        [1.3994618e-03],
        [0.0000000e+00],
        [1.9530950e-03],
        [3.9215689e-03],
        [3.9215689e-03],
        [3.1372549e-03]],

       [[3.9215689e-03],
        [3.9215689e-03],
        [3.7831604e-03],
        [3.9215689e-03],
        [1.9684739e-03],
        [1.9684739e-03],
        [1.5378700e-03],
        [3.5063438e-03],
        [1.2610535e-03],
        [0.0000000e+00],
        [0.0000000e+00],
        [1.3840832e-04],
        [3.2141483e-03],
        [0.0000000e+00],
        [3.9215689e-03],
        [1.2610535e-03],
        [1.3840832e-04],
        [2.9373318e-03],
        [2.1068822e-03],
        [3.9215689e-03],
        [3.9215689e-03],
        [3.9215689e-03],
        [7.0742023e-04],
        [3.6447521e-03],
        [1.9684739e-03],
        [3.9215689e-03],
        [1.2610535e-03],
        [3.9215689e-03]],

       [[3.7062669e-03],
        [3.9215689e-03],
        [3.9215689e-03],
        [1.8146867e-03],
        [3.9215689e-03],
        [3.6447521e-03],
        [2.8604383e-03],
        [0.0000000e+00],
        [2.7681663e-04],
        [3.9215689e-03],
        [3.7831604e-03],
        [3.6447521e-03],
        [2.5221070e-03],
        [8.4582856e-04],
        [3.9215689e-03],
        [0.0000000e+00],
        [3.9215689e-03],
        [3.9215689e-03],
        [2.6143793e-04],
        [1.8454441e-04],
        [8.4582856e-04],
        [3.7062669e-03],
        [1.3071896e-03],
        [3.6601308e-03],
        [1.9530950e-03],
        [3.6447521e-03],
        [1.8454441e-04],
        [2.8604383e-03]],

       [[3.9215689e-03],
        [3.4140716e-03],
        [3.9215689e-03],
        [1.3994618e-03],
        [0.0000000e+00],
        [2.5221070e-03],
        [2.8143022e-03],
        [3.0757401e-03],
        [1.6147635e-03],
        [2.7681663e-04],
        [7.0742023e-04],
        [1.8454441e-04],
        [0.0000000e+00],
        [0.0000000e+00],
        [3.6139947e-03],
        [1.6147635e-03],
        [3.9215689e-03],
        [3.9215689e-03],
        [1.8454441e-04],
        [0.0000000e+00],
        [1.3994618e-03],
        [0.0000000e+00],
        [3.9215689e-03],
        [0.0000000e+00],
        [1.9530950e-03],
        [2.6143792e-03],
        [4.9211847e-04],
        [3.9215689e-03]],

       [[3.2602844e-03],
        [1.9838526e-03],
        [0.0000000e+00],
        [3.9215689e-03],
        [3.6447521e-03],
        [2.8143022e-03],
        [0.0000000e+00],
        [8.4582856e-04],
        [3.0757401e-03],
        [2.1530181e-04],
        [2.3836987e-03],
        [0.0000000e+00],
        [1.9838526e-03],
        [6.6128414e-04],
        [6.6128414e-04],
        [3.3833142e-03],
        [2.8604383e-03],
        [3.9215689e-03],
        [2.7681663e-04],
        [3.9215689e-03],
        [3.2602844e-03],
        [3.7370243e-03],
        [2.5221070e-03],
        [3.0757401e-03],
        [3.9215689e-03],
        [7.8431371e-04],
        [3.9215689e-03],
        [0.0000000e+00]],

       [[0.0000000e+00],
        [3.3833142e-03],
        [3.9215689e-03],
        [2.3836987e-03],
        [3.9215689e-03],
        [1.9377163e-03],
        [3.9215689e-03],
        [2.5990005e-03],
        [5.3825456e-04],
        [6.1514809e-05],
        [2.3836987e-03],
        [2.7681663e-04],
        [3.0757402e-04],
        [3.9215689e-03],
        [0.0000000e+00],
        [2.5221070e-03],
        [0.0000000e+00],
        [0.0000000e+00],
        [6.1514809e-05],
        [8.4582856e-04],
        [3.9215689e-03],
        [3.9215689e-03],
        [3.0757402e-04],
        [0.0000000e+00],
        [1.5378700e-03],
        [0.0000000e+00],
        [3.6139947e-03],
        [3.0757401e-03]],

       [[1.1072665e-03],
        [9.0734335e-04],
        [3.9215689e-03],
        [0.0000000e+00],
        [2.7681663e-04],
        [2.3068052e-03],
        [3.0757401e-03],
        [0.0000000e+00],
        [0.0000000e+00],
        [3.9215689e-03],
        [7.0742023e-04],
        [9.2272203e-05],
        [3.9215689e-03],
        [0.0000000e+00],
        [1.9838526e-03],
        [3.9215689e-03],
        [3.8139177e-03],
        [3.9215689e-03],
        [2.7681663e-04],
        [3.9215689e-03],
        [5.3825456e-04],
        [1.9377163e-03],
        [2.5221070e-03],
        [1.5840061e-03],
        [1.9530950e-03],
        [0.0000000e+00],
        [2.5221070e-03],
        [1.9838526e-03]],

       [[3.0757401e-03],
        [2.7681663e-04],
        [3.9215689e-03],
        [0.0000000e+00],
        [3.9215689e-03],
        [1.5840061e-03],
        [2.8604383e-03],
        [3.8600538e-03],
        [0.0000000e+00],
        [2.6143793e-04],
        [2.7681663e-04],
        [2.7681663e-04],
        [0.0000000e+00],
        [0.0000000e+00],
        [3.9215689e-03],
        [0.0000000e+00],
        [3.9061899e-03],
        [3.9215689e-03],
        [3.9215689e-03],
        [0.0000000e+00],
        [3.9215689e-03],
        [3.1372549e-03],
        [3.7370243e-03],
        [1.5378702e-05],
        [3.7831604e-03],
        [2.6143793e-04],
        [3.9215689e-03],
        [0.0000000e+00]],

       [[0.0000000e+00],
        [1.9684739e-03],
        [3.7831604e-03],
        [3.9215689e-03],
        [2.1068822e-03],
        [0.0000000e+00],
        [1.9684739e-03],
        [0.0000000e+00],
        [1.3994618e-03],
        [2.7681663e-04],
        [2.9373318e-03],
        [1.8146867e-03],
        [0.0000000e+00],
        [3.9215689e-03],
        [3.5063438e-03],
        [1.9684739e-03],
        [0.0000000e+00],
        [0.0000000e+00],
        [3.9215689e-03],
        [2.6605153e-03],
        [0.0000000e+00],
        [4.1522493e-04],
        [0.0000000e+00],
        [2.1068822e-03],
        [3.9215689e-03],
        [2.7681663e-04],
        [2.5221070e-03],
        [1.5378700e-03]],

       [[2.1530181e-04],
        [0.0000000e+00],
        [3.9215689e-03],
        [0.0000000e+00],
        [0.0000000e+00],
        [1.3994618e-03],
        [3.0757401e-03],
        [3.9215689e-03],
        [2.3375626e-03],
        [2.7681663e-04],
        [0.0000000e+00],
        [2.7681663e-04],
        [1.3994618e-03],
        [2.7681663e-04],
        [3.1372549e-03],
        [0.0000000e+00],
        [0.0000000e+00],
        [3.7831604e-03],
        [0.0000000e+00],
        [3.7370243e-03],
        [3.6447521e-03],
        [2.1530181e-04],
        [1.5840061e-03],
        [3.6601308e-03],
        [3.7831604e-03],
        [2.6143793e-04],
        [2.6143792e-03],
        [8.4582856e-04]],

       [[3.9215689e-03],
        [3.0142253e-03],
        [0.0000000e+00],
        [1.9530950e-03],
        [2.5221070e-03],
        [3.0142253e-03],
        [3.0757402e-04],
        [0.0000000e+00],
        [0.0000000e+00],
        [1.3994618e-03],
        [3.9215689e-03],
        [3.9215689e-03],
        [1.6147635e-03],
        [3.9215689e-03],
        [2.2760478e-03],
        [3.9215689e-03],
        [3.7370243e-03],
        [7.0742023e-04],
        [2.6143792e-03],
        [2.3068052e-03],
        [1.9838526e-03],
        [3.9215689e-03],
        [3.9215689e-03],
        [2.7681663e-04],
        [1.2610535e-03],
        [1.3071896e-03],
        [3.9215689e-03],
        [3.3833142e-03]],

       [[8.3044986e-04],
        [0.0000000e+00],
        [8.4582856e-04],
        [8.4582856e-04],
        [9.9961564e-04],
        [0.0000000e+00],
        [3.7370243e-03],
        [3.9215689e-03],
        [3.3833142e-03],
        [0.0000000e+00],
        [3.9215689e-03],
        [7.8431371e-04],
        [5.3825456e-04],
        [3.7370243e-03],
        [3.9215689e-03],
        [3.3833142e-03],
        [6.1514809e-05],
        [2.3836987e-03],
        [1.0611304e-03],
        [3.0757401e-03],
        [0.0000000e+00],
        [0.0000000e+00],
        [0.0000000e+00],
        [7.8431371e-04],
        [3.9215689e-03],
        [2.1530181e-04],
        [5.3825456e-04],
        [0.0000000e+00]],

       [[6.6128414e-04],
        [0.0000000e+00],
        [7.8431371e-04],
        [2.3836987e-03],
        [3.9215689e-03],
        [5.3825456e-04],
        [2.4144561e-03],
        [8.3044986e-04],
        [0.0000000e+00],
        [0.0000000e+00],
        [3.9215689e-03],
        [2.7681663e-04],
        [3.9215689e-03],
        [3.2602844e-03],
        [1.8454441e-04],
        [3.9215689e-03],
        [2.7681663e-04],
        [3.9215689e-03],
        [2.8604383e-03],
        [3.0757402e-04],
        [3.9215689e-03],
        [3.9215689e-03],
        [5.3825456e-04],
        [3.9215689e-03],
        [0.0000000e+00],
        [2.8604383e-03],
        [1.9838526e-03],
        [3.2602844e-03]],

       [[3.9215689e-03],
        [3.9215689e-03],
        [2.7681663e-04],
        [3.9215689e-03],
        [1.5840061e-03],
        [3.0142253e-03],
        [3.9215689e-03],
        [3.3833142e-03],
        [9.0734335e-04],
        [2.3375626e-03],
        [2.5221070e-03],
        [1.3994618e-03],
        [1.7993080e-03],
        [3.0757402e-04],
        [0.0000000e+00],
        [3.9215689e-03],
        [3.9215689e-03],
        [3.9215689e-03],
        [0.0000000e+00],
        [0.0000000e+00],
        [3.0757402e-04],
        [3.9215689e-03],
        [9.0734335e-04],
        [2.3375626e-03],
        [2.5221070e-03],
        [0.0000000e+00],
        [3.9215689e-03],
        [3.9215689e-03]],

       [[3.7062669e-03],
        [0.0000000e+00],
        [3.6601308e-03],
        [3.9215689e-03],
        [2.7681663e-04],
        [1.4763552e-03],
        [3.8600538e-03],
        [1.0611304e-03],
        [3.9215689e-03],
        [2.7681663e-04],
        [0.0000000e+00],
        [1.5378702e-05],
        [3.7370243e-03],
        [3.0757401e-03],
        [0.0000000e+00],
        [3.8139177e-03],
        [2.7681663e-04],
        [0.0000000e+00],
        [3.3833142e-03],
        [3.9215689e-03],
        [0.0000000e+00],
        [3.9215689e-03],
        [3.9215689e-03],
        [1.5378702e-05],
        [1.9530950e-03],
        [1.5378702e-05],
        [1.3994618e-03],
        [3.9215689e-03]],

       [[1.9684739e-03],
        [1.9684739e-03],
        [0.0000000e+00],
        [0.0000000e+00],
        [1.9684739e-03],
        [3.9215689e-03],
        [1.5378700e-03],
        [1.9684739e-03],
        [1.2610535e-03],
        [1.3840832e-04],
        [3.9215689e-03],
        [1.8146867e-03],
        [7.0742023e-04],
        [1.9684739e-03],
        [0.0000000e+00],
        [3.2141483e-03],
        [3.6447521e-03],
        [1.9684739e-03],
        [1.8146867e-03],
        [1.9530950e-03],
        [0.0000000e+00],
        [2.3836987e-03],
        [1.3994618e-03],
        [3.6447521e-03],
        [9.8423695e-04],
        [3.9215689e-03],
        [0.0000000e+00],
        [2.3836987e-03]],

       [[9.9961564e-04],
        [3.9215689e-03],
        [1.5378702e-05],
        [1.8146867e-03],
        [0.0000000e+00],
        [2.5221070e-03],
        [2.8604383e-03],
        [2.9219531e-03],
        [1.3071896e-03],
        [0.0000000e+00],
        [0.0000000e+00],
        [2.6143793e-04],
        [2.5221070e-03],
        [3.6447521e-03],
        [1.0611304e-03],
        [3.9215689e-03],
        [0.0000000e+00],
        [1.9530950e-03],
        [3.3833142e-03],
        [1.3994618e-03],
        [2.1530181e-04],
        [3.9215689e-03],
        [2.6143792e-03],
        [0.0000000e+00],
        [3.7831604e-03],
        [3.3833142e-03],
        [3.9215689e-03],
        [2.9219531e-03]],

       [[3.0757402e-04],
        [0.0000000e+00],
        [2.7681663e-04],
        [1.9530950e-03],
        [9.2272203e-05],
        [1.3994618e-03],
        [2.8143022e-03],
        [0.0000000e+00],
        [4.9211847e-04],
        [0.0000000e+00],
        [3.9215689e-03],
        [2.7681663e-04],
        [1.3994618e-03],
        [3.6139947e-03],
        [3.0757402e-04],
        [2.5221070e-03],
        [3.9215689e-03],
        [1.9530950e-03],
        [2.3375626e-03],
        [3.4140716e-03],
        [1.9377163e-03],
        [1.6455210e-03],
        [1.6147635e-03],
        [2.3375626e-03],
        [3.9215689e-03],
        [0.0000000e+00],
        [3.4140716e-03],
        [2.8143022e-03]],

       [[0.0000000e+00],
        [3.0757401e-03],
        [8.4582856e-04],
        [0.0000000e+00],
        [2.1530181e-04],
        [1.3994618e-03],
        [3.9215689e-03],
        [1.8454441e-04],
        [3.0757402e-04],
        [6.1514809e-05],
        [0.0000000e+00],
        [3.0757401e-03],
        [2.8143022e-03],
        [6.6128414e-04],
        [6.6128414e-04],
        [3.0757401e-03],
        [7.8431371e-04],
        [1.9530950e-03],
        [3.6447521e-03],
        [2.8143022e-03],
        [6.6128414e-04],
        [3.0757401e-03],
        [0.0000000e+00],
        [3.9215689e-03],
        [1.5378700e-03],
        [1.0611304e-03],
        [3.9215689e-03],
        [2.4144561e-03]]], dtype=float32)
array([1., 0., 0., 0., 0.])

Проектирование архитектуры LeNet-5

In [14]:
from keras.api.models import Sequential
from keras.api.layers import InputLayer, Conv2D, MaxPooling2D, Dropout, Flatten, Dense

lenet_model = Sequential()

# Входной слой
lenet_model.add(InputLayer(shape=(28, 28, 1)))

# Первый скрытый слой
lenet_model.add(Conv2D(32, kernel_size=(3, 3), activation="relu"))

# Второй скрытый слой
lenet_model.add(Conv2D(64, kernel_size=(3, 3), activation="relu"))
lenet_model.add(MaxPooling2D(pool_size=(2, 2)))


# Третий скрытый слой
lenet_model.add(Flatten())
lenet_model.add(Dense(128, activation="relu"))
lenet_model.add(Dropout(0.5))

# Выходной слой
lenet_model.add(Dense(n_classes, activation="softmax"))

lenet_model.summary()
Model: "sequential_1"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                     Output Shape                  Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ conv2d_2 (Conv2D)               │ (None, 26, 26, 32)     │           320 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_3 (Conv2D)               │ (None, 24, 24, 64)     │        18,496 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d_1 (MaxPooling2D)  │ (None, 12, 12, 64)     │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ flatten_1 (Flatten)             │ (None, 9216)           │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_2 (Dense)                 │ (None, 128)            │     1,179,776 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dropout_1 (Dropout)             │ (None, 128)            │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_3 (Dense)                 │ (None, 5)              │           645 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 1,199,237 (4.57 MB)
 Trainable params: 1,199,237 (4.57 MB)
 Non-trainable params: 0 (0.00 B)

Обучение глубокой модели

In [15]:
lenet_model.compile(
    loss="categorical_crossentropy",
    optimizer="adam",
    metrics=["accuracy"],
)

lenet_model.fit(
    X_train,
    y_train,
    batch_size=128,
    epochs=10,
    validation_data=(X_valid, y_valid),
)
Epoch 1/10
8/8 ━━━━━━━━━━━━━━━━━━━━ 2s 118ms/step - accuracy: 0.5089 - loss: 1.4596 - val_accuracy: 0.5652 - val_loss: 0.7946
Epoch 2/10
8/8 ━━━━━━━━━━━━━━━━━━━━ 1s 91ms/step - accuracy: 0.5371 - loss: 0.7967 - val_accuracy: 0.4348 - val_loss: 0.7080
Epoch 3/10
8/8 ━━━━━━━━━━━━━━━━━━━━ 1s 90ms/step - accuracy: 0.5210 - loss: 0.7743 - val_accuracy: 0.5652 - val_loss: 0.6919
Epoch 4/10
8/8 ━━━━━━━━━━━━━━━━━━━━ 1s 89ms/step - accuracy: 0.5461 - loss: 0.7277 - val_accuracy: 0.5652 - val_loss: 0.6983
Epoch 5/10
8/8 ━━━━━━━━━━━━━━━━━━━━ 1s 86ms/step - accuracy: 0.5538 - loss: 0.7081 - val_accuracy: 0.5652 - val_loss: 0.6918
Epoch 6/10
8/8 ━━━━━━━━━━━━━━━━━━━━ 1s 86ms/step - accuracy: 0.5112 - loss: 0.7309 - val_accuracy: 0.5652 - val_loss: 0.6887
Epoch 7/10
8/8 ━━━━━━━━━━━━━━━━━━━━ 1s 90ms/step - accuracy: 0.5446 - loss: 0.7038 - val_accuracy: 0.5652 - val_loss: 0.6871
Epoch 8/10
8/8 ━━━━━━━━━━━━━━━━━━━━ 1s 102ms/step - accuracy: 0.5310 - loss: 0.7137 - val_accuracy: 0.5652 - val_loss: 0.6866
Epoch 9/10
8/8 ━━━━━━━━━━━━━━━━━━━━ 1s 89ms/step - accuracy: 0.5198 - loss: 0.7142 - val_accuracy: 0.5652 - val_loss: 0.6878
Epoch 10/10
8/8 ━━━━━━━━━━━━━━━━━━━━ 1s 90ms/step - accuracy: 0.5262 - loss: 0.7058 - val_accuracy: 0.5652 - val_loss: 0.6872
Out[15]:
<keras.src.callbacks.history.History at 0x29089f104a0>

Оценка качества модели

In [16]:
lenet_model.evaluate(X_valid, y_valid)
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - accuracy: 0.6047 - loss: 0.6794 
Out[16]:
[0.6871750354766846, 0.5652173757553101]

Классификация текстов

In [17]:
import pandas as pd
import os
import win32com.client as win32
from tqdm import tqdm

def read_doc(file_path):
    try:
        if not os.path.exists(file_path):
            raise FileNotFoundError(f"Файл {file_path} не существует")
        if not file_path.lower().endswith('.doc'):
            raise ValueError(f"Неверное расширение файла: {file_path}")
        word = win32.Dispatch("Word.Application")
        word.Visible = False  
        try:
            doc = word.Documents.Open(os.path.abspath(file_path))
            text = doc.Content.Text
            doc.Close(False)  
            return text.strip()
        except Exception as e:
            print(f"Ошибка при чтении {file_path}: {e}")
            return None
        finally:
            word.Quit()
    except Exception as e:
        print(f"Критическая ошибка для {file_path}: {e}")
        return None

def load_docs(dataset_path):
    dataset_path = os.path.abspath(dataset_path)
    if not os.path.exists(dataset_path):
        raise FileNotFoundError(f"Директория {dataset_path} не существует")
    df = pd.DataFrame(columns=["doc", "text"])
    files = [f for f in os.listdir(dataset_path) 
             if f.lower().endswith('.doc') and not f.startswith('~$')]
    for file_name in tqdm(files, desc="Обработка документов"):
        full_path = os.path.join(dataset_path, file_name)
        text = read_doc(full_path)
        if text:  
            df.loc[len(df)] = {
                "doc": file_name,
                "text": text
            }
    return df
try:
    base_path = ".//static//csv//wsw"
    df = load_docs(base_path)

    df["type"] = df["doc"].str.lower().str.startswith("tz_").astype(int)
    df.sort_values("doc", inplace=True)

    display(df.head())
    print(f"Успешно обработано {len(df)} из {len(files)} файлов")

except Exception as e:
    print(f"Ошибка в основном потоке: {e}")
Обработка документов: 100%|██████████| 41/41 [03:37<00:00,  5.30s/it]
<style scoped=""> .dataframe tbody tr th:only-of-type { vertical-align: middle; } .dataframe tbody tr th { vertical-align: top; } .dataframe thead th { text-align: right; } </style>
doc text type
0 tz_01.doc 2.2 ТЕХНИЧЕСКОЕ ЗАДАНИЕ\r2.2.1 Общие сведения\... 1
1 tz_02.doc 2.2 ТЕХНИЧЕСКОЕ ЗАДАНИЕ\r2.2.1 Общие сведения\... 1
2 tz_03.doc 2.2. ТЕХНИЧЕСКОЕ ЗАДАНИЕ\rОбщие сведения:\rВ д... 1
3 tz_04.doc ТЕХНИЧЕСКОЕ ЗАДАНИЕ\r2.2.1 ОБЩИЕ СВЕДЕНИЯ\rИнт... 1
4 tz_05.doc 2.2 ТЕХНИЧЕСКОЕ ЗАДАНИЕ\r2.2.1 Общие сведения.... 1
Ошибка в основном потоке: name 'files' is not defined

Предобработка текста

In [18]:
from gensim.models.phrases import Phraser, Phrases
import spacy

sp = spacy.load("ru_core_news_lg")

def prep_text(text):
    doc = sp(text)
    lower_sents = []
    for sent in doc.sents:
        lower_sents.append([word.lemma_.lower() for word in sent if not word.is_punct and not word.is_stop and not word.is_space])
    lower_bigram = Phraser(Phrases(lower_sents))
    clean_sents = []
    for sent in lower_sents:
        clean_sents.append(lower_bigram[sent])
    return clean_sents

df["prep_text"] = df.apply(lambda row: prep_text(row["text"]), axis=1)
df
Out[18]:
<style scoped=""> .dataframe tbody tr th:only-of-type { vertical-align: middle; } .dataframe tbody tr th { vertical-align: top; } .dataframe thead th { text-align: right; } </style>
doc text type prep_text
0 tz_01.doc 2.2 ТЕХНИЧЕСКОЕ ЗАДАНИЕ\r2.2.1 Общие сведения\... 1 [[2.2, техническое, задание, 2.2.1, общий, све...
1 tz_02.doc 2.2 ТЕХНИЧЕСКОЕ ЗАДАНИЕ\r2.2.1 Общие сведения\... 1 [[2.2, техническое, задание, 2.2.1, общий, све...
2 tz_03.doc 2.2. ТЕХНИЧЕСКОЕ ЗАДАНИЕ\rОбщие сведения:\rВ д... 1 [[2.2], [техническое, задание, общий, сведение...
3 tz_04.doc ТЕХНИЧЕСКОЕ ЗАДАНИЕ\r2.2.1 ОБЩИЕ СВЕДЕНИЯ\rИнт... 1 [[техническое, задание, 2.2.1, общие, сведение...
4 tz_05.doc 2.2 ТЕХНИЧЕСКОЕ ЗАДАНИЕ\r2.2.1 Общие сведения.... 1 [[2.2, техническое, задание, 2.2.1, общий, све...
5 tz_06.doc 2.2 ТЕХНИЧЕСКОЕ ЗАДАНИЕ\t\r1.Общие сведения\rП... 1 [[2.2, техническое, задание, 1.общие, сведение...
6 tz_07.doc ТЕХНИЧЕСКОЕ ЗАДАНИЕ\rОбщие сведения\rВ данном ... 1 [[техническое, задание, общий, сведение, разде...
7 tz_08.doc Техническое задание\r1 Общие сведения\r1.1 Пол... 1 [[технический, задание, 1, общий, сведение, 1....
8 tz_09.doc 2.2. ТЕХНИЧЕСКОЕ ЗАДАНИЕ\r2.2.1.\rОбщие сведен... 1 [[2.2], [техническое, задание, 2.2.1], [общий,...
9 tz_10.doc 2.2. ТЕХНИЧЕСКОЕ ЗАДАНИЕ\r2.2.1. Общие сведени... 1 [[2.2], [техническое, задание, 2.2.1], [общий,...
10 tz_11.doc 2.2.\tТЕХНИЧЕСКОЕ ЗАДАНИЕ\r2.2.1.\tОбщие сведе... 1 [[2.2], [техническое, задание, 2.2.1], [общий,...
11 tz_12.doc 2.2\tТЕХНИЧЕСКОЕ ЗАДАНИЕ\r2.2.1\tОбщие сведени... 1 [[2.2, техническое, задание, 2.2.1, общий, све...
12 tz_13.doc 2.2. ТЕХНИЧЕСКОЕ ЗАДАНИЕ\r2.2.1. Общие сведени... 1 [[2.2], [техническое, задание, 2.2.1], [общий,...
13 tz_14.doc ТЕХНИЧЕСКОЕ ЗАДАНИЕ\rРассмотрев общие требован... 1 [[техническое, задание, рассмотреть, общий, тр...
14 tz_15.doc ТЕХНИЧЕСКОЕ ЗАДАНИЕ\r2.2.1. Общие сведения\rПо... 1 [[техническое, задание, 2.2.1], [общий, сведен...
15 tz_16.doc 2.2\tТехническое задание\r2.2.1\tОбщие сведени... 1 [[2.2, технический, задание, 2.2.1, общий, све...
16 tz_17.doc 2.2 ТЕХНИЧЕСКОЕ ЗАДАНИЕ.\r2.2.1 Общие сведения... 1 [[2.2, техническое, задание], [2.2.1, общий, с...
17 tz_18.doc 2.2. Техническое задание\rОбщие сведения:\rПол... 1 [[2.2], [технический, задание, общий, сведение...
18 tz_19.doc 2.2. ТЕХНИЧЕСКОЕ ЗАДАНИЕ\r2.2.1. Наименование ... 1 [[2.2], [техническое, задание, 2.2.1], [наимен...
19 tz_20.doc 2.2. ТЕХНИЧЕСКОЕ ЗАДАНИЕ\r2.2.1. Общие сведени... 1 [[2.2], [техническое, задание, 2.2.1], [общий,...
20 Архитектура, управляемая модель.doc Архитектура, управляемая модель\rАббревиатура ... 0 [[архитектура, управляемая, модель, аббревиату...
21 Введение в проектирование ИС.doc 1. ВВЕДЕНИЕ В ПРОЕКТИРОВАНИЕ ИНФОРМАЦИОННЫХ СИ... 0 [[1], [введение, проектирование, информационны...
22 Встроенные операторы SQL.doc Встроенные операторы SQL. \rКак было отмечено ... 0 [[встроить, оператор, sql], [отметить, sql, st...
23 Методологии разработки программного обеспечени... Методологии разработки программного обеспечени... 0 [[методология, разработка, программный_обеспеч...
24 Методологии разработки программного обеспечени... Методологии разработки программного обеспечени... 0 [[методология, разработка, программный, обеспе...
25 Методы композиции и декомпозиции.doc Методы композиции и декомпозиции исполняемых U... 0 [[метод, композиция_декомпозиция, исполнять, u...
26 Модели представления данных в СУБД.doc 2.3.1. Исследование моделей информационного пр... 0 [[2.3.1], [исследование, модель, информационны...
27 Некоторые особенности проектирования.doc Некоторые особенности проектирования под конкр... 0 [[особенность, проектирование, конкретный, арх...
28 Непроцедурный доступ к данным.doc 2.3.2.3 Непроцедурный доступ к данным (SQL).\r... 0 [[2.3.2.3, непроцедурный, доступ, данным, sql]...
29 Процедурное расширение языка SQL.doc Процедурное расширение языка SQL - PL/SQL.\rOr... 0 [[процедурный, расширение, язык, sql, pl_sql],...
30 Системные объекты базы данных.doc Системные объекты базы данных.\rСловарь данных... 0 [[системный, объект, база], [словарь], [первый...
31 Технология создания распр ИС.doc 2. ТЕХНОЛОГИИ СОЗДАНИЯ РАСПРЕДЕЛЕННЫХ ИНФОРМАЦ... 0 [[2], [технологии, создания, распределенных, и...
32 Требования к проекту.doc Требования к проекту\rВведение\rОтносительно с... 0 [[требование, проект, введение, относительно, ...
33 Условия целостности БД.doc 2.1.1.3.1. Ограничительные условия, поддержива... 0 [[2.1.1.3.1], [ограничительный, условие, подде...
34 Характеристики СУБД.doc 2.2.2 Сравнительные характеристики SQL СУБД.\r... 0 [[2.2.2, сравнительные, характеристика, sql, с...
35 Этапы разработки проекта1.doc Этапы разработки проекта: заключительные стади... 0 [[этап, разработка, проект, заключительные, ст...
36 Этапы разработки проекта2.doc Этапы разработки проекта: заключительные стади... 0 [[этап, разработка, проект, заключительные, ст...
37 Этапы разработки проекта3.doc Этапы разработки проекта: определение стратеги... 0 [[этап, разработка, проект, определение, страт...
38 Этапы разработки проекта4.doc Этапы разработки проекта: реализация, тестиров... 0 [[этап_разработка, проект, реализация, тестиро...
39 Этапы разработки проекта5.doc Этапы разработки проекта: стратегия и анализ\r... 0 [[этап, разработка_проект, стратегия, анализ, ...
40 Язык манипуляции данными.doc 2.1.3. Язык манипуляции данными (ЯМД)\rЯзык ма... 0 [[2.1.3], [язык, манипуляция, данными, ямд, яз...

Инициализация Keras

In [19]:
import os

os.environ["KERAS_BACKEND"] = "jax"
import keras

print(keras.__version__)
3.9.2

Загрузка данных для классификации с помощью глубоких сетей

In [20]:
from keras.api.datasets import imdb
import os
from sklearn.model_selection import train_test_split
from tensorflow.keras.preprocessing.text import Tokenizer # type: ignore
from keras.api.preprocessing.sequence import pad_sequences
import numpy as np

unique_words = 10000
max_length = 200
X_texts = [' '.join([word for sent in doc for word in sent]) for doc in df['prep_text']]
tokenizer = Tokenizer(num_words=unique_words)
tokenizer.fit_on_texts(X_texts)
sequences = tokenizer.texts_to_sequences(X_texts)
X_padded = pad_sequences(sequences, maxlen=max_length, padding='post', truncating='post')



output_dir = "tmp"
if not os.path.exists(output_dir):
    os.makedirs(output_dir)


X_train, X_valid, y_train, y_valid = train_test_split(X_padded, df["type"].values, test_size=0.2, random_state=42)

Приведение отзывов к длине max_length (100)`

In [21]:
from keras.api.preprocessing.sequence import pad_sequences
from scipy.sparse import issparse

X_train = pad_sequences(X_train, maxlen=max_length, padding='pre', truncating='pre', value=0)
X_valid = pad_sequences(X_valid, maxlen=max_length, padding="pre", truncating="pre", value=0)

Формирование архитектуры глубокой рекуррентной сети

In [22]:
from keras.api.models import Sequential
from keras.api.layers import InputLayer, Embedding, SpatialDropout1D, SimpleRNN, Dense

rnn_model = Sequential()
rnn_model.add(InputLayer(shape=(max_length,), dtype="float32"))
rnn_model.add(Embedding(unique_words, 64))
rnn_model.add(SpatialDropout1D(0.2))
rnn_model.add(SimpleRNN(256, dropout=0.2))
rnn_model.add(Dense(1, activation="sigmoid"))

rnn_model.summary()
Model: "sequential_2"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                     Output Shape                  Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ embedding (Embedding)           │ (None, 200, 64)        │       640,000 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ spatial_dropout1d               │ (None, 200, 64)        │             0 │
│ (SpatialDropout1D)              │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ simple_rnn (SimpleRNN)          │ (None, 256)            │        82,176 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_4 (Dense)                 │ (None, 1)              │           257 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 722,433 (2.76 MB)
 Trainable params: 722,433 (2.76 MB)
 Non-trainable params: 0 (0.00 B)

Обучение модели

In [23]:
from keras.api.callbacks import ModelCheckpoint

rnn_model.compile(
    loss="binary_crossentropy",
    optimizer="adam",
    metrics=["accuracy"],
)

rnn_model.fit(
    X_train,
    y_train,
    batch_size=128,
    epochs=16,
    validation_data=(X_valid, y_valid),
    callbacks=[ModelCheckpoint(filepath=output_dir + "/rnn_weights.{epoch:02d}.keras")],
)
Epoch 1/16
1/1 ━━━━━━━━━━━━━━━━━━━━ 3s 3s/step - accuracy: 0.5312 - loss: 0.6881 - val_accuracy: 0.5556 - val_loss: 0.6938
Epoch 2/16
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 282ms/step - accuracy: 0.9688 - loss: 0.4678 - val_accuracy: 0.8889 - val_loss: 0.3995
Epoch 3/16
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 482ms/step - accuracy: 0.7188 - loss: 0.5565 - val_accuracy: 0.4444 - val_loss: 0.7295
Epoch 4/16
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 275ms/step - accuracy: 0.7188 - loss: 0.5845 - val_accuracy: 0.5556 - val_loss: 0.7028
Epoch 5/16
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 283ms/step - accuracy: 0.8438 - loss: 0.5840 - val_accuracy: 0.5556 - val_loss: 0.7054
Epoch 6/16
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 438ms/step - accuracy: 0.6562 - loss: 0.5577 - val_accuracy: 0.5556 - val_loss: 0.7249
Epoch 7/16
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 374ms/step - accuracy: 0.5625 - loss: 0.5709 - val_accuracy: 0.4444 - val_loss: 0.8666
Epoch 8/16
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 406ms/step - accuracy: 0.6875 - loss: 0.4976 - val_accuracy: 0.2222 - val_loss: 0.9157
Epoch 9/16
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 421ms/step - accuracy: 0.7812 - loss: 0.6008 - val_accuracy: 0.5556 - val_loss: 0.7275
Epoch 10/16
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 389ms/step - accuracy: 0.5312 - loss: 0.5579 - val_accuracy: 0.5556 - val_loss: 0.7214
Epoch 11/16
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 379ms/step - accuracy: 0.5312 - loss: 0.5610 - val_accuracy: 0.5556 - val_loss: 0.7141
Epoch 12/16
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 367ms/step - accuracy: 0.8125 - loss: 0.5421 - val_accuracy: 0.5556 - val_loss: 0.7136
Epoch 13/16
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 314ms/step - accuracy: 1.0000 - loss: 0.5058 - val_accuracy: 0.0000e+00 - val_loss: 0.7242
Epoch 14/16
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 391ms/step - accuracy: 1.0000 - loss: 0.4982 - val_accuracy: 0.4444 - val_loss: 0.7418
Epoch 15/16
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 373ms/step - accuracy: 1.0000 - loss: 0.4669 - val_accuracy: 0.4444 - val_loss: 0.7595
Epoch 16/16
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 393ms/step - accuracy: 0.9688 - loss: 0.4163 - val_accuracy: 0.4444 - val_loss: 0.7732
Out[23]:
<keras.src.callbacks.history.History at 0x290591efad0>
In [24]:
rnn_model.load_weights(output_dir + "/rnn_weights.15.keras")
rnn_model.evaluate(X_valid, y_valid)
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 109ms/step - accuracy: 0.4444 - loss: 0.7595
Out[24]:
[0.7595400810241699, 0.4444444477558136]

Визуализация распределения вероятностей результатов модели на валидационной выборке

In [25]:
import matplotlib.pyplot as plt

plt.hist(rnn_model.predict(X_valid))
_ = plt.axvline(x=0.5, color="red")
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 296ms/step
No description has been provided for this image

Формирование архитектуры глубокой сверточной сети

In [26]:
from keras.api.models import Sequential
from keras.api.layers import InputLayer, Embedding, SpatialDropout1D, Conv1D, GlobalMaxPooling1D, Dense, Dropout

conv_model = Sequential()
conv_model.add(InputLayer(shape=(max_length,), dtype="float32"))
conv_model.add(Embedding(unique_words, 64))
conv_model.add(SpatialDropout1D(0.2))

# сверточный слой
conv_model.add(Conv1D(256, 3, activation="relu"))

conv_model.add(GlobalMaxPooling1D())

# полносвязанный слой
conv_model.add(Dense(256, activation="relu"))
conv_model.add(Dropout(0.2))

# выходной слой
conv_model.add(Dense(1, activation="sigmoid"))

conv_model.summary()
Model: "sequential_3"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                     Output Shape                  Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ embedding_1 (Embedding)         │ (None, 200, 64)        │       640,000 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ spatial_dropout1d_1             │ (None, 200, 64)        │             0 │
│ (SpatialDropout1D)              │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv1d (Conv1D)                 │ (None, 198, 256)       │        49,408 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ global_max_pooling1d            │ (None, 256)            │             0 │
│ (GlobalMaxPooling1D)            │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_5 (Dense)                 │ (None, 256)            │        65,792 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dropout_2 (Dropout)             │ (None, 256)            │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_6 (Dense)                 │ (None, 1)              │           257 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 755,457 (2.88 MB)
 Trainable params: 755,457 (2.88 MB)
 Non-trainable params: 0 (0.00 B)

Обучение модели

In [27]:
from keras.api.callbacks import ModelCheckpoint

conv_model.compile(
    loss="binary_crossentropy",
    optimizer="adam",
    metrics=["accuracy"],
)

conv_model.fit(
    X_train,
    y_train,
    batch_size=128,
    epochs=10,
    validation_data=(X_valid, y_valid),
    callbacks=[ModelCheckpoint(filepath=output_dir + "/conv_weights.{epoch:02d}.keras")],
)
Epoch 1/10
1/1 ━━━━━━━━━━━━━━━━━━━━ 2s 2s/step - accuracy: 0.4688 - loss: 0.6933 - val_accuracy: 0.6667 - val_loss: 0.6881
Epoch 2/10
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 322ms/step - accuracy: 0.6875 - loss: 0.6844 - val_accuracy: 1.0000 - val_loss: 0.6829
Epoch 3/10
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 288ms/step - accuracy: 0.6562 - loss: 0.6829 - val_accuracy: 1.0000 - val_loss: 0.6775
Epoch 4/10
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 300ms/step - accuracy: 0.8750 - loss: 0.6697 - val_accuracy: 1.0000 - val_loss: 0.6716
Epoch 5/10
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 314ms/step - accuracy: 1.0000 - loss: 0.6571 - val_accuracy: 1.0000 - val_loss: 0.6651
Epoch 6/10
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 341ms/step - accuracy: 1.0000 - loss: 0.6408 - val_accuracy: 1.0000 - val_loss: 0.6568
Epoch 7/10
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 303ms/step - accuracy: 1.0000 - loss: 0.6304 - val_accuracy: 1.0000 - val_loss: 0.6466
Epoch 8/10
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 330ms/step - accuracy: 1.0000 - loss: 0.6211 - val_accuracy: 1.0000 - val_loss: 0.6341
Epoch 9/10
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 311ms/step - accuracy: 1.0000 - loss: 0.6057 - val_accuracy: 1.0000 - val_loss: 0.6198
Epoch 10/10
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 288ms/step - accuracy: 1.0000 - loss: 0.5722 - val_accuracy: 1.0000 - val_loss: 0.6035
Out[27]:
<keras.src.callbacks.history.History at 0x29035960350>

Загрузка лучшей модели и оценка ее качества

In [28]:
conv_model.load_weights(output_dir + "/conv_weights.10.keras")
conv_model.evaluate(X_valid, y_valid)
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 76ms/step - accuracy: 1.0000 - loss: 0.6035
Out[28]:
[0.60353022813797, 1.0]

Визуализация распределения вероятностей результатов модели на валидационной выборке

In [29]:
import matplotlib.pyplot as plt

plt.hist(conv_model.predict(X_valid))
_ = plt.axvline(x=0.5, color="red")
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 164ms/step
No description has been provided for this image

Формирование архитектуры глубокой полносвязанной сети

In [30]:
from keras.api.models import Sequential
from keras.api.layers import Dense, Flatten, Dropout, Embedding, InputLayer

simple_model = Sequential()
simple_model.add(InputLayer(shape=(max_length,), dtype="float32"))
simple_model.add(Embedding(unique_words, 64))
simple_model.add(Flatten())
simple_model.add(Dense(64, activation="relu"))
simple_model.add(Dropout(0.5))
simple_model.add(Dense(1, activation="sigmoid"))

simple_model.summary()
Model: "sequential_4"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                     Output Shape                  Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ embedding_2 (Embedding)         │ (None, 200, 64)        │       640,000 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ flatten_2 (Flatten)             │ (None, 12800)          │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_7 (Dense)                 │ (None, 64)             │       819,264 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dropout_3 (Dropout)             │ (None, 64)             │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_8 (Dense)                 │ (None, 1)              │            65 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 1,459,329 (5.57 MB)
 Trainable params: 1,459,329 (5.57 MB)
 Non-trainable params: 0 (0.00 B)

Обучение модели

In [31]:
from keras.api.callbacks import ModelCheckpoint

simple_model.compile(
    loss="binary_crossentropy",
    optimizer="adam",
    metrics=["accuracy"],
)

simple_model.fit(
    X_train,
    y_train,
    batch_size=128,
    epochs=10,
    validation_data=(X_valid, y_valid),
    callbacks=[ModelCheckpoint(filepath=output_dir + "/simple_weights.{epoch:02d}.keras")],
)
Epoch 1/10
1/1 ━━━━━━━━━━━━━━━━━━━━ 2s 2s/step - accuracy: 0.4062 - loss: 0.6990 - val_accuracy: 0.5556 - val_loss: 0.6789
Epoch 2/10
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 407ms/step - accuracy: 0.9688 - loss: 0.6011 - val_accuracy: 0.6667 - val_loss: 0.6662
Epoch 3/10
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 332ms/step - accuracy: 1.0000 - loss: 0.5178 - val_accuracy: 0.6667 - val_loss: 0.6511
Epoch 4/10
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 434ms/step - accuracy: 1.0000 - loss: 0.4240 - val_accuracy: 0.6667 - val_loss: 0.6344
Epoch 5/10
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 422ms/step - accuracy: 1.0000 - loss: 0.3631 - val_accuracy: 0.6667 - val_loss: 0.6170
Epoch 6/10
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 311ms/step - accuracy: 1.0000 - loss: 0.3023 - val_accuracy: 0.6667 - val_loss: 0.5986
Epoch 7/10
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 280ms/step - accuracy: 1.0000 - loss: 0.2429 - val_accuracy: 0.6667 - val_loss: 0.5801
Epoch 8/10
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 300ms/step - accuracy: 1.0000 - loss: 0.1860 - val_accuracy: 0.6667 - val_loss: 0.5621
Epoch 9/10
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 520ms/step - accuracy: 1.0000 - loss: 0.1621 - val_accuracy: 0.6667 - val_loss: 0.5452
Epoch 10/10
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 287ms/step - accuracy: 1.0000 - loss: 0.1185 - val_accuracy: 0.6667 - val_loss: 0.5295
Out[31]:
<keras.src.callbacks.history.History at 0x2907b080350>

Загрузка лучшей модели и оценка ее качества

In [32]:
simple_model.load_weights(output_dir + "/simple_weights.10.keras")
simple_model.evaluate(X_valid, y_valid)
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 55ms/step - accuracy: 0.6667 - loss: 0.5295
Out[32]:
[0.5294919610023499, 0.6666666865348816]

Визуализация распределения вероятностей результатов модели на валидационной выборке

In [33]:
import matplotlib.pyplot as plt

plt.hist(simple_model.predict(X_valid))
_ = plt.axvline(x=0.5, color="red")
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 86ms/step
No description has been provided for this image