RError.com

RError.com Logo RError.com Logo

RError.com Navigation

  • 主页

Mobile menu

Close
  • 主页
  • 系统&网络
    • 热门问题
    • 最新问题
    • 标签
  • Ubuntu
    • 热门问题
    • 最新问题
    • 标签
  • 帮助
主页 / 问题 / 921572
Accepted
Сергей Андреев
Сергей Андреев
Asked:2020-12-16 18:14:10 +0000 UTC2020-12-16 18:14:10 +0000 UTC 2020-12-16 18:14:10 +0000 UTC

用于在 Keras 中训练分类模型的文本预处理

  • 772

我写了一个神经网络脚本,或者更确切地说是准备输入数据的一部分。但我不确定我是否做对了所有事情,以便模型可以正确学习。我真的需要知道的人的意见。

编码:

import sklearn
import numpy as np
from collections import Counter
from keras.models import model_from_json
from keras.preprocessing import sequence
from sklearn.model_selection import train_test_split as tts


labels_lexicon = ['_label_0', '_label_1', '_label_2'] # список категорий

def get_data_from_the_file():
  labels, descriptions, lexicon, lexicon_base = [], [], [], []
  for i,  line in enumerate(open('testtext.txt', 'r', encoding='utf8', errors='ignore')):
    content = line.split()
    labels.append([content[0]])
    descriptions.append(content[1:])
    lexicon_base += content[1:]

  count_lexicon = Counter(lexicon_base).most_common(5000)
  for count_item in count_lexicon:
   lexicon.append(count_item[0])

  return labels, descriptions, lexicon

labels, descriptions, lexicon = get_data_from_the_file()

def get_descriptions_to_index(lexicon):
    cache = {}
    word2index = {}
    for i,word in enumerate(lexicon):
        if cache.get(word) == None:
            cache[word] = i
            word2index[word] = i
    return word2index
word2index = get_descriptions_to_index(lexicon)


def get_labels_to_index(labels_lexicon):
    cache = {}
    labels2index = {}
    for i,word in enumerate(labels_lexicon):
        if cache.get(word) == None:
            cache[word] = i
            labels2index[word] = i
    return labels2index
labels2index = get_labels_to_index(labels_lexicon)

list_of_tokenize_descriptions = []
list_of_tokenize_labels = []


for descriptions_arrays in descriptions:
    prepare_list_of_tokenize_descriptions = []
    for descriptions_piece in descriptions_arrays:
        if word2index.get(descriptions_piece) != None:
            prepare_list_of_tokenize_descriptions.append(word2index[descriptions_piece])
    list_of_tokenize_descriptions.append(prepare_list_of_tokenize_descriptions)


for labels_arrays in labels:
    prepare_list_of_tokenize_labels = []
    for labels_piece in labels_arrays:
        if labels2index.get(labels_piece) != None:
            prepare_list_of_tokenize_labels.append(labels2index[labels_piece])
    list_of_tokenize_labels.append(prepare_list_of_tokenize_labels)

x_matrix_list = []
y_matrix_list = []

for i in range(len(list_of_tokenize_descriptions)):
  matrix_i = np.zeros((len(lexicon)),dtype=int)
  line =  list_of_tokenize_descriptions[i]
  for index in line:
    matrix_i[index] = 1
  x_matrix_list.append(matrix_i)


for i in range(len(list_of_tokenize_labels)):
  matrix_i = np.zeros((len(labels_lexicon)),dtype=int)
  line =  list_of_tokenize_labels[i]
  for index in line:
    matrix_i[index] = 1
  y_matrix_list.append(matrix_i)


x_train, x_test, y_train, y_test = tts(np.array(x_matrix_list), np.array(y_matrix_list),  test_size=0.3)

这是数据集的链接。

python
  • 1 1 个回答
  • 10 Views

