为python-telegram-bot定制的持久性类。

问题描述 投票:1回答:1

我正在开发一个简单的Telegram聊天机器人,使用的是 python-telegram-bot 库。我的机器人目前使用的是 ConversationHandler 来跟踪对话的状态。

我想通过将对话状态存储在MongoDB数据库中来使对话持久化。.

我正在使用 mongoengine python的库与我的DB进行通信。

通过阅读 BasePersistence (https:/python-telegram-bot.readthedocs.ioenstabletelegram.ext.basepersistence.html。)我的理解是,有必要用一个自定义的类来扩展这个类,让我们称它为 MongoPersistence并覆盖以下方法。

  • get_conversations(name)
  • update_conversation(name, key, new_state)

文档中并没有明确说明 dict 所回 get_conversations(name) 有,所以也很难理解如何实现。update_conversation(name, key, new_state)

假设我有上述的类(store_user_data, store_chat_data, store_bot_data 蓄势待发 False 因为我不想存储这些数据)。)

from telegram.ext import BasePersistence


class MongoPersistence(BasePersistence):

    def __init__(self):
        super(MongoPersistence, self).__init__(store_user_data=False,
                                               store_chat_data=False,
                                               store_bot_data=False)

    def get_conversations(self, name):
        pass

    def update_conversation(self, name, key, new_state):
        pass

我怎样才能实现这个类,使我的对话状态能从DB中获取并保存?

python telegram mongoengine python-telegram-bot
1个回答
5
投票

对话的持久性

我想最简单的实现方式就是看一看 PicklePersistence().我见过的唯一一个字典的例子就是 conversations = { name : { (user_id,user_id): state} } 哪儿 name 是给 ConversationHandler()钥匙的元组。(user_id,user_id)user_id 与您的机器人对话的对象和 state 是谈话的状态。好吧,也许一个不是 user_id,也许是 chat_id 但我不能确定,我需要更多的豚鼠。

为了处理元组作为一个键,python-telegram-bot包含一些工具来帮助你处理。encode_conversations_to_jsondecode_conversations_from_json.

给你 on_flush 是一个变量,用来告诉代码是否要在每次调用到 update_conversation() 当设置为 False 或只有在退出程序时,当设置为 True

最后一个细节:目前,以下代码只保存和检索数据库,但是 没有替换或删除.

from telegram.ext import BasePersistence
from config import mongo_URI
from copy import deepcopy
from telegram.utils.helpers import decode_conversations_from_json, encode_conversations_to_json
import mongoengine
import json
from bson import json_util

class Conversations(mongoengine.Document):
    obj = mongoengine.DictField()
    meta = { 'collection': 'Conversations', 'ordering': ['-id']}

class MongoPersistence(BasePersistence):

    def __init__(self):
        super(MongoPersistence, self).__init__(store_user_data=False,
                                               store_chat_data=False,
                                               store_bot_data=False)
        dbname = "persistencedb"
        mongoengine.connect(host=mongo_URI, db=dbname)
        self.conversation_collection = "Conversations"
        self.conversations = None
        self.on_flush = False

    def get_conversations(self, name):
        if self.conversations:
            pass
        else:
            document = Conversations.objects()
            if document.first() == None:
                document = {}
            else:
                document = document.first()['obj']
            conversations_json = json_util.dumps(document)
            self.conversations = decode_conversations_from_json(conversations_json)
        return self.conversations.get(name, {}).copy()

    def update_conversation(self, name, key, new_state):
        if self.conversations.setdefault(name, {}).get(key) == new_state:
            return
        self.conversations[name][key] = new_state
        if not self.on_flush:
            conversations_dic = json_util.loads(encode_conversations_to_json(self.conversations))
            document = Conversations(obj=conversations_dic)
            document.save()

    def flush(self):
        conversations_dic = json_util.loads(encode_conversations_to_json(self.conversations))
        document = Conversations(obj=conversations_dic)
        document.save()
        mongoengine.disconnect()

注意! 有时对话需要用户预先设置 user_data 而这段代码并没有按照要求提供它。

所有持久性

这里是一个更完整的代码(数据库中还缺少替换文件).

from telegram.ext import BasePersistence
from collections import defaultdict
from config import mongo_URI
from copy import deepcopy
from telegram.utils.helpers import decode_user_chat_data_from_json, decode_conversations_from_json, encode_conversations_to_json
import mongoengine
import json
from bson import json_util

class Conversations(mongoengine.Document):
    obj = mongoengine.DictField()
    meta = { 'collection': 'Conversations', 'ordering': ['-id']}

class UserData(mongoengine.Document):
    obj = mongoengine.DictField()
    meta = { 'collection': 'UserData', 'ordering': ['-id']}

class ChatData(mongoengine.Document):
    obj = mongoengine.DictField()
    meta = { 'collection': 'ChatData', 'ordering': ['-id']}

class BotData(mongoengine.Document):
    obj = mongoengine.DictField()
    meta = { 'collection': 'BotData', 'ordering': ['-id']}

