mirror of
https://github.com/open-webui/open-webui.git
synced 2026-04-26 09:35:29 +02:00
enh/refac: notes
This commit is contained in:
@@ -7,12 +7,15 @@ from functools import lru_cache
|
||||
from open_webui.internal.db import Base, get_db
|
||||
from open_webui.models.groups import Groups
|
||||
from open_webui.utils.access_control import has_access
|
||||
from open_webui.models.users import Users, UserResponse
|
||||
from open_webui.models.users import User, UserModel, Users, UserResponse
|
||||
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
|
||||
from sqlalchemy import or_, func, select, and_, text
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
|
||||
|
||||
from sqlalchemy import or_, func, select, and_, text, cast, or_, and_, func
|
||||
from sqlalchemy.sql import exists
|
||||
|
||||
####################
|
||||
@@ -75,7 +78,63 @@ class NoteUserResponse(NoteModel):
|
||||
user: Optional[UserResponse] = None
|
||||
|
||||
|
||||
class NoteItemResponse(BaseModel):
|
||||
id: str
|
||||
title: str
|
||||
data: Optional[dict]
|
||||
updated_at: int
|
||||
created_at: int
|
||||
user: Optional[UserResponse] = None
|
||||
|
||||
|
||||
class NoteListResponse(BaseModel):
|
||||
items: list[NoteUserResponse]
|
||||
total: int
|
||||
|
||||
|
||||
class NoteTable:
|
||||
def _has_permission(self, db, query, filter: dict, permission: str = "read"):
|
||||
group_ids = filter.get("group_ids", [])
|
||||
user_id = filter.get("user_id")
|
||||
|
||||
dialect_name = db.bind.dialect.name
|
||||
|
||||
# Public access
|
||||
conditions = []
|
||||
if group_ids or user_id:
|
||||
conditions.extend(
|
||||
[
|
||||
Note.access_control.is_(None),
|
||||
cast(Note.access_control, String) == "null",
|
||||
]
|
||||
)
|
||||
|
||||
# User-level permission
|
||||
if user_id:
|
||||
conditions.append(Note.user_id == user_id)
|
||||
|
||||
# Group-level permission
|
||||
if group_ids:
|
||||
group_conditions = []
|
||||
for gid in group_ids:
|
||||
if dialect_name == "sqlite":
|
||||
group_conditions.append(
|
||||
Note.access_control[permission]["group_ids"].contains([gid])
|
||||
)
|
||||
elif dialect_name == "postgresql":
|
||||
group_conditions.append(
|
||||
cast(
|
||||
Note.access_control[permission]["group_ids"],
|
||||
JSONB,
|
||||
).contains([gid])
|
||||
)
|
||||
conditions.append(or_(*group_conditions))
|
||||
|
||||
if conditions:
|
||||
query = query.filter(or_(*conditions))
|
||||
|
||||
return query
|
||||
|
||||
def insert_new_note(
|
||||
self,
|
||||
form_data: NoteForm,
|
||||
@@ -110,15 +169,103 @@ class NoteTable:
|
||||
notes = query.all()
|
||||
return [NoteModel.model_validate(note) for note in notes]
|
||||
|
||||
def search_notes(
|
||||
self, user_id: str, filter: dict = {}, skip: int = 0, limit: int = 30
|
||||
) -> NoteListResponse:
|
||||
with get_db() as db:
|
||||
query = db.query(Note, User).outerjoin(User, User.id == Note.user_id)
|
||||
if filter:
|
||||
query_key = filter.get("query")
|
||||
if query_key:
|
||||
query = query.filter(
|
||||
or_(
|
||||
Note.title.ilike(f"%{query_key}%"),
|
||||
Note.data["content"]["md"].ilike(f"%{query_key}%"),
|
||||
)
|
||||
)
|
||||
|
||||
view_option = filter.get("view_option")
|
||||
if view_option == "created":
|
||||
query = query.filter(Note.user_id == user_id)
|
||||
elif view_option == "shared":
|
||||
query = query.filter(Note.user_id != user_id)
|
||||
|
||||
# Apply access control filtering
|
||||
query = self._has_permission(
|
||||
db,
|
||||
query,
|
||||
filter,
|
||||
permission="write",
|
||||
)
|
||||
|
||||
order_by = filter.get("order_by")
|
||||
direction = filter.get("direction")
|
||||
|
||||
if order_by == "name":
|
||||
if direction == "asc":
|
||||
query = query.order_by(Note.title.asc())
|
||||
else:
|
||||
query = query.order_by(Note.title.desc())
|
||||
elif order_by == "created_at":
|
||||
if direction == "asc":
|
||||
query = query.order_by(Note.created_at.asc())
|
||||
else:
|
||||
query = query.order_by(Note.created_at.desc())
|
||||
elif order_by == "updated_at":
|
||||
if direction == "asc":
|
||||
query = query.order_by(Note.updated_at.asc())
|
||||
else:
|
||||
query = query.order_by(Note.updated_at.desc())
|
||||
else:
|
||||
query = query.order_by(Note.updated_at.desc())
|
||||
|
||||
else:
|
||||
query = query.order_by(Note.updated_at.desc())
|
||||
|
||||
for key, value in filter.items():
|
||||
query = query.filter(getattr(Note, key).ilike(f"%{value}%"))
|
||||
|
||||
# Count BEFORE pagination
|
||||
total = query.count()
|
||||
|
||||
if skip:
|
||||
query = query.offset(skip)
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
|
||||
items = query.all()
|
||||
|
||||
notes = []
|
||||
for note, user in items:
|
||||
notes.append(
|
||||
NoteUserResponse(
|
||||
**NoteModel.model_validate(note).model_dump(),
|
||||
user=(
|
||||
UserResponse(**UserModel.model_validate(user).model_dump())
|
||||
if user
|
||||
else None
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return NoteListResponse(items=notes, total=total)
|
||||
|
||||
def get_notes_by_user_id(
|
||||
self,
|
||||
user_id: str,
|
||||
permission: str = "read",
|
||||
skip: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> list[NoteModel]:
|
||||
with get_db() as db:
|
||||
query = db.query(Note).filter(Note.user_id == user_id)
|
||||
query = query.order_by(Note.updated_at.desc())
|
||||
user_group_ids = [
|
||||
group.id for group in Groups.get_groups_by_member_id(user_id)
|
||||
]
|
||||
|
||||
query = db.query(Note).order_by(Note.updated_at.desc())
|
||||
query = self._has_permission(
|
||||
db, query, {"user_id": user_id, "group_ids": user_group_ids}, permission
|
||||
)
|
||||
|
||||
if skip is not None:
|
||||
query = query.offset(skip)
|
||||
@@ -128,56 +275,6 @@ class NoteTable:
|
||||
notes = query.all()
|
||||
return [NoteModel.model_validate(note) for note in notes]
|
||||
|
||||
def get_notes_by_permission(
|
||||
self,
|
||||
user_id: str,
|
||||
permission: str = "write",
|
||||
skip: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> list[NoteModel]:
|
||||
with get_db() as db:
|
||||
user_groups = Groups.get_groups_by_member_id(user_id)
|
||||
user_group_ids = {group.id for group in user_groups}
|
||||
|
||||
# Order newest-first. We stream to keep memory usage low.
|
||||
query = (
|
||||
db.query(Note)
|
||||
.order_by(Note.updated_at.desc())
|
||||
.execution_options(stream_results=True)
|
||||
.yield_per(256)
|
||||
)
|
||||
|
||||
results: list[NoteModel] = []
|
||||
n_skipped = 0
|
||||
|
||||
for note in query:
|
||||
# Fast-pass #1: owner
|
||||
if note.user_id == user_id:
|
||||
permitted = True
|
||||
# Fast-pass #2: public/open
|
||||
elif note.access_control is None:
|
||||
# Technically this should mean public access for both read and write, but we'll only do read for now
|
||||
# We might want to change this behavior later
|
||||
permitted = permission == "read"
|
||||
else:
|
||||
permitted = has_access(
|
||||
user_id, permission, note.access_control, user_group_ids
|
||||
)
|
||||
|
||||
if not permitted:
|
||||
continue
|
||||
|
||||
# Apply skip AFTER permission filtering so it counts only accessible notes
|
||||
if skip and n_skipped < skip:
|
||||
n_skipped += 1
|
||||
continue
|
||||
|
||||
results.append(NoteModel.model_validate(note))
|
||||
if limit is not None and len(results) >= limit:
|
||||
break
|
||||
|
||||
return results
|
||||
|
||||
def get_note_by_id(self, id: str) -> Optional[NoteModel]:
|
||||
with get_db() as db:
|
||||
note = db.query(Note).filter(Note.id == id).first()
|
||||
|
||||
Reference in New Issue
Block a user