Files
open-webui/backend/open_webui/models/messages.py
Timothy Jaeryang Baek de3317e26b refac
2026-03-17 17:58:01 -05:00

577 lines
20 KiB
Python

import json
import time
import uuid
from typing import Optional
from sqlalchemy.orm import Session
from open_webui.internal.db import Base, JSONField, get_db, get_db_context
from open_webui.models.tags import TagModel, Tag, Tags
from open_webui.models.users import Users, User, UserNameResponse
from open_webui.models.channels import Channels, ChannelMember
from pydantic import BaseModel, ConfigDict, field_validator
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
from sqlalchemy import or_, func, select, and_, text
from sqlalchemy.sql import exists
####################
# Message DB Schema
####################
class MessageReaction(Base):
__tablename__ = 'message_reaction'
id = Column(Text, primary_key=True, unique=True)
user_id = Column(Text)
message_id = Column(Text)
name = Column(Text)
created_at = Column(BigInteger)
class MessageReactionModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str
user_id: str
message_id: str
name: str
created_at: int # timestamp in epoch
class Message(Base):
__tablename__ = 'message'
id = Column(Text, primary_key=True, unique=True)
user_id = Column(Text)
channel_id = Column(Text, nullable=True)
reply_to_id = Column(Text, nullable=True)
parent_id = Column(Text, nullable=True)
# Pins
is_pinned = Column(Boolean, nullable=False, default=False)
pinned_at = Column(BigInteger, nullable=True)
pinned_by = Column(Text, nullable=True)
content = Column(Text)
data = Column(JSON, nullable=True)
meta = Column(JSON, nullable=True)
created_at = Column(BigInteger) # time_ns
updated_at = Column(BigInteger) # time_ns
class MessageModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str
user_id: str
channel_id: Optional[str] = None
reply_to_id: Optional[str] = None
parent_id: Optional[str] = None
# Pins
is_pinned: bool = False
pinned_by: Optional[str] = None
pinned_at: Optional[int] = None # timestamp in epoch (time_ns)
content: str
data: Optional[dict] = None
meta: Optional[dict] = None
created_at: int # timestamp in epoch (time_ns)
updated_at: int # timestamp in epoch (time_ns)
####################
# Forms
####################
class MessageForm(BaseModel):
temp_id: Optional[str] = None
content: str
reply_to_id: Optional[str] = None
parent_id: Optional[str] = None
data: Optional[dict] = None
meta: Optional[dict] = None
class Reactions(BaseModel):
name: str
users: list[dict]
count: int
class MessageUserResponse(MessageModel):
user: Optional[UserNameResponse] = None
class MessageUserSlimResponse(MessageUserResponse):
data: bool | None = None
@field_validator('data', mode='before')
def convert_data_to_bool(cls, v):
# No data or not a dict → False
if not isinstance(v, dict):
return False
# True if ANY value in the dict is non-empty
return any(bool(val) for val in v.values())
class MessageReplyToResponse(MessageUserResponse):
reply_to_message: Optional[MessageUserSlimResponse] = None
class MessageWithReactionsResponse(MessageUserSlimResponse):
reactions: list[Reactions]
class MessageResponse(MessageReplyToResponse):
latest_reply_at: Optional[int]
reply_count: int
reactions: list[Reactions]
class MessageTable:
def insert_new_message(
self,
form_data: MessageForm,
channel_id: str,
user_id: str,
db: Optional[Session] = None,
) -> Optional[MessageModel]:
with get_db_context(db) as db:
channel_member = Channels.join_channel(channel_id, user_id)
id = str(uuid.uuid4())
ts = int(time.time_ns())
message = MessageModel(
**{
'id': id,
'user_id': user_id,
'channel_id': channel_id,
'reply_to_id': form_data.reply_to_id,
'parent_id': form_data.parent_id,
'is_pinned': False,
'pinned_at': None,
'pinned_by': None,
'content': form_data.content,
'data': form_data.data,
'meta': form_data.meta,
'created_at': ts,
'updated_at': ts,
}
)
result = Message(**message.model_dump())
db.add(result)
db.commit()
db.refresh(result)
return MessageModel.model_validate(result) if result else None
def get_message_by_id(
self,
id: str,
include_thread_replies: Optional[bool] = True,
db: Optional[Session] = None,
) -> Optional[MessageResponse]:
with get_db_context(db) as db:
message = db.get(Message, id)
if not message:
return None
reply_to_message = (
self.get_message_by_id(message.reply_to_id, include_thread_replies=False, db=db)
if message.reply_to_id
else None
)
reactions = self.get_reactions_by_message_id(id, db=db)
thread_replies = []
if include_thread_replies:
thread_replies = self.get_thread_replies_by_message_id(id, db=db)
# Check if message was sent by webhook (webhook info in meta takes precedence)
webhook_info = message.meta.get('webhook') if message.meta else None
if webhook_info and webhook_info.get('id'):
# Look up webhook by ID to get current name
webhook = Channels.get_webhook_by_id(webhook_info.get('id'), db=db)
if webhook:
user_info = {
'id': webhook.id,
'name': webhook.name,
'role': 'webhook',
}
else:
# Webhook was deleted, use placeholder
user_info = {
'id': webhook_info.get('id'),
'name': 'Deleted Webhook',
'role': 'webhook',
}
else:
user = Users.get_user_by_id(message.user_id, db=db)
user_info = user.model_dump() if user else None
return MessageResponse.model_validate(
{
**MessageModel.model_validate(message).model_dump(),
'user': user_info,
'reply_to_message': (reply_to_message.model_dump() if reply_to_message else None),
'latest_reply_at': (thread_replies[0].created_at if thread_replies else None),
'reply_count': len(thread_replies),
'reactions': reactions,
}
)
def get_thread_replies_by_message_id(self, id: str, db: Optional[Session] = None) -> list[MessageReplyToResponse]:
with get_db_context(db) as db:
all_messages = db.query(Message).filter_by(parent_id=id).order_by(Message.created_at.desc()).all()
messages = []
for message in all_messages:
reply_to_message = (
self.get_message_by_id(message.reply_to_id, include_thread_replies=False, db=db)
if message.reply_to_id
else None
)
webhook_info = message.meta.get('webhook') if message.meta else None
user_info = None
if webhook_info and webhook_info.get('id'):
webhook = Channels.get_webhook_by_id(webhook_info.get('id'), db=db)
if webhook:
user_info = {
'id': webhook.id,
'name': webhook.name,
'role': 'webhook',
}
else:
user_info = {
'id': webhook_info.get('id'),
'name': 'Deleted Webhook',
'role': 'webhook',
}
messages.append(
MessageReplyToResponse.model_validate(
{
**MessageModel.model_validate(message).model_dump(),
'user': user_info,
'reply_to_message': (reply_to_message.model_dump() if reply_to_message else None),
}
)
)
return messages
def get_reply_user_ids_by_message_id(self, id: str, db: Optional[Session] = None) -> list[str]:
with get_db_context(db) as db:
return [message.user_id for message in db.query(Message).filter_by(parent_id=id).all()]
def get_messages_by_channel_id(
self,
channel_id: str,
skip: int = 0,
limit: int = 50,
db: Optional[Session] = None,
) -> list[MessageReplyToResponse]:
with get_db_context(db) as db:
all_messages = (
db.query(Message)
.filter_by(channel_id=channel_id, parent_id=None)
.order_by(Message.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
messages = []
for message in all_messages:
reply_to_message = (
self.get_message_by_id(message.reply_to_id, include_thread_replies=False, db=db)
if message.reply_to_id
else None
)
webhook_info = message.meta.get('webhook') if message.meta else None
user_info = None
if webhook_info and webhook_info.get('id'):
webhook = Channels.get_webhook_by_id(webhook_info.get('id'), db=db)
if webhook:
user_info = {
'id': webhook.id,
'name': webhook.name,
'role': 'webhook',
}
else:
user_info = {
'id': webhook_info.get('id'),
'name': 'Deleted Webhook',
'role': 'webhook',
}
messages.append(
MessageReplyToResponse.model_validate(
{
**MessageModel.model_validate(message).model_dump(),
'user': user_info,
'reply_to_message': (reply_to_message.model_dump() if reply_to_message else None),
}
)
)
return messages
def get_messages_by_parent_id(
self,
channel_id: str,
parent_id: str,
skip: int = 0,
limit: int = 50,
db: Optional[Session] = None,
) -> list[MessageReplyToResponse]:
with get_db_context(db) as db:
message = db.get(Message, parent_id)
if not message:
return []
all_messages = (
db.query(Message)
.filter_by(channel_id=channel_id, parent_id=parent_id)
.order_by(Message.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
# If length of all_messages is less than limit, then add the parent message
if len(all_messages) < limit:
all_messages.append(message)
messages = []
for message in all_messages:
reply_to_message = (
self.get_message_by_id(message.reply_to_id, include_thread_replies=False, db=db)
if message.reply_to_id
else None
)
webhook_info = message.meta.get('webhook') if message.meta else None
user_info = None
if webhook_info and webhook_info.get('id'):
webhook = Channels.get_webhook_by_id(webhook_info.get('id'), db=db)
if webhook:
user_info = {
'id': webhook.id,
'name': webhook.name,
'role': 'webhook',
}
else:
user_info = {
'id': webhook_info.get('id'),
'name': 'Deleted Webhook',
'role': 'webhook',
}
messages.append(
MessageReplyToResponse.model_validate(
{
**MessageModel.model_validate(message).model_dump(),
'user': user_info,
'reply_to_message': (reply_to_message.model_dump() if reply_to_message else None),
}
)
)
return messages
def get_last_message_by_channel_id(self, channel_id: str, db: Optional[Session] = None) -> Optional[MessageModel]:
with get_db_context(db) as db:
message = db.query(Message).filter_by(channel_id=channel_id).order_by(Message.created_at.desc()).first()
return MessageModel.model_validate(message) if message else None
def get_pinned_messages_by_channel_id(
self,
channel_id: str,
skip: int = 0,
limit: int = 50,
db: Optional[Session] = None,
) -> list[MessageModel]:
with get_db_context(db) as db:
all_messages = (
db.query(Message)
.filter_by(channel_id=channel_id, is_pinned=True)
.order_by(Message.pinned_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return [MessageModel.model_validate(message) for message in all_messages]
def update_message_by_id(
self, id: str, form_data: MessageForm, db: Optional[Session] = None
) -> Optional[MessageModel]:
with get_db_context(db) as db:
message = db.get(Message, id)
message.content = form_data.content
message.data = {
**(message.data if message.data else {}),
**(form_data.data if form_data.data else {}),
}
message.meta = {
**(message.meta if message.meta else {}),
**(form_data.meta if form_data.meta else {}),
}
message.updated_at = int(time.time_ns())
db.commit()
db.refresh(message)
return MessageModel.model_validate(message) if message else None
def update_is_pinned_by_id(
self,
id: str,
is_pinned: bool,
pinned_by: Optional[str] = None,
db: Optional[Session] = None,
) -> Optional[MessageModel]:
with get_db_context(db) as db:
message = db.get(Message, id)
message.is_pinned = is_pinned
message.pinned_at = int(time.time_ns()) if is_pinned else None
message.pinned_by = pinned_by if is_pinned else None
db.commit()
db.refresh(message)
return MessageModel.model_validate(message) if message else None
def get_unread_message_count(
self,
channel_id: str,
user_id: str,
last_read_at: Optional[int] = None,
db: Optional[Session] = None,
) -> int:
with get_db_context(db) as db:
query = db.query(Message).filter(
Message.channel_id == channel_id,
Message.parent_id == None, # only count top-level messages
Message.created_at > (last_read_at if last_read_at else 0),
)
if user_id:
query = query.filter(Message.user_id != user_id)
return query.count()
def add_reaction_to_message(
self, id: str, user_id: str, name: str, db: Optional[Session] = None
) -> Optional[MessageReactionModel]:
with get_db_context(db) as db:
# check for existing reaction
existing_reaction = db.query(MessageReaction).filter_by(message_id=id, user_id=user_id, name=name).first()
if existing_reaction:
return MessageReactionModel.model_validate(existing_reaction)
reaction_id = str(uuid.uuid4())
reaction = MessageReactionModel(
id=reaction_id,
user_id=user_id,
message_id=id,
name=name,
created_at=int(time.time_ns()),
)
result = MessageReaction(**reaction.model_dump())
db.add(result)
db.commit()
db.refresh(result)
return MessageReactionModel.model_validate(result) if result else None
def get_reactions_by_message_id(self, id: str, db: Optional[Session] = None) -> list[Reactions]:
with get_db_context(db) as db:
# JOIN User so all user info is fetched in one query
results = (
db.query(MessageReaction, User)
.join(User, MessageReaction.user_id == User.id)
.filter(MessageReaction.message_id == id)
.all()
)
reactions = {}
for reaction, user in results:
if reaction.name not in reactions:
reactions[reaction.name] = {
'name': reaction.name,
'users': [],
'count': 0,
}
reactions[reaction.name]['users'].append(
{
'id': user.id,
'name': user.name,
}
)
reactions[reaction.name]['count'] += 1
return [Reactions(**reaction) for reaction in reactions.values()]
def remove_reaction_by_id_and_user_id_and_name(
self, id: str, user_id: str, name: str, db: Optional[Session] = None
) -> bool:
with get_db_context(db) as db:
db.query(MessageReaction).filter_by(message_id=id, user_id=user_id, name=name).delete()
db.commit()
return True
def delete_reactions_by_id(self, id: str, db: Optional[Session] = None) -> bool:
with get_db_context(db) as db:
db.query(MessageReaction).filter_by(message_id=id).delete()
db.commit()
return True
def delete_replies_by_id(self, id: str, db: Optional[Session] = None) -> bool:
with get_db_context(db) as db:
db.query(Message).filter_by(parent_id=id).delete()
db.commit()
return True
def delete_message_by_id(self, id: str, db: Optional[Session] = None) -> bool:
with get_db_context(db) as db:
db.query(Message).filter_by(id=id).delete()
# Delete all reactions to this message
db.query(MessageReaction).filter_by(message_id=id).delete()
db.commit()
return True
def search_messages_by_channel_ids(
self,
channel_ids: list[str],
query: str,
start_timestamp: Optional[int] = None,
end_timestamp: Optional[int] = None,
limit: int = 10,
db: Optional[Session] = None,
) -> list[MessageModel]:
"""Search messages in specified channels by content."""
with get_db_context(db) as db:
query_builder = db.query(Message).filter(
Message.channel_id.in_(channel_ids),
Message.content.ilike(f'%{query}%'),
)
if start_timestamp:
query_builder = query_builder.filter(Message.created_at >= start_timestamp)
if end_timestamp:
query_builder = query_builder.filter(Message.created_at <= end_timestamp)
messages = query_builder.order_by(Message.created_at.desc()).limit(limit).all()
return [MessageModel.model_validate(msg) for msg in messages]
Messages = MessageTable()