2023-11-23 16:12:33 +04:00

104 lines
3.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import math
from functools import reduce
from LabWork01.LabWork6.ConvertorDataFrame import CovertorDataFrame
# дата-сет
dfMain = CovertorDataFrame()[0]
dfTest = CovertorDataFrame()[1]
cstr = lambda s: [k + ":" + str(v) for k, v in sorted(s.value_counts().items())]
# Структура данных Decision Tree
tree = {
# name: Название этого нода (узла)
"name": "decision tree " + dfMain.columns[-1] + " " + str(cstr(dfMain.iloc[:, -1])),
# df: Данные, связанные с этим нодом (узлом)
"df": dfMain,
# edges: Список ребер (ветвей), выходящих из этого узла, или пустой массив, если ниже нет листового узла.
"edges": [],
}
# Генерацию дерева, у узлов которого могут быть ветви, сохраняем в open
open = [tree]
# Лямба-выражение для вычесления энтропии.
# Аргумент - pandas.Series、возвращаемое значение - число энтропии
entropy = lambda s: -reduce(lambda x, y: x + y, map(lambda x: (x / len(s)) * math.log2(x / len(s)), s.value_counts()))
# Зацикливаем, пока open не станет пустым
while (len(open) != 0):
n = open.pop(0)
df_n = n["df"]
if 0 == entropy(df_n.iloc[:, -1]):
continue
attrs = {}
for attr in df_n.columns[:-1]:
attrs[attr] = {"entropy": 0, "dfs": [], "values": []}
for value in sorted(set(df_n[attr])):
df_m = df_n.query(attr + "=='" + value + "'")
attrs[attr]["entropy"] += entropy(df_m.iloc[:, -1]) * df_m.shape[0] / df_n.shape[0]
attrs[attr]["dfs"] += [df_m]
attrs[attr]["values"] += [value]
pass
pass
if len(attrs) == 0:
continue
attr = min(attrs, key=lambda x: attrs[x]["entropy"])
for d, v in zip(attrs[attr]["dfs"], attrs[attr]["values"]):
m = {"name": attr + "=" + v, "edges": [], "df": d.drop(columns=attr)}
n["edges"].append(m)
open.append(m)
pass
# Выводим дата сет
print(dfMain, "\n-------------")
# оценка тестовых данных
def predict_bp(nodes, target) -> int:
overlap = None
for node in nodes:
check: bool = node["value"] == target[node["attr"]]
if check:
overlap = node
break
if overlap is None:
overlap = nodes[-1]
if len(overlap["edges"]) == 0:
return int(overlap["df"]["StoreSales"].mean())
else:
return predict_bp(overlap["edges"], target)
def predict_str(count: int):
predictions = []
for i in range(count):
row = dfTest.iloc[i]
prediction = f"{ {'Age': row['Age'], 'BMI': row['BMI']} }" + \
f"<br/>predict {predict_bp(tree['edges'], {'Age': row['Age'], 'BMI': row['BMI']})} / fact {row['BloodPressure']}"
predictions.append(prediction)
return '<br/>'.join(predictions)
def tstr(tree, indent=""):
s = indent + tree["name"] + str(cstr(tree["df"].iloc[:, -1]) if len(tree["edges"]) == 0 else "") + "\n"
# Зацикливаем все ветви этого узла.
for e in tree["edges"]:
s += tstr(e, "\t" + indent + " ")
pass
return s
def getStringTree():
return tstr(tree)
# Выводим древо в его символьном представлении.
print(tstr(tree))