Skip to content

Commit

Permalink
Code beautifying
Browse files Browse the repository at this point in the history
  • Loading branch information
Egoluback authored Jun 2, 2021
1 parent 91703a1 commit 1528f01
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 162 deletions.
254 changes: 98 additions & 156 deletions Train.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
"jqW8MkU4ZyP6",
"qDEwhGaiZ3pK",
"b4z1-4YRr2_y",
"VOROjV3er6Ej",
"XdFMj4wbr8kk",
"JIAprBNPkJmK",
"2q91KWVRuUxv",
"T_ABCKzTuZbL",
"ercQmJFH9WbV",
Expand Down Expand Up @@ -102,42 +102,6 @@
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "VB68mww4k0Vd",
"outputId": "0aa9ecd8-8ff8-4831-c066-3ff34c1c590e"
},
"source": [
"!pip install mlxtend"
],
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"text": [
"Requirement already satisfied: mlxtend in /usr/local/lib/python3.7/dist-packages (0.14.0)\n",
"Requirement already satisfied: pandas>=0.17.1 in /usr/local/lib/python3.7/dist-packages (from mlxtend) (1.1.5)\n",
"Requirement already satisfied: scikit-learn>=0.18 in /usr/local/lib/python3.7/dist-packages (from mlxtend) (0.22.2.post1)\n",
"Requirement already satisfied: setuptools in /usr/local/lib/python3.7/dist-packages (from mlxtend) (56.1.0)\n",
"Requirement already satisfied: scipy>=0.17 in /usr/local/lib/python3.7/dist-packages (from mlxtend) (1.4.1)\n",
"Requirement already satisfied: matplotlib>=1.5.1 in /usr/local/lib/python3.7/dist-packages (from mlxtend) (3.2.2)\n",
"Requirement already satisfied: numpy>=1.10.4 in /usr/local/lib/python3.7/dist-packages (from mlxtend) (1.19.5)\n",
"Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.7/dist-packages (from pandas>=0.17.1->mlxtend) (2018.9)\n",
"Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas>=0.17.1->mlxtend) (2.8.1)\n",
"Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/dist-packages (from scikit-learn>=0.18->mlxtend) (1.0.1)\n",
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=1.5.1->mlxtend) (0.10.0)\n",
"Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=1.5.1->mlxtend) (1.3.1)\n",
"Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=1.5.1->mlxtend) (2.4.7)\n",
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.7.3->pandas>=0.17.1->mlxtend) (1.15.0)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
Expand All @@ -156,20 +120,13 @@
"from nltk.tokenize import word_tokenize\n",
"from nltk.stem import SnowballStemmer\n",
"nltk.download('punkt')\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.metrics import precision_score, recall_score, precision_recall_curve\n",
"from matplotlib import pyplot as plt\n",
"from sklearn.metrics import plot_precision_recall_curve\n",
"from sklearn.metrics import roc_auc_score\n",
"import numpy as np\n",
"from sklearn.model_selection import GridSearchCV\n",
"\n",
"from catboost import CatBoostClassifier\n",
"from mlxtend.classifier import EnsembleVoteClassifier\n",
"from sklearn.ensemble import RandomForestClassifier\n",
"\n",
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.svm import SVC\n",
"from sklearn.feature_extraction.text import TfidfVectorizer"
],
"execution_count": 4,
Expand Down Expand Up @@ -696,19 +653,6 @@
"# Sentences vectorizing"
]
},
{
"cell_type": "code",
"metadata": {
"id": "kjLpRL7w5X0Q"
},
"source": [
"# TODO:\n",
" # vectorizing sentences method: for each word in object count word2vec * tf-idf\n",
"# https://coderoad.ru/29760935/%D0%9A%D0%B0%D0%BA-%D0%BF%D0%BE%D0%BB%D1%83%D1%87%D0%B8%D1%82%D1%8C-%D0%B2%D0%B5%D0%BA%D1%82%D0%BE%D1%80-%D0%B4%D0%BB%D1%8F-%D0%BF%D1%80%D0%B5%D0%B4%D0%BB%D0%BE%D0%B6%D0%B5%D0%BD%D0%B8%D1%8F-%D0%B8%D0%B7-word2vec-%D1%82%D0%BE%D0%BA%D0%B5%D0%BD%D0%BE%D0%B2-%D0%B2-%D0%BF%D1%80%D0%B5%D0%B4%D0%BB%D0%BE%D0%B6%D0%B5%D0%BD%D0%B8%D0%B8"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
Expand Down Expand Up @@ -876,8 +820,6 @@
" \n",
" vector = self.tokenize_word2vec(dataset.iloc[object_index])\n",
"\n",
" # print(dataset.iloc[object_index])\n",
"\n",
" vectorized.append(vector)\n",
" \n",
" return pd.DataFrame(list(vectorized))"
Expand Down Expand Up @@ -987,21 +929,6 @@
"# Train"
]
},
{
"cell_type": "code",
"metadata": {
"id": "TzfbY8-ukGrF"
},
"source": [
"# models_pipeline = {\n",
"# \"INS\": Pipeline([(\"vectorizer\", TfidfVectorizer(tokenizer=tokenize_sentence)), (\"model\", LogisticRegression(random_state=42, C=10))]),\n",
"# \"THR\": Pipeline([(\"vectorizer\", TfidfVectorizer(tokenizer=tokenize_sentence)), (\"model\", LogisticRegression(random_state=42, C=10))]), \n",
"# \"OBS\": Pipeline([(\"vectorizer\", TfidfVectorizer(tokenizer=tokenize_sentence)), (\"model\", LogisticRegression(random_state=42, C=10))])\n",
"# }"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
Expand Down Expand Up @@ -1032,7 +959,7 @@
"outputId": "891e1668-e2f6-4345-e63b-53a8468e8311"
},
"source": [
"model_lr.fit(X_train_ins_vec, y_train_ins)"
"model_lr.fit(X_train_vec, y_train_ins)"
],
"execution_count": 99,
"outputs": [
Expand Down Expand Up @@ -4280,97 +4207,28 @@
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "alAQvPn_s1hD",
"outputId": "488d151a-bb64-49d6-919f-d8419ce7b0b5"
},
"source": [
"precision_score(y_true=y_test_ins, y_pred=model_ins_cbc.predict(X_test_vec)), precision_score(y_true=y_test_thr, y_pred=model_thr_cbc.predict(X_test_vec)), precision_score(y_true=y_test_obs, y_pred=model_obs_cbc.predict(X_test_vec))"
],
"execution_count": 179,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(0.8709055876685935, 0.8530405405405406, 0.8199566160520607)"
]
},
"metadata": {
"tags": []
},
"execution_count": 179
}
]
},
{
"cell_type": "code",
"cell_type": "markdown",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "UDbhP_eTs5MO",
"outputId": "8b343a8d-e8cb-46d6-aa91-da64ebca115a"
"id": "XdFMj4wbr8kk"
},
"source": [
"recall_score(y_true=y_test_ins, y_pred=model_ins_cbc.predict(X_test_vec)), recall_score(y_true=y_test_thr, y_pred=model_thr_cbc.predict(X_test_vec)), recall_score(y_true=y_test_obs, y_pred=model_obs_cbc.predict(X_test_vec))"
],
"execution_count": 180,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(0.7385017563924516, 0.625619425173439, 0.5470332850940666)"
]
},
"metadata": {
"tags": []
},
"execution_count": 180
}
"## Models with old vectorization"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "csKnrp-zsFDb",
"outputId": "3a7d35a1-4b42-42e0-b8e0-c6723d6443f7"
"id": "FK5bkQ7yiAy5"
},
"source": [
"roc_auc_score(y_test_ins, model_ins_cbc.predict_proba(X_test_vec).T[1]), roc_auc_score(y_test_thr, model_thr_cbc.predict_proba(X_test_vec).T[1]), roc_auc_score(y_test_obs, model_obs_cbc.predict_proba(X_test_vec).T[1])"
"models_pipeline = {\n",
" \"INS\": Pipeline([(\"vectorizer\", TfidfVectorizer(tokenizer=tokenize_sentence)), (\"model\", LogisticRegression(random_state=42, C=10))]),\n",
" \"THR\": Pipeline([(\"vectorizer\", TfidfVectorizer(tokenizer=tokenize_sentence)), (\"model\", LogisticRegression(random_state=42, C=10))]), \n",
" \"OBS\": Pipeline([(\"vectorizer\", TfidfVectorizer(tokenizer=tokenize_sentence)), (\"model\", LogisticRegression(random_state=42, C=10))])\n",
" }"
],
"execution_count": 182,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(0.9670210532958152, 0.9796235421641681, 0.976706747165172)"
]
},
"metadata": {
"tags": []
},
"execution_count": 182
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XdFMj4wbr8kk"
},
"source": [
"## Models with old vectorization"
]
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
Expand Down Expand Up @@ -4912,7 +4770,91 @@
"id": "T_ABCKzTuZbL"
},
"source": [
"# New models"
"## New models"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "gkKCbYCRiM2A",
"outputId": "391eab91-2ce8-4d62-9c78-55ca34593d8f"
},
"source": [
"precision_score(y_true=y_test_ins, y_pred=model_ins_cbc.predict(X_test_vec)), precision_score(y_true=y_test_thr, y_pred=model_thr_cbc.predict(X_test_vec)), precision_score(y_true=y_test_obs, y_pred=model_obs_cbc.predict(X_test_vec))"
],
"execution_count": 197,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(0.8709055876685935, 0.8530405405405406, 0.8199566160520607)"
]
},
"metadata": {
"tags": []
},
"execution_count": 197
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "5CMe95kdiOOz",
"outputId": "9bd2a12a-7195-4f99-e5af-0e2ba3b2cd40"
},
"source": [
"recall_score(y_true=y_test_ins, y_pred=model_ins_cbc.predict(X_test_vec)), recall_score(y_true=y_test_thr, y_pred=model_thr_cbc.predict(X_test_vec)), recall_score(y_true=y_test_obs, y_pred=model_obs_cbc.predict(X_test_vec))"
],
"execution_count": 198,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(0.7385017563924516, 0.625619425173439, 0.5470332850940666)"
]
},
"metadata": {
"tags": []
},
"execution_count": 198
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "5h0TAQNniPYr",
"outputId": "f4ed848f-2796-41ef-a6da-48fd4d8440a9"
},
"source": [
"roc_auc_score(y_test_ins, model_ins_cbc.predict_proba(X_test_vec).T[1]), roc_auc_score(y_test_thr, model_thr_cbc.predict_proba(X_test_vec).T[1]), roc_auc_score(y_test_obs, model_obs_cbc.predict_proba(X_test_vec).T[1])"
],
"execution_count": 199,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(0.9670210532958152, 0.9796235421641681, 0.976706747165172)"
]
},
"metadata": {
"tags": []
},
"execution_count": 199
}
]
},
{
Expand Down
7 changes: 1 addition & 6 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,13 @@

from Vectorizer import Vectorizer

import telebot, joblib, string

import wget, zipfile, gensim
import telebot, joblib, string, wget, zipfile, gensim

import numpy as np

from functools import lru_cache
from pymystem3 import Mystem


# snowball = SnowballStemmer(language="russian")
# russian_stop_words = stopwords.words("russian")

Expand Down Expand Up @@ -56,10 +53,8 @@ def reply(message):
print('------')

if (result_ins < 0.5 and result_thr < 0.5 and result_obs < 0.5): return
# if (result_ins < 0.5): return

bot.reply_to(message, f"Это оскорбление с вероятностью {result_ins}\nЭто угроза с вероятностью {result_thr}\nЭто домогательство с вероятностью {result_obs}\n")
# bot.reply_to(message, f"Это оскорбление с вероятностью {result_ins}")


bot.polling()

0 comments on commit 1528f01

Please sign in to comment.