class DBHelper():
    """Class to add and get documents from a mongo database using mongoengine
    """
    def __init__(self, dbname="persistencedb"):
        mongoengine.connect(host=mongo_URI, db=dbname)
    def add_item(self, data, collection):
        if collection == "Conversations":
            document = Conversations(obj=data)
        elif collection == "UserData":
            document = UserData(obj=data)
        elif collection == "chat_data_collection":
            document = ChatData(obj=data)
        else:
            document = BotData(obj=data)
        document.save()
    def get_item(self, collection):
        if collection == "Conversations":
            document = Conversations.objects()
        elif collection == "UserData":
            document = UserData.objects()
        elif collection == "ChatData":
            document = ChatData.objects()
        else:
            document = BotData.objects()
        if document.first() == None:
            document = {}
        else:
            document = document.first()['obj']

        return document
    def close(self):
        mongoengine.disconnect()

class DBPersistence(BasePersistence):
    """Uses DBHelper to make the bot persistant on a database.
       It's heavily inspired on PicklePersistence from python-telegram-bot
    """
    def __init__(self):
        super(DBPersistence, self).__init__(store_user_data=True,
                                               store_chat_data=True,
                                               store_bot_data=True)
        self.persistdb = "persistancedb"
        self.conversation_collection = "Conversations"
        self.user_data_collection = "UserData"
        self.chat_data_collection = "ChatData"
        self.bot_data_collection = "BotData"
        self.db = DBHelper()
        self.user_data = None
        self.chat_data = None
        self.bot_data = None
        self.conversations = None
        self.on_flush = False

    def get_conversations(self, name):
        if self.conversations:
            pass
        else:
            conversations_json = json_util.dumps(self.db.get_item(self.conversation_collection))
            self.conversations = decode_conversations_from_json(conversations_json)
        return self.conversations.get(name, {}).copy()

    def update_conversation(self, name, key, new_state):
        if self.conversations.setdefault(name, {}).get(key) == new_state:
            return
        self.conversations[name][key] = new_state
        if not self.on_flush:
            conversations_json = json_util.loads(encode_conversations_to_json(self.conversations))
            self.db.add_item(conversations_json, self.conversation_collection)

    def get_user_data(self):
        if self.user_data:
            pass
        else:
            user_data_json = json_util.dumps(self.db.get_item(self.user_data_collection))
            if user_data_json != '{}':
                self.user_data = decode_user_chat_data_from_json(user_data_json)
            else:
                self.user_data = defaultdict(dict,{})
        return deepcopy(self.user_data)

    def update_user_data(self, user_id, data):
        if self.user_data is None:
            self.user_data = defaultdict(dict)
        # comment next line if you want to save to db every time this function is called
        if self.user_data.get(user_id) == data:
            return
        self.user_data[user_id] = data
        if not self.on_flush:
            user_data_json = json_util.loads(json.dumps(self.user_data))
            self.db.add_item(user_data_json, self.user_data_collection)

    def get_chat_data(self):
        if self.chat_data:
            pass
        else:
            chat_data_json = json_util.dumps(self.db.get_item(self.chat_data_collection))
            if chat_data_json != "{}":
                self.chat_data = decode_user_chat_data_from_json(chat_data_json)
            else:
                self.chat_data = defaultdict(dict,{})
        return deepcopy(self.chat_data)

    def update_chat_data(self, chat_id, data):
        if self.chat_data is None:
            self.chat_data = defaultdict(dict)
        # comment next line if you want to save to db every time this function is called
        if self.chat_data.get(chat_id) == data:
            return
        self.chat_data[chat_id] = data
        if not self.on_flush:
            chat_data_json = json_util.loads(json.dumps(self.chat_data))
            self.db.add_item(chat_data_json, self.chat_data_collection)

    def get_bot_data(self):
        if self.bot_data:
            pass
        else:
            bot_data_json = json_util.dumps(self.db.get_item(self.bot_data_collection))
            self.bot_data = json.loads(bot_data_json)
        return deepcopy(self.bot_data)

    def update_bot_data(self, data):
        if self.bot_data == data:
            return
        self.bot_data = data.copy()
        if not self.on_flush:
            bot_data_json = json_util.loads(json.dumps(self.bot_data))
            self.db.add_item(self.bot_data, self.bot_data_collection)

    def flush(self):
        if self.conversations:
            conversations_json = json_util.loads(encode_conversations_to_json(self.conversations))
            self.db.add_item(conversations_json, self.conversation_collection)
        if self.user_data:
            user_data_json = json_util.loads(json.dumps(self.user_data))
            self.db.add_item(user_data_json, self.user_data_collection)
        if self.chat_data:
            chat_data_json = json_util.loads(json.dumps(self.chat_data))
            self.db.add_item(chat_data_json, self.chat_data_collection)
        if self.bot_data:
            bot_data_json = json_util.loads(json.dumps(self.bot_data))
            self.db.add_item(self.bot_data, self.bot_data_collection)
        self.db.close()

两个细节。

  1. 聊天数据的持久性暂时还没有保存到数据库中。需要多测试。也许这部分代码有bug。
  2. 目前代码中唯一的部分 on_flush = False 工作是在对话中。在所有其他更新中,似乎调用是在赋值后完成的,所以 if variable[key] == data 总是 True 并在保存到数据库之前完成代码,这就是为什么会有一个注释说 # comment next line if you want to save to db every time this function is called 但却能节省不少开支。如果你设置 on_flush = True 和代码提前停止(例如进程被杀死),你将不会在数据库中保存任何东西。
© www.soinside.com 2019 - 2024. All rights reserved.