mirror of
https://github.com/open-webui/open-webui.git
synced 2026-04-25 17:15:16 +02:00
577 lines
20 KiB
Python
577 lines
20 KiB
Python
import json
|
|
import time
|
|
import uuid
|
|
from typing import Optional
|
|
|
|
from sqlalchemy.orm import Session
|
|
from open_webui.internal.db import Base, JSONField, get_db, get_db_context
|
|
from open_webui.models.tags import TagModel, Tag, Tags
|
|
from open_webui.models.users import Users, User, UserNameResponse
|
|
from open_webui.models.channels import Channels, ChannelMember
|
|
|
|
|
|
from pydantic import BaseModel, ConfigDict, field_validator
|
|
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
|
|
from sqlalchemy import or_, func, select, and_, text
|
|
from sqlalchemy.sql import exists
|
|
|
|
####################
|
|
# Message DB Schema
|
|
####################
|
|
|
|
|
|
class MessageReaction(Base):
|
|
__tablename__ = 'message_reaction'
|
|
id = Column(Text, primary_key=True, unique=True)
|
|
user_id = Column(Text)
|
|
message_id = Column(Text)
|
|
name = Column(Text)
|
|
created_at = Column(BigInteger)
|
|
|
|
|
|
class MessageReactionModel(BaseModel):
|
|
model_config = ConfigDict(from_attributes=True)
|
|
|
|
id: str
|
|
user_id: str
|
|
message_id: str
|
|
name: str
|
|
created_at: int # timestamp in epoch
|
|
|
|
|
|
class Message(Base):
|
|
__tablename__ = 'message'
|
|
id = Column(Text, primary_key=True, unique=True)
|
|
|
|
user_id = Column(Text)
|
|
channel_id = Column(Text, nullable=True)
|
|
|
|
reply_to_id = Column(Text, nullable=True)
|
|
parent_id = Column(Text, nullable=True)
|
|
|
|
# Pins
|
|
is_pinned = Column(Boolean, nullable=False, default=False)
|
|
pinned_at = Column(BigInteger, nullable=True)
|
|
pinned_by = Column(Text, nullable=True)
|
|
|
|
content = Column(Text)
|
|
data = Column(JSON, nullable=True)
|
|
meta = Column(JSON, nullable=True)
|
|
|
|
created_at = Column(BigInteger) # time_ns
|
|
updated_at = Column(BigInteger) # time_ns
|
|
|
|
|
|
class MessageModel(BaseModel):
|
|
model_config = ConfigDict(from_attributes=True)
|
|
|
|
id: str
|
|
user_id: str
|
|
channel_id: Optional[str] = None
|
|
|
|
reply_to_id: Optional[str] = None
|
|
parent_id: Optional[str] = None
|
|
|
|
# Pins
|
|
is_pinned: bool = False
|
|
pinned_by: Optional[str] = None
|
|
pinned_at: Optional[int] = None # timestamp in epoch (time_ns)
|
|
|
|
content: str
|
|
data: Optional[dict] = None
|
|
meta: Optional[dict] = None
|
|
|
|
created_at: int # timestamp in epoch (time_ns)
|
|
updated_at: int # timestamp in epoch (time_ns)
|
|
|
|
|
|
####################
|
|
# Forms
|
|
####################
|
|
|
|
|
|
class MessageForm(BaseModel):
|
|
temp_id: Optional[str] = None
|
|
content: str
|
|
reply_to_id: Optional[str] = None
|
|
parent_id: Optional[str] = None
|
|
data: Optional[dict] = None
|
|
meta: Optional[dict] = None
|
|
|
|
|
|
class Reactions(BaseModel):
|
|
name: str
|
|
users: list[dict]
|
|
count: int
|
|
|
|
|
|
class MessageUserResponse(MessageModel):
|
|
user: Optional[UserNameResponse] = None
|
|
|
|
|
|
class MessageUserSlimResponse(MessageUserResponse):
|
|
data: bool | None = None
|
|
|
|
@field_validator('data', mode='before')
|
|
def convert_data_to_bool(cls, v):
|
|
# No data or not a dict → False
|
|
if not isinstance(v, dict):
|
|
return False
|
|
|
|
# True if ANY value in the dict is non-empty
|
|
return any(bool(val) for val in v.values())
|
|
|
|
|
|
class MessageReplyToResponse(MessageUserResponse):
|
|
reply_to_message: Optional[MessageUserSlimResponse] = None
|
|
|
|
|
|
class MessageWithReactionsResponse(MessageUserSlimResponse):
|
|
reactions: list[Reactions]
|
|
|
|
|
|
class MessageResponse(MessageReplyToResponse):
|
|
latest_reply_at: Optional[int]
|
|
reply_count: int
|
|
reactions: list[Reactions]
|
|
|
|
|
|
class MessageTable:
|
|
def insert_new_message(
|
|
self,
|
|
form_data: MessageForm,
|
|
channel_id: str,
|
|
user_id: str,
|
|
db: Optional[Session] = None,
|
|
) -> Optional[MessageModel]:
|
|
with get_db_context(db) as db:
|
|
channel_member = Channels.join_channel(channel_id, user_id)
|
|
|
|
id = str(uuid.uuid4())
|
|
ts = int(time.time_ns())
|
|
|
|
message = MessageModel(
|
|
**{
|
|
'id': id,
|
|
'user_id': user_id,
|
|
'channel_id': channel_id,
|
|
'reply_to_id': form_data.reply_to_id,
|
|
'parent_id': form_data.parent_id,
|
|
'is_pinned': False,
|
|
'pinned_at': None,
|
|
'pinned_by': None,
|
|
'content': form_data.content,
|
|
'data': form_data.data,
|
|
'meta': form_data.meta,
|
|
'created_at': ts,
|
|
'updated_at': ts,
|
|
}
|
|
)
|
|
result = Message(**message.model_dump())
|
|
|
|
db.add(result)
|
|
db.commit()
|
|
db.refresh(result)
|
|
return MessageModel.model_validate(result) if result else None
|
|
|
|
def get_message_by_id(
|
|
self,
|
|
id: str,
|
|
include_thread_replies: Optional[bool] = True,
|
|
db: Optional[Session] = None,
|
|
) -> Optional[MessageResponse]:
|
|
with get_db_context(db) as db:
|
|
message = db.get(Message, id)
|
|
if not message:
|
|
return None
|
|
|
|
reply_to_message = (
|
|
self.get_message_by_id(message.reply_to_id, include_thread_replies=False, db=db)
|
|
if message.reply_to_id
|
|
else None
|
|
)
|
|
|
|
reactions = self.get_reactions_by_message_id(id, db=db)
|
|
|
|
thread_replies = []
|
|
if include_thread_replies:
|
|
thread_replies = self.get_thread_replies_by_message_id(id, db=db)
|
|
|
|
# Check if message was sent by webhook (webhook info in meta takes precedence)
|
|
webhook_info = message.meta.get('webhook') if message.meta else None
|
|
if webhook_info and webhook_info.get('id'):
|
|
# Look up webhook by ID to get current name
|
|
webhook = Channels.get_webhook_by_id(webhook_info.get('id'), db=db)
|
|
if webhook:
|
|
user_info = {
|
|
'id': webhook.id,
|
|
'name': webhook.name,
|
|
'role': 'webhook',
|
|
}
|
|
else:
|
|
# Webhook was deleted, use placeholder
|
|
user_info = {
|
|
'id': webhook_info.get('id'),
|
|
'name': 'Deleted Webhook',
|
|
'role': 'webhook',
|
|
}
|
|
else:
|
|
user = Users.get_user_by_id(message.user_id, db=db)
|
|
user_info = user.model_dump() if user else None
|
|
|
|
return MessageResponse.model_validate(
|
|
{
|
|
**MessageModel.model_validate(message).model_dump(),
|
|
'user': user_info,
|
|
'reply_to_message': (reply_to_message.model_dump() if reply_to_message else None),
|
|
'latest_reply_at': (thread_replies[0].created_at if thread_replies else None),
|
|
'reply_count': len(thread_replies),
|
|
'reactions': reactions,
|
|
}
|
|
)
|
|
|
|
def get_thread_replies_by_message_id(self, id: str, db: Optional[Session] = None) -> list[MessageReplyToResponse]:
|
|
with get_db_context(db) as db:
|
|
all_messages = db.query(Message).filter_by(parent_id=id).order_by(Message.created_at.desc()).all()
|
|
|
|
messages = []
|
|
for message in all_messages:
|
|
reply_to_message = (
|
|
self.get_message_by_id(message.reply_to_id, include_thread_replies=False, db=db)
|
|
if message.reply_to_id
|
|
else None
|
|
)
|
|
|
|
webhook_info = message.meta.get('webhook') if message.meta else None
|
|
user_info = None
|
|
if webhook_info and webhook_info.get('id'):
|
|
webhook = Channels.get_webhook_by_id(webhook_info.get('id'), db=db)
|
|
if webhook:
|
|
user_info = {
|
|
'id': webhook.id,
|
|
'name': webhook.name,
|
|
'role': 'webhook',
|
|
}
|
|
else:
|
|
user_info = {
|
|
'id': webhook_info.get('id'),
|
|
'name': 'Deleted Webhook',
|
|
'role': 'webhook',
|
|
}
|
|
|
|
messages.append(
|
|
MessageReplyToResponse.model_validate(
|
|
{
|
|
**MessageModel.model_validate(message).model_dump(),
|
|
'user': user_info,
|
|
'reply_to_message': (reply_to_message.model_dump() if reply_to_message else None),
|
|
}
|
|
)
|
|
)
|
|
return messages
|
|
|
|
def get_reply_user_ids_by_message_id(self, id: str, db: Optional[Session] = None) -> list[str]:
|
|
with get_db_context(db) as db:
|
|
return [message.user_id for message in db.query(Message).filter_by(parent_id=id).all()]
|
|
|
|
def get_messages_by_channel_id(
|
|
self,
|
|
channel_id: str,
|
|
skip: int = 0,
|
|
limit: int = 50,
|
|
db: Optional[Session] = None,
|
|
) -> list[MessageReplyToResponse]:
|
|
with get_db_context(db) as db:
|
|
all_messages = (
|
|
db.query(Message)
|
|
.filter_by(channel_id=channel_id, parent_id=None)
|
|
.order_by(Message.created_at.desc())
|
|
.offset(skip)
|
|
.limit(limit)
|
|
.all()
|
|
)
|
|
|
|
messages = []
|
|
for message in all_messages:
|
|
reply_to_message = (
|
|
self.get_message_by_id(message.reply_to_id, include_thread_replies=False, db=db)
|
|
if message.reply_to_id
|
|
else None
|
|
)
|
|
|
|
webhook_info = message.meta.get('webhook') if message.meta else None
|
|
user_info = None
|
|
if webhook_info and webhook_info.get('id'):
|
|
webhook = Channels.get_webhook_by_id(webhook_info.get('id'), db=db)
|
|
if webhook:
|
|
user_info = {
|
|
'id': webhook.id,
|
|
'name': webhook.name,
|
|
'role': 'webhook',
|
|
}
|
|
else:
|
|
user_info = {
|
|
'id': webhook_info.get('id'),
|
|
'name': 'Deleted Webhook',
|
|
'role': 'webhook',
|
|
}
|
|
|
|
messages.append(
|
|
MessageReplyToResponse.model_validate(
|
|
{
|
|
**MessageModel.model_validate(message).model_dump(),
|
|
'user': user_info,
|
|
'reply_to_message': (reply_to_message.model_dump() if reply_to_message else None),
|
|
}
|
|
)
|
|
)
|
|
return messages
|
|
|
|
def get_messages_by_parent_id(
|
|
self,
|
|
channel_id: str,
|
|
parent_id: str,
|
|
skip: int = 0,
|
|
limit: int = 50,
|
|
db: Optional[Session] = None,
|
|
) -> list[MessageReplyToResponse]:
|
|
with get_db_context(db) as db:
|
|
message = db.get(Message, parent_id)
|
|
|
|
if not message:
|
|
return []
|
|
|
|
all_messages = (
|
|
db.query(Message)
|
|
.filter_by(channel_id=channel_id, parent_id=parent_id)
|
|
.order_by(Message.created_at.desc())
|
|
.offset(skip)
|
|
.limit(limit)
|
|
.all()
|
|
)
|
|
|
|
# If length of all_messages is less than limit, then add the parent message
|
|
if len(all_messages) < limit:
|
|
all_messages.append(message)
|
|
|
|
messages = []
|
|
for message in all_messages:
|
|
reply_to_message = (
|
|
self.get_message_by_id(message.reply_to_id, include_thread_replies=False, db=db)
|
|
if message.reply_to_id
|
|
else None
|
|
)
|
|
|
|
webhook_info = message.meta.get('webhook') if message.meta else None
|
|
user_info = None
|
|
if webhook_info and webhook_info.get('id'):
|
|
webhook = Channels.get_webhook_by_id(webhook_info.get('id'), db=db)
|
|
if webhook:
|
|
user_info = {
|
|
'id': webhook.id,
|
|
'name': webhook.name,
|
|
'role': 'webhook',
|
|
}
|
|
else:
|
|
user_info = {
|
|
'id': webhook_info.get('id'),
|
|
'name': 'Deleted Webhook',
|
|
'role': 'webhook',
|
|
}
|
|
|
|
messages.append(
|
|
MessageReplyToResponse.model_validate(
|
|
{
|
|
**MessageModel.model_validate(message).model_dump(),
|
|
'user': user_info,
|
|
'reply_to_message': (reply_to_message.model_dump() if reply_to_message else None),
|
|
}
|
|
)
|
|
)
|
|
return messages
|
|
|
|
def get_last_message_by_channel_id(self, channel_id: str, db: Optional[Session] = None) -> Optional[MessageModel]:
|
|
with get_db_context(db) as db:
|
|
message = db.query(Message).filter_by(channel_id=channel_id).order_by(Message.created_at.desc()).first()
|
|
return MessageModel.model_validate(message) if message else None
|
|
|
|
def get_pinned_messages_by_channel_id(
|
|
self,
|
|
channel_id: str,
|
|
skip: int = 0,
|
|
limit: int = 50,
|
|
db: Optional[Session] = None,
|
|
) -> list[MessageModel]:
|
|
with get_db_context(db) as db:
|
|
all_messages = (
|
|
db.query(Message)
|
|
.filter_by(channel_id=channel_id, is_pinned=True)
|
|
.order_by(Message.pinned_at.desc())
|
|
.offset(skip)
|
|
.limit(limit)
|
|
.all()
|
|
)
|
|
return [MessageModel.model_validate(message) for message in all_messages]
|
|
|
|
def update_message_by_id(
|
|
self, id: str, form_data: MessageForm, db: Optional[Session] = None
|
|
) -> Optional[MessageModel]:
|
|
with get_db_context(db) as db:
|
|
message = db.get(Message, id)
|
|
message.content = form_data.content
|
|
message.data = {
|
|
**(message.data if message.data else {}),
|
|
**(form_data.data if form_data.data else {}),
|
|
}
|
|
message.meta = {
|
|
**(message.meta if message.meta else {}),
|
|
**(form_data.meta if form_data.meta else {}),
|
|
}
|
|
message.updated_at = int(time.time_ns())
|
|
db.commit()
|
|
db.refresh(message)
|
|
return MessageModel.model_validate(message) if message else None
|
|
|
|
def update_is_pinned_by_id(
|
|
self,
|
|
id: str,
|
|
is_pinned: bool,
|
|
pinned_by: Optional[str] = None,
|
|
db: Optional[Session] = None,
|
|
) -> Optional[MessageModel]:
|
|
with get_db_context(db) as db:
|
|
message = db.get(Message, id)
|
|
message.is_pinned = is_pinned
|
|
message.pinned_at = int(time.time_ns()) if is_pinned else None
|
|
message.pinned_by = pinned_by if is_pinned else None
|
|
db.commit()
|
|
db.refresh(message)
|
|
return MessageModel.model_validate(message) if message else None
|
|
|
|
def get_unread_message_count(
|
|
self,
|
|
channel_id: str,
|
|
user_id: str,
|
|
last_read_at: Optional[int] = None,
|
|
db: Optional[Session] = None,
|
|
) -> int:
|
|
with get_db_context(db) as db:
|
|
query = db.query(Message).filter(
|
|
Message.channel_id == channel_id,
|
|
Message.parent_id == None, # only count top-level messages
|
|
Message.created_at > (last_read_at if last_read_at else 0),
|
|
)
|
|
if user_id:
|
|
query = query.filter(Message.user_id != user_id)
|
|
return query.count()
|
|
|
|
def add_reaction_to_message(
|
|
self, id: str, user_id: str, name: str, db: Optional[Session] = None
|
|
) -> Optional[MessageReactionModel]:
|
|
with get_db_context(db) as db:
|
|
# check for existing reaction
|
|
existing_reaction = db.query(MessageReaction).filter_by(message_id=id, user_id=user_id, name=name).first()
|
|
if existing_reaction:
|
|
return MessageReactionModel.model_validate(existing_reaction)
|
|
|
|
reaction_id = str(uuid.uuid4())
|
|
reaction = MessageReactionModel(
|
|
id=reaction_id,
|
|
user_id=user_id,
|
|
message_id=id,
|
|
name=name,
|
|
created_at=int(time.time_ns()),
|
|
)
|
|
result = MessageReaction(**reaction.model_dump())
|
|
db.add(result)
|
|
db.commit()
|
|
db.refresh(result)
|
|
return MessageReactionModel.model_validate(result) if result else None
|
|
|
|
def get_reactions_by_message_id(self, id: str, db: Optional[Session] = None) -> list[Reactions]:
|
|
with get_db_context(db) as db:
|
|
# JOIN User so all user info is fetched in one query
|
|
results = (
|
|
db.query(MessageReaction, User)
|
|
.join(User, MessageReaction.user_id == User.id)
|
|
.filter(MessageReaction.message_id == id)
|
|
.all()
|
|
)
|
|
|
|
reactions = {}
|
|
|
|
for reaction, user in results:
|
|
if reaction.name not in reactions:
|
|
reactions[reaction.name] = {
|
|
'name': reaction.name,
|
|
'users': [],
|
|
'count': 0,
|
|
}
|
|
|
|
reactions[reaction.name]['users'].append(
|
|
{
|
|
'id': user.id,
|
|
'name': user.name,
|
|
}
|
|
)
|
|
reactions[reaction.name]['count'] += 1
|
|
|
|
return [Reactions(**reaction) for reaction in reactions.values()]
|
|
|
|
def remove_reaction_by_id_and_user_id_and_name(
|
|
self, id: str, user_id: str, name: str, db: Optional[Session] = None
|
|
) -> bool:
|
|
with get_db_context(db) as db:
|
|
db.query(MessageReaction).filter_by(message_id=id, user_id=user_id, name=name).delete()
|
|
db.commit()
|
|
return True
|
|
|
|
def delete_reactions_by_id(self, id: str, db: Optional[Session] = None) -> bool:
|
|
with get_db_context(db) as db:
|
|
db.query(MessageReaction).filter_by(message_id=id).delete()
|
|
db.commit()
|
|
return True
|
|
|
|
def delete_replies_by_id(self, id: str, db: Optional[Session] = None) -> bool:
|
|
with get_db_context(db) as db:
|
|
db.query(Message).filter_by(parent_id=id).delete()
|
|
db.commit()
|
|
return True
|
|
|
|
def delete_message_by_id(self, id: str, db: Optional[Session] = None) -> bool:
|
|
with get_db_context(db) as db:
|
|
db.query(Message).filter_by(id=id).delete()
|
|
|
|
# Delete all reactions to this message
|
|
db.query(MessageReaction).filter_by(message_id=id).delete()
|
|
|
|
db.commit()
|
|
return True
|
|
|
|
def search_messages_by_channel_ids(
|
|
self,
|
|
channel_ids: list[str],
|
|
query: str,
|
|
start_timestamp: Optional[int] = None,
|
|
end_timestamp: Optional[int] = None,
|
|
limit: int = 10,
|
|
db: Optional[Session] = None,
|
|
) -> list[MessageModel]:
|
|
"""Search messages in specified channels by content."""
|
|
with get_db_context(db) as db:
|
|
query_builder = db.query(Message).filter(
|
|
Message.channel_id.in_(channel_ids),
|
|
Message.content.ilike(f'%{query}%'),
|
|
)
|
|
|
|
if start_timestamp:
|
|
query_builder = query_builder.filter(Message.created_at >= start_timestamp)
|
|
if end_timestamp:
|
|
query_builder = query_builder.filter(Message.created_at <= end_timestamp)
|
|
|
|
messages = query_builder.order_by(Message.created_at.desc()).limit(limit).all()
|
|
return [MessageModel.model_validate(msg) for msg in messages]
|
|
|
|
|
|
Messages = MessageTable()
|