1 个回答

  • Voted
  1. Best Answer
    MaxU - stop genocide of UA
    2020-12-16T23:16:00Z2020-12-16T23:16:00Z

    在这种情况下,我会使用keras.preprocessing.text.Tokenizer方法:

    from pathlib import Path
    import pandas as pd
    from keras.models import Sequential
    from keras.layers import Embedding, LSTM, Dense, Dropout, Activation
    from keras.preprocessing.text import Tokenizer, text_to_word_sequence
    from keras.preprocessing.sequence import pad_sequences
    from keras import optimizers
    from keras.callbacks import ModelCheckpoint
    from keras.models import load_model
    from sklearn.model_selection import train_test_split
    
    def get_data(filename, num_words=5000, frac=1.0):
        data = (pd.read_csv(filename, header=None, names=['text'], sep='~')
                  .sample(frac=frac))
        data[['label','text']] = data.pop('text').str.split(n=1, expand=True)
        data = data.dropna()
        data = data.loc[data['label'].str.contains(r'^_label')]
    
        # build vocabulary
        tok = Tokenizer(num_words=num_words)
        tok.fit_on_texts(data['text'])
        # convert texts to sequences
        X = tok.texts_to_sequences(data['text'])
        lb = LabelBinarizer()
        Y = pd.DataFrame(lb.fit_transform(data['label']), 
                         columns=lb.classes_, index=data.index)
        return (pad_sequences(X, maxlen=num_words), Y, tok)
    
    
    path = Path(r'D:\temp\.data')            
    filename = path / 'testtext.txt'
    num_words = 1000
    
    X, Y, tok = get_data(filename, num_words=num_words)
    
    # split data set to train / dev
    X_train, X_dev, Y_train, Y_dev = \
        train_test_split(X, Y, test_size=0.2, random_state=123, stratify=Y)
    print('X_train.shape:\t{}\t\tY_train.shape:\t{}'.format(X_train.shape, Y_train.shape))
    print('X_dev.shape:\t{}\t\tY_dev.shape:\t{}'.format(X_dev.shape, Y_dev.shape))
    

    结论:

    X_train.shape:  (26850, 1000)           Y_train.shape:  (26850, 3)
    X_dev.shape:    (6713, 1000)            Y_dev.shape:    (6713, 3)
    

    我们得到了什么:

    In [4]: X
    Out[4]:
    array([[  0,   0,   0, ..., 250, 154,  16],
           [  0,   0,   0, ..., 112, 121,  84],
           [  0,   0,   0, ...,  72,  49,  44],
           ...,
           [  0,   0,   0, ...,   5, 109,  99],
           [  0,   0,   0, ..., 158, 513,  78],
           [  0,   0,   0, ...,   0,   0, 138]])
    
    In [5]: X.shape
    Out[5]: (33563, 1000)
    
    In [6]: Y
    Out[6]:
           _label_0  _label_1  _label_2
    30455         1         0         0
    19423         1         0         0
    29907         0         1         0
    12779         0         1         0
    28342         0         0         1
    27583         0         0         1
    28096         0         1         0
    21411         1         0         0
    23425         1         0         0
    33227         1         0         0
    ...         ...       ...       ...
    28788         1         0         0
    17329         0         1         0
    5339          0         1         0
    9461          0         0         1
    31315         0         0         1
    23199         1         0         0
    6752          0         1         0
    164           1         0         0
    24283         0         0         1
    25055         0         0         1
    
    [33563 rows x 3 columns]
    
    In [7]: Y.shape
    Out[7]: (33563, 3)
    

    In [10]: tok.index_word[250]
    Out[10]: 'реабилитац'
    
    In [11]: tok.index_word[154]
    Out[11]: 'инвалид'
    
    In [12]: tok.index_word[16]
    Out[12]: 'год'
    
    • 4

相关问题

Sidebar

Stats

  • 问题 10021
  • Answers 30001
  • 最佳答案 8000
  • 用户 6900
  • 常问
  • 回答
  • Marko Smith

    是否可以在 C++ 中继承类 <---> 结构?

    • 2 个回答
  • Marko Smith

    这种神经网络架构适合文本分类吗?

    • 1 个回答
  • Marko Smith

    为什么分配的工作方式不同?

    • 3 个回答
  • Marko Smith

    控制台中的光标坐标

    • 1 个回答
  • Marko Smith

    如何在 C++ 中删除类的实例?

    • 4 个回答
  • Marko Smith

    点是否属于线段的问题

    • 2 个回答
  • Marko Smith

    json结构错误

    • 1 个回答
  • Marko Smith

    ServiceWorker 中的“获取”事件

    • 1 个回答
  • Marko Smith

    c ++控制台应用程序exe文件[重复]

    • 1 个回答
  • Marko Smith

    按多列从sql表中选择

    • 1 个回答
  • Martin Hope
    Alexandr_TT 圣诞树动画 2020-12-23 00:38:08 +0000 UTC
  • Martin Hope
    Suvitruf - Andrei Apanasik 什么是空? 2020-08-21 01:48:09 +0000 UTC
  • Martin Hope
    Air 究竟是什么标识了网站访问者? 2020-11-03 15:49:20 +0000 UTC
  • Martin Hope
    Qwertiy 号码显示 9223372036854775807 2020-07-11 18:16:49 +0000 UTC
  • Martin Hope
    user216109 如何为黑客设下陷阱,或充分击退攻击? 2020-05-10 02:22:52 +0000 UTC
  • Martin Hope
    Qwertiy 并变成3个无穷大 2020-11-06 07:15:57 +0000 UTC
  • Martin Hope
    koks_rs 什么是样板代码? 2020-10-27 15:43:19 +0000 UTC
  • Martin Hope
    Sirop4ik 向 git 提交发布的正确方法是什么? 2020-10-05 00:02:00 +0000 UTC
  • Martin Hope
    faoxis 为什么在这么多示例中函数都称为 foo? 2020-08-15 04:42:49 +0000 UTC
  • Martin Hope
    Pavel Mayorov 如何从事件或回调函数中返回值?或者至少等他们完成。 2020-08-11 16:49:28 +0000 UTC

热门标签

javascript python java php c# c++ html android jquery mysql

Explore

  • 主页
  • 问题
    • 热门问题
    • 最新问题
  • 标签
  • 帮助

Footer

RError.com

关于我们

  • 关于我们
  • 联系我们

Legal Stuff

  • Privacy Policy

帮助

© 2023 RError.com All Rights Reserve   沪ICP备12040472号-5