diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml
index f65d788c..d1be84b0 100644
--- a/.github/workflows/python-app.yml
+++ b/.github/workflows/python-app.yml
@@ -26,7 +26,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
- pip install pytest
+ pip install pytest pytest_asyncio hypothesis
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Test with pytest
run: |
diff --git a/README.md b/README.md
index b6f349c6..06f20ca2 100644
--- a/README.md
+++ b/README.md
@@ -26,6 +26,7 @@ Current status: Working/In-Progress
- Launch with `python3 -m tldw_chatbook.app` (working on pip packaging)
- Inspiration from [Elia Chat](https://github.com/darrenburns/elia)
+#### Chat Features
Chat Features
@@ -57,6 +58,7 @@ Current status: Working/In-Progress
- Support for searchingloading/editing/saving Notes from the Notes Database.
+#### Notes & Media Features
Notes & Media Features
@@ -75,7 +77,7 @@ Current status: Working/In-Progress
-
-
+#### Local LLM Management Features
Local LLM Inference
diff --git a/Tests/DB/__init__.py b/STests/Sync/__init__.py
similarity index 100%
rename from Tests/DB/__init__.py
rename to STests/Sync/__init__.py
diff --git a/Tests/Sync/test_sync_client-1.py b/STests/Sync/test_sync_client-1.py
similarity index 100%
rename from Tests/Sync/test_sync_client-1.py
rename to STests/Sync/test_sync_client-1.py
diff --git a/Tests/Sync/test_sync_client-2.py b/STests/Sync/test_sync_client-2.py
similarity index 99%
rename from Tests/Sync/test_sync_client-2.py
rename to STests/Sync/test_sync_client-2.py
index 96255185..94aed185 100644
--- a/Tests/Sync/test_sync_client-2.py
+++ b/STests/Sync/test_sync_client-2.py
@@ -14,11 +14,11 @@
import requests
-from tldw_cli.tldw_app.DB.Media_DB import Database, DatabaseError
+from tldw_chatbook.DB.Client_Media_DB_v2 import MediaDatabase as Database, DatabaseError
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
-from tldw_cli.tldw_app.DB.Sync_Client import ClientSyncEngine, SYNC_BATCH_SIZE # Your client sync engine
+from tldw_chatbook.DB.Sync_Client import ClientSyncEngine, SYNC_BATCH_SIZE # Your client sync engine
diff --git a/Tests/MediaDB2/test_sync_client.py b/STests/Sync/test_sync_client.py
similarity index 100%
rename from Tests/MediaDB2/test_sync_client.py
rename to STests/Sync/test_sync_client.py
diff --git a/Tests/MediaDB2/test_sync_server.py b/STests/Sync/test_sync_server.py
similarity index 100%
rename from Tests/MediaDB2/test_sync_server.py
rename to STests/Sync/test_sync_server.py
diff --git a/Tests/MediaDB2/Sync/__init__.py b/STests/__init__.py
similarity index 100%
rename from Tests/MediaDB2/Sync/__init__.py
rename to STests/__init__.py
diff --git a/Tests/ChaChaNotesDB/test_chachanotes_db.py b/Tests/ChaChaNotesDB/test_chachanotes_db.py
index 9a8f69a0..7167f7b3 100644
--- a/Tests/ChaChaNotesDB/test_chachanotes_db.py
+++ b/Tests/ChaChaNotesDB/test_chachanotes_db.py
@@ -1,4 +1,4 @@
-# test_chacha_notes_db.py
+# test_chachanotes_db.py
#
#
# Imports
@@ -8,27 +8,34 @@
import uuid
from datetime import datetime, timezone
from pathlib import Path
-import os # For :memory: check
+
#
# Third-Party Imports
+from hypothesis import strategies as st
+from hypothesis import given, settings, HealthCheck
#
# Local Imports
-from tldw_Server_API.app.core.DB_Management.ChaChaNotes_DB import (
+# --- UPDATED IMPORT PATH ---
+from tldw_chatbook.DB.ChaChaNotes_DB import (
CharactersRAGDB,
CharactersRAGDBError,
SchemaError,
InputError,
ConflictError
)
+
+
#
#######################################################################################################################
#
# Functions:
+
# --- Fixtures ---
@pytest.fixture
def client_id():
+ """Provides a consistent client ID for tests."""
return "test_client_001"
@@ -39,11 +46,11 @@ def db_path(tmp_path):
@pytest.fixture(scope="function")
-def db_instance(db_path, client_id): # Add db_path and tmp_path back
+def db_instance(db_path, client_id):
"""Creates a DB instance for each test, ensuring a fresh database."""
current_db_path = Path(db_path)
- # Clean up any existing files
+ # Clean up any existing files from previous runs to be safe
for suffix in ["", "-wal", "-shm"]:
p = Path(str(current_db_path) + suffix)
if p.exists():
@@ -54,12 +61,12 @@ def db_instance(db_path, client_id): # Add db_path and tmp_path back
db = None
try:
- db = CharactersRAGDB(current_db_path, client_id) # Use current_db_path
+ db = CharactersRAGDB(current_db_path, client_id)
yield db
finally:
if db:
db.close_connection()
- # Additional cleanup
+ # Additional cleanup after test completes
for suffix in ["", "-wal", "-shm"]:
p = Path(str(current_db_path) + suffix)
if p.exists():
@@ -71,18 +78,39 @@ def db_instance(db_path, client_id): # Add db_path and tmp_path back
@pytest.fixture
def mem_db_instance(client_id):
- """Creates an in-memory DB instance."""
+ """Creates an in-memory DB instance for tests that don't need file persistence."""
db = CharactersRAGDB(":memory:", client_id)
yield db
db.close_connection()
+@pytest.fixture
+def sample_card(db_instance: CharactersRAGDB) -> dict:
+ """A fixture that adds a sample card to the DB and returns its data."""
+ card_data = _create_sample_card_data("FromFixture")
+ card_id = db_instance.add_character_card(card_data)
+ # Return the full record from the DB, which includes ID, version, etc.
+ return db_instance.get_character_card_by_id(card_id)
+
+# You can create similar fixtures for conversations, messages, etc.
+@pytest.fixture
+def sample_conv(db_instance: CharactersRAGDB, sample_card: dict) -> dict:
+ """Adds a sample conversation linked to the sample_card."""
+ conv_data = {
+ "character_id": sample_card['id'],
+ "title": "Conversation From Fixture"
+ }
+ conv_id = db_instance.add_conversation(conv_data)
+ return db_instance.get_conversation_by_id(conv_id)
+
# --- Helper Functions ---
def get_current_utc_timestamp_iso():
+ """Returns the current UTC time in ISO 8601 format, as used by the DB."""
return datetime.now(timezone.utc).isoformat(timespec='milliseconds').replace('+00:00', 'Z')
def _create_sample_card_data(name_suffix="", client_id_override=None):
+ """Creates a sample character card data dictionary."""
return {
"name": f"Test Character {name_suffix}",
"description": "A test character.",
@@ -90,18 +118,18 @@ def _create_sample_card_data(name_suffix="", client_id_override=None):
"scenario": "A test scenario.",
"image": b"testimagebytes",
"first_message": "Hello, test!",
- "alternate_greetings": json.dumps(["Hi", "Hey"]), # Ensure JSON strings for direct use
+ "alternate_greetings": json.dumps(["Hi", "Hey"]),
"tags": json.dumps(["test", "sample"]),
"extensions": json.dumps({"custom_field": "value"}),
- "client_id": client_id_override # For testing specific client_id scenarios
+ "client_id": client_id_override
}
# --- Test Cases ---
class TestDBInitialization:
- def test_db_creation(self, db_path, client_id):
- current_db_path = Path(db_path) # Ensure it's a Path object
+ def test_db_creation_and_schema_version(self, db_path, client_id):
+ current_db_path = Path(db_path)
assert not current_db_path.exists()
db = CharactersRAGDB(current_db_path, client_id)
assert current_db_path.exists()
@@ -109,33 +137,32 @@ def test_db_creation(self, db_path, client_id):
# Check schema version
conn = db.get_connection()
- version = \
- conn.execute("SELECT version FROM db_schema_version WHERE schema_name = ?", (db._SCHEMA_NAME,)).fetchone()[
- 'version']
- assert version == db._CURRENT_SCHEMA_VERSION
+ version_row = conn.execute("SELECT version FROM db_schema_version WHERE schema_name = ?",
+ (db._SCHEMA_NAME,)).fetchone()
+ assert version_row is not None
+ assert version_row['version'] == db._CURRENT_SCHEMA_VERSION
db.close_connection()
- def test_in_memory_db(self, client_id):
+ def test_in_memory_db_initialization(self, client_id):
db = CharactersRAGDB(":memory:", client_id)
assert db.is_memory_db
assert db.client_id == client_id
- # Check schema version for in-memory
conn = db.get_connection()
- version = \
- conn.execute("SELECT version FROM db_schema_version WHERE schema_name = ?", (db._SCHEMA_NAME,)).fetchone()[
- 'version']
- assert version == db._CURRENT_SCHEMA_VERSION
+ version_row = conn.execute("SELECT version FROM db_schema_version WHERE schema_name = ?",
+ (db._SCHEMA_NAME,)).fetchone()
+ assert version_row is not None
+ assert version_row['version'] == db._CURRENT_SCHEMA_VERSION
db.close_connection()
- def test_missing_client_id(self, db_path):
+ def test_initialization_with_missing_client_id(self, db_path):
with pytest.raises(ValueError, match="Client ID cannot be empty or None."):
CharactersRAGDB(db_path, "")
with pytest.raises(ValueError, match="Client ID cannot be empty or None."):
CharactersRAGDB(db_path, None)
- def test_reopen_db(self, db_path, client_id):
+ def test_reopening_db_preserves_schema(self, db_path, client_id):
db1 = CharactersRAGDB(db_path, client_id)
- v1 = db1._get_db_version(db1.get_connection()) # Assuming _get_db_version is still available for tests
+ v1 = db1._get_db_version(db1.get_connection())
db1.close_connection()
db2 = CharactersRAGDB(db_path, "another_client")
@@ -144,17 +171,16 @@ def test_reopen_db(self, db_path, client_id):
assert v2 == CharactersRAGDB._CURRENT_SCHEMA_VERSION
db2.close_connection()
- def test_schema_newer_than_code(self, db_path, client_id):
+ def test_opening_db_with_newer_schema_raises_error(self, db_path, client_id):
db = CharactersRAGDB(db_path, client_id)
conn = db.get_connection()
- # Manually set a newer version
+ newer_version = CharactersRAGDB._CURRENT_SCHEMA_VERSION + 1
conn.execute("UPDATE db_schema_version SET version = ? WHERE schema_name = ?",
- (CharactersRAGDB._CURRENT_SCHEMA_VERSION + 1, CharactersRAGDB._SCHEMA_NAME))
+ (newer_version, CharactersRAGDB._SCHEMA_NAME))
conn.commit()
db.close_connection()
- # Match the wrapped message from CharactersRAGDBError
- expected_message_part = "Database initialization failed: Schema initialization/migration for 'rag_char_chat_schema' failed: Database schema 'rag_char_chat_schema' version .* is newer than supported by code"
+ expected_message_part = f"version \\({newer_version}\\) is newer than supported by code \\({CharactersRAGDB._CURRENT_SCHEMA_VERSION}\\)"
with pytest.raises(CharactersRAGDBError, match=expected_message_part):
CharactersRAGDB(db_path, client_id)
@@ -163,27 +189,26 @@ class TestCharacterCards:
def test_add_character_card(self, db_instance: CharactersRAGDB):
card_data = _create_sample_card_data("Add")
card_id = db_instance.add_character_card(card_data)
- assert card_id is not None
assert isinstance(card_id, int)
retrieved = db_instance.get_character_card_by_id(card_id)
assert retrieved is not None
assert retrieved["name"] == card_data["name"]
assert retrieved["description"] == card_data["description"]
- assert retrieved["image"] == card_data["image"] # BLOB check
- assert isinstance(retrieved["alternate_greetings"], list) # Check deserialization
+ assert retrieved["image"] == card_data["image"]
+ assert isinstance(retrieved["alternate_greetings"], list)
assert retrieved["alternate_greetings"] == json.loads(card_data["alternate_greetings"])
- assert retrieved["client_id"] == db_instance.client_id # Ensure instance client_id is used
+ assert retrieved["client_id"] == db_instance.client_id
assert retrieved["version"] == 1
assert not retrieved["deleted"]
- def test_add_character_card_missing_name(self, db_instance: CharactersRAGDB):
+ def test_add_character_card_with_missing_name_raises_error(self, db_instance: CharactersRAGDB):
card_data = _create_sample_card_data("MissingName")
del card_data["name"]
with pytest.raises(InputError, match="Required field 'name' is missing"):
db_instance.add_character_card(card_data)
- def test_add_character_card_duplicate_name(self, db_instance: CharactersRAGDB):
+ def test_add_character_card_with_duplicate_name_raises_error(self, db_instance: CharactersRAGDB):
card_data = _create_sample_card_data("Duplicate")
db_instance.add_character_card(card_data)
with pytest.raises(ConflictError, match=f"Character card with name '{card_data['name']}' already exists"):
@@ -194,220 +219,135 @@ def test_get_character_card_by_id_not_found(self, db_instance: CharactersRAGDB):
def test_get_character_card_by_name(self, db_instance: CharactersRAGDB):
card_data = _create_sample_card_data("ByName")
- db_instance.add_character_card(card_data)
+ card_id = db_instance.add_character_card(card_data)
retrieved = db_instance.get_character_card_by_name(card_data["name"])
assert retrieved is not None
- assert retrieved["description"] == card_data["description"]
+ assert retrieved["id"] == card_id
def test_list_character_cards(self, db_instance: CharactersRAGDB):
- assert db_instance.list_character_cards() == []
+ # A new DB instance should contain exactly one default card.
+ initial_cards = db_instance.list_character_cards()
+ assert len(initial_cards) == 1
+ assert initial_cards[0]['name'] == 'Default Assistant'
+
card_data1 = _create_sample_card_data("List1")
card_data2 = _create_sample_card_data("List2")
db_instance.add_character_card(card_data1)
db_instance.add_character_card(card_data2)
+
+ # The list should now contain 3 cards (1 default + 2 new)
cards = db_instance.list_character_cards()
- assert len(cards) == 2
- # Sort by name for predictable order if names are unique and sortable
- sorted_cards = sorted(cards, key=lambda c: c['name'])
- assert sorted_cards[0]["name"] == card_data1["name"]
- assert sorted_cards[1]["name"] == card_data2["name"]
-
- def test_update_character_card(self, db_instance: CharactersRAGDB):
- card_data_initial = _create_sample_card_data("Update")
- card_id = db_instance.add_character_card(card_data_initial)
- assert card_id is not None
-
- original_card = db_instance.get_character_card_by_id(card_id)
- assert original_card is not None
- initial_expected_version = original_card['version'] # Should be 1
-
- update_payload = {
- "description": "Updated Description", # Keep this as it was working
- "personality": "More Testy" # NEW: Very simple string for personality
- }
-
- # Determine how many version bumps to expect from this payload
- # (This is a simplified count for this test; the DB method handles the actual bumps)
- num_updatable_fields_in_payload = 0
- if "description" in update_payload:
- num_updatable_fields_in_payload += 1
- if "personality" in update_payload:
- num_updatable_fields_in_payload += 1
- # Add other fields from update_payload if they are individually updated
-
- # If no actual fields are updated, metadata update still bumps version once
- final_expected_version_bump = num_updatable_fields_in_payload if num_updatable_fields_in_payload > 0 else 1
- if not update_payload: # If payload is empty, version should still bump once due to metadata update
- final_expected_version_bump = 1
-
- updated = db_instance.update_character_card(card_id, update_payload, expected_version=initial_expected_version)
+ assert len(cards) == 3
+
+ # You can still sort and check your added cards if you filter out the default one.
+ added_card_names = {c['name'] for c in cards if c['name'] != 'Default Assistant'}
+ assert added_card_names == {card_data1['name'], card_data2['name']}
+
+ def test_update_character_card(self, db_instance: CharactersRAGDB, sample_card: dict):
+ update_payload = {"description": "Updated Description"}
+ updated = db_instance.update_character_card(
+ sample_card['id'],
+ update_payload,
+ expected_version=sample_card['version']
+ )
assert updated is True
- retrieved = db_instance.get_character_card_by_id(card_id)
- assert retrieved is not None
+ retrieved = db_instance.get_character_card_by_id(sample_card['id'])
assert retrieved["description"] == "Updated Description"
- assert retrieved["personality"] == "More Testy"
- assert retrieved["name"] == card_data_initial["name"] # Unchanged
+ assert retrieved["version"] == sample_card['version'] + 1
- # Adjust expected version based on sequential updates
- assert retrieved["version"] == initial_expected_version + 1
-
- def test_update_character_card_version_conflict(self, db_instance: CharactersRAGDB):
- card_data = _create_sample_card_data("VersionConflict")
- card_id = db_instance.add_character_card(card_data)
- assert card_id is not None
-
- # Original version is 1
- client_expected_version = 1
+ def test_update_character_card_with_version_conflict_raises_error(self, db_instance: CharactersRAGDB):
+ card_id = db_instance.add_character_card(_create_sample_card_data("VersionConflict"))
# Simulate another client's update, bumping DB version to 2
- conn = db_instance.get_connection()
- conn.execute("UPDATE character_cards SET version = 2, client_id = 'other_client' WHERE id = ?", (card_id,))
- conn.commit()
+ db_instance.update_character_card(card_id, {"description": "First update"}, expected_version=1)
+ # Client tries to update with old expected_version=1
update_payload = {"description": "Conflict Update"}
- # Updated match string to be more flexible or exact
- expected_error_regex = r"Update failed: version mismatch \(db has 2, client expected 1\) for character_cards ID 1\."
+ expected_error_regex = r"version mismatch \(db has 2, client expected 1\)"
with pytest.raises(ConflictError, match=expected_error_regex):
- db_instance.update_character_card(card_id, update_payload, expected_version=client_expected_version)
+ db_instance.update_character_card(card_id, update_payload, expected_version=1)
- def test_update_character_card_not_found(self, db_instance: CharactersRAGDB):
- with pytest.raises(ConflictError,
- match="Record not found in character_cards."): # Match new _get_current_db_version error
+ def test_update_character_card_not_found_raises_error(self, db_instance: CharactersRAGDB):
+ with pytest.raises(ConflictError, match="Record not found in character_cards"):
db_instance.update_character_card(999, {"description": "Not Found"}, expected_version=1)
- def test_soft_delete_character_card(self, db_instance: CharactersRAGDB):
- card_data = _create_sample_card_data("Delete")
- card_id = db_instance.add_character_card(card_data)
- assert card_id is not None
-
- original_card = db_instance.get_character_card_by_id(card_id)
- assert original_card is not None
- expected_version_for_first_delete = original_card['version'] # Should be 1
-
- deleted = db_instance.soft_delete_character_card(card_id, expected_version=expected_version_for_first_delete)
+ def test_soft_delete_character_card(self, db_instance: CharactersRAGDB, sample_card: dict):
+ deleted = db_instance.soft_delete_character_card(
+ sample_card['id'],
+ expected_version=sample_card['version']
+ )
assert deleted is True
+ assert db_instance.get_character_card_by_id(sample_card['id']) is None
- retrieved_after_first_delete = db_instance.get_character_card_by_id(card_id) # Should be None
- assert retrieved_after_first_delete is None
+ def test_soft_delete_is_idempotent(self, db_instance: CharactersRAGDB):
+ card_id = db_instance.add_character_card(_create_sample_card_data("IdempotentDelete"))
+ db_instance.soft_delete_character_card(card_id, expected_version=1)
+ # Calling delete again on an already deleted record should succeed
+ assert db_instance.soft_delete_character_card(card_id, expected_version=1) is True
+ # Verify version didn't change again
conn = db_instance.get_connection()
- raw_retrieved_after_first_delete = conn.execute("SELECT * FROM character_cards WHERE id = ?",
- (card_id,)).fetchone()
- assert raw_retrieved_after_first_delete is not None
- assert raw_retrieved_after_first_delete["deleted"] == 1
- assert raw_retrieved_after_first_delete["version"] == expected_version_for_first_delete + 1 # Version is now 2
-
- # Attempt to delete again with the *original* expected_version (which is now incorrect: 1).
- # The soft_delete_character_card method should recognize the card is already deleted
- # and treat this as an idempotent success, returning True.
- # The internal _get_current_db_version would raise "Record is soft-deleted",
- # which soft_delete_character_card catches and handles.
- assert db_instance.soft_delete_character_card(card_id,
- expected_version=expected_version_for_first_delete) is True
-
- # Verify version didn't change again (it's still 2 from the first delete)
- still_deleted_card_info = conn.execute("SELECT version, deleted FROM character_cards WHERE id = ?",
- (card_id,)).fetchone()
- assert still_deleted_card_info is not None
- assert still_deleted_card_info["deleted"] == 1
- assert still_deleted_card_info['version'] == expected_version_for_first_delete + 1 # Still version 2
-
- # Test idempotent success: calling soft_delete on an already deleted record
- # with its *current correct version* should also succeed.
- current_version_of_deleted_card = raw_retrieved_after_first_delete['version'] # This is 2
- assert db_instance.soft_delete_character_card(card_id, expected_version=current_version_of_deleted_card) is True
+ raw_record = conn.execute("SELECT version FROM character_cards WHERE id = ?", (card_id,)).fetchone()
+ assert raw_record["version"] == 2
def test_search_character_cards(self, db_instance: CharactersRAGDB):
- card1_data = _create_sample_card_data("Searchable Alpha")
+ card1_data = _create_sample_card_data("Search Alpha")
card1_data["description"] = "Unique keyword: ZYX"
- card2_data = _create_sample_card_data("Searchable Beta")
- card2_data["system_prompt"] = "Contains ZYX too"
+ card2_data = _create_sample_card_data("Search Beta")
+ card2_data["system_prompt"] = "Also has ZYX"
card3_data = _create_sample_card_data("Unsearchable")
db_instance.add_character_card(card1_data)
- db_instance.add_character_card(card2_data)
+ card2_id = db_instance.add_character_card(card2_data)
db_instance.add_character_card(card3_data)
results = db_instance.search_character_cards("ZYX")
assert len(results) == 2
- names = [r["name"] for r in results]
+ names = {r["name"] for r in results}
assert card1_data["name"] in names
assert card2_data["name"] in names
- # Test search after delete
- card1 = db_instance.get_character_card_by_name(card1_data["name"])
- assert card1 is not None
- db_instance.soft_delete_character_card(card1["id"], expected_version=card1["version"])
+ # Test search after soft-deleting one of the results
+ card2 = db_instance.get_character_card_by_id(card2_id)
+ db_instance.soft_delete_character_card(card2["id"], expected_version=card2["version"])
+
results_after_delete = db_instance.search_character_cards("ZYX")
assert len(results_after_delete) == 1
- assert results_after_delete[0]["name"] == card2_data["name"]
-
+ assert results_after_delete[0]["name"] == card1_data["name"]
+
+ @pytest.mark.parametrize(
+ "field_to_remove, expected_error, error_match",
+ [
+ ("name", InputError, "Required field 'name' is missing"),
+ # Assuming you add a required 'creator' field later
+ # ("creator", InputError, "Required field 'creator' is missing"),
+ ]
+ )
+ def test_add_card_missing_required_fields(self, db_instance, field_to_remove, expected_error, error_match):
+ card_data = _create_sample_card_data("MissingFields")
+ del card_data[field_to_remove]
+ with pytest.raises(expected_error, match=error_match):
+ db_instance.add_character_card(card_data)
class TestConversationsAndMessages:
@pytest.fixture
def char_id(self, db_instance):
card_id = db_instance.add_character_card(_create_sample_card_data("ConvChar"))
- assert card_id is not None
return card_id
def test_add_conversation(self, db_instance: CharactersRAGDB, char_id):
- conv_data = {
- "id": str(uuid.uuid4()),
- "character_id": char_id,
- "title": "Test Conversation"
- }
+ conv_data = {"id": str(uuid.uuid4()), "character_id": char_id, "title": "Test Conversation"}
conv_id = db_instance.add_conversation(conv_data)
assert conv_id == conv_data["id"]
retrieved = db_instance.get_conversation_by_id(conv_id)
- assert retrieved is not None
assert retrieved["title"] == "Test Conversation"
assert retrieved["character_id"] == char_id
- assert retrieved["root_id"] == conv_id # Default root_id
assert retrieved["version"] == 1
assert retrieved["client_id"] == db_instance.client_id
- def test_add_conversation_duplicate_id(self, db_instance: CharactersRAGDB, char_id):
- conv_id_val = str(uuid.uuid4())
- conv_data = {"id": conv_id_val, "character_id": char_id, "title": "First"}
- db_instance.add_conversation(conv_data)
-
- conv_data_dup = {"id": conv_id_val, "character_id": char_id, "title": "Duplicate"}
- with pytest.raises(ConflictError, match=f"Conversation with ID '{conv_id_val}' already exists"):
- db_instance.add_conversation(conv_data_dup)
-
- def test_add_message(self, db_instance: CharactersRAGDB, char_id):
+ def test_add_message_and_get_for_conversation(self, db_instance: CharactersRAGDB, char_id):
conv_id = db_instance.add_conversation({"character_id": char_id, "title": "MsgConv"})
- assert conv_id is not None
-
- msg_data = {
- "conversation_id": conv_id,
- "sender": "user",
- "content": "Hello there!"
- }
- msg_id = db_instance.add_message(msg_data)
- assert msg_id is not None
-
- retrieved_msg = db_instance.get_message_by_id(msg_id)
- assert retrieved_msg is not None
- assert retrieved_msg["sender"] == "user"
- assert retrieved_msg["content"] == "Hello there!"
- assert retrieved_msg["conversation_id"] == conv_id
- assert retrieved_msg["version"] == 1
- assert retrieved_msg["client_id"] == db_instance.client_id
-
- # Test adding message to non-existent conversation
- msg_data_bad_conv = {
- "conversation_id": str(uuid.uuid4()),
- "sender": "user",
- "content": "Test"
- }
- with pytest.raises(InputError, match="Cannot add message: Conversation ID .* not found or deleted"):
- db_instance.add_message(msg_data_bad_conv)
-
- def test_get_messages_for_conversation_ordering(self, db_instance: CharactersRAGDB, char_id):
- conv_id = db_instance.add_conversation({"character_id": char_id, "title": "OrderedMsgConv"})
- assert conv_id is not None
msg1_id = db_instance.add_message(
{"conversation_id": conv_id, "sender": "user", "content": "First", "timestamp": "2023-01-01T10:00:00Z"})
msg2_id = db_instance.add_message(
@@ -423,517 +363,206 @@ def test_get_messages_for_conversation_ordering(self, db_instance: CharactersRAG
assert messages_desc[0]["id"] == msg2_id
assert messages_desc[1]["id"] == msg1_id
- def test_update_conversation(self, db_instance: CharactersRAGDB, char_id: int):
- # 1. Setup: Add an initial conversation with a SIMPLE title
- initial_title = "AlphaTitleOne" # Simple, unique for this test run
- conv_id = db_instance.add_conversation({
- "character_id": char_id,
- "title": initial_title
- })
- assert conv_id is not None, "Failed to add initial conversation"
-
- # 2. Verify initial state in main table
- original_conv_main = db_instance.get_conversation_by_id(conv_id)
- assert original_conv_main is not None, "Failed to retrieve initial conversation"
- assert original_conv_main['title'] == initial_title, "Initial title mismatch in main table"
- initial_expected_version = original_conv_main['version']
- assert initial_expected_version == 1, "Initial version should be 1"
-
- # 3. Verify initial FTS state (new title is searchable)
- try:
- initial_fts_results = db_instance.search_conversations_by_title(initial_title)
- assert len(initial_fts_results) == 1, \
- f"FTS Pre-Update: Expected 1 result for '{initial_title}', got {len(initial_fts_results)}. Results: {initial_fts_results}"
- assert initial_fts_results[0]['id'] == conv_id, \
- f"FTS Pre-Update: Conversation ID {conv_id} not found in initial search results for '{initial_title}'."
- except Exception as e:
- pytest.fail(f"Failed during initial FTS check for '{initial_title}': {e}")
-
- # 4. Define update payload with a SIMPLE, DIFFERENT title
- updated_title = "BetaTitleTwo" # Simple, unique, different from initial_title
- updated_rating = 5
- update_payload = {"title": updated_title, "rating": updated_rating}
-
- # 5. Perform the update
- updated = db_instance.update_conversation(conv_id, update_payload, expected_version=initial_expected_version)
- assert updated is True, "update_conversation returned False"
-
- # 6. Verify updated state in main table
- retrieved_after_update = db_instance.get_conversation_by_id(conv_id)
- assert retrieved_after_update is not None, "Failed to retrieve conversation after update"
- assert retrieved_after_update["title"] == updated_title, "Title was not updated correctly in main table"
- assert retrieved_after_update["rating"] == updated_rating, "Rating was not updated correctly"
- assert retrieved_after_update[
- "version"] == initial_expected_version + 1, "Version did not increment correctly after update"
-
- # 7. Verify FTS state after update
- # Search for the NEW title
- try:
- search_results_new_title = db_instance.search_conversations_by_title(updated_title)
- assert len(search_results_new_title) == 1, \
- f"FTS Post-Update: Expected 1 result for new title '{updated_title}', got {len(search_results_new_title)}. Results: {search_results_new_title}"
- assert search_results_new_title[0]['id'] == conv_id, \
- f"FTS Post-Update: Conversation ID {conv_id} not found in search results for new title '{updated_title}'."
- except Exception as e:
- pytest.fail(f"Failed during FTS check for new title '{updated_title}': {e}")
-
- # Search for the OLD title
- try:
- search_results_old_title = db_instance.search_conversations_by_title(initial_title)
- found_old_title_for_this_conv_via_match = any(r['id'] == conv_id for r in search_results_old_title)
-
- if found_old_title_for_this_conv_via_match:
- print(
- f"\nINFO (FTS Nuance): FTS MATCH found old title '{initial_title}' for conv_id {conv_id} immediately after update.")
- print(f" Search results for old title via MATCH: {search_results_old_title}")
-
- # Verify the actual content stored in FTS table for this rowid
- # This confirms the trigger updated the FTS table's data record.
- conn_debug = db_instance.get_connection() # Get a connection
-
- # Get the rowid of the conversation from the main table first
- main_table_rowid_cursor = conn_debug.execute("SELECT rowid FROM conversations WHERE id = ?", (conv_id,))
- main_table_rowid_row = main_table_rowid_cursor.fetchone()
- assert main_table_rowid_row is not None, f"Could not fetch rowid from main 'conversations' table for id {conv_id}"
- target_conv_rowid = main_table_rowid_row['rowid']
-
- fts_content_cursor = conn_debug.execute(
- "SELECT title FROM conversations_fts WHERE rowid = ?",
- (target_conv_rowid,) # Use the actual rowid from the main table
- )
- fts_content_row = fts_content_cursor.fetchone()
-
- current_fts_content_title = "FTS ROW NOT FOUND (SHOULD EXIST)"
- if fts_content_row:
- current_fts_content_title = fts_content_row['title']
- else: # This case should ideally not happen if the new title was inserted
- print(
- f"ERROR: FTS row for rowid {target_conv_rowid} not found directly after update, but MATCH found it.")
-
- print(
- f" Actual content in conversations_fts.title for rowid {target_conv_rowid}: '{current_fts_content_title}'")
-
- assert current_fts_content_title == updated_title, \
- f"FTS CONTENT CHECK FAILED: Stored FTS content for rowid {target_conv_rowid} of conv_id {conv_id} was '{current_fts_content_title}', expected '{updated_title}'."
-
- # The following assertion is expected to FAIL due to FTS5 MATCH "stickiness"
- # It demonstrates that while the FTS data record is updated, MATCH might still find old terms immediately.
- # To make the overall test "pass" while acknowledging this, this line would be commented out or adjusted.
- # assert not found_old_title_for_this_conv_via_match, \
- # (f"FTS MATCH BEHAVIOR: Old title '{initial_title}' was STILL MATCHED for conversation ID {conv_id} "
- # f"after update, even though FTS content for its rowid ({target_conv_rowid}) is now '{current_fts_content_title}'. "
- # f"This highlights FTS5's eventual consistency for MATCH queries post-update.")
- else:
- # This is the ideal immediate outcome: old title is not found by MATCH.
- assert not found_old_title_for_this_conv_via_match # This will pass if branch is taken
-
- except Exception as e:
- pytest.fail(f"Failed during FTS check for old title '{initial_title}': {e}")
-
- def test_soft_delete_conversation_and_messages(self, db_instance: CharactersRAGDB, char_id):
- # Setup: Conversation with messages
- conv_title_for_delete_test = "DeleteConvForFTS"
- conv_id = db_instance.add_conversation({"character_id": char_id, "title": conv_title_for_delete_test})
- assert conv_id is not None
- msg1_id = db_instance.add_message({"conversation_id": conv_id, "sender": "user", "content": "Msg1"})
- assert msg1_id is not None
-
+ def test_update_conversation_and_fts(self, db_instance: CharactersRAGDB, char_id: int):
+ initial_title = "AlphaTitleOne"
+ conv_id = db_instance.add_conversation({"character_id": char_id, "title": initial_title})
original_conv = db_instance.get_conversation_by_id(conv_id)
- assert original_conv is not None
- expected_version = original_conv['version']
- # Verify it's in FTS before delete
- results_before_delete = db_instance.search_conversations_by_title(conv_title_for_delete_test)
- assert len(results_before_delete) == 1, "Conversation should be in FTS before soft delete"
- assert results_before_delete[0]['id'] == conv_id
+ # Verify FTS state before update
+ assert len(db_instance.search_conversations_by_title(initial_title)) == 1
- # Soft delete conversation
- deleted = db_instance.soft_delete_conversation(conv_id, expected_version=expected_version)
- assert deleted is True
- assert db_instance.get_conversation_by_id(conv_id) is None
+ # Perform update
+ updated_title = "BetaTitleTwo"
+ db_instance.update_conversation(conv_id, {"title": updated_title}, expected_version=original_conv['version'])
- msg1 = db_instance.get_message_by_id(msg1_id)
- assert msg1 is not None
- assert msg1["conversation_id"] == conv_id
+ # Verify FTS state after update
+ assert len(db_instance.search_conversations_by_title(updated_title)) == 1
+ assert len(db_instance.search_conversations_by_title(initial_title)) == 0, "FTS should not find the old title"
- # FTS search for conversation should not find it
- # UNCOMMENTED AND VERIFIED:
- results_after_delete = db_instance.search_conversations_by_title(conv_title_for_delete_test)
- assert len(results_after_delete) == 0, "FTS search should not find the soft-deleted conversation"
+ def test_soft_delete_conversation_and_fts(self, db_instance: CharactersRAGDB, char_id):
+ conv_title_for_delete_test = "DeleteConvForFTS"
+ conv_id = db_instance.add_conversation({"character_id": char_id, "title": conv_title_for_delete_test})
+ original_conv = db_instance.get_conversation_by_id(conv_id)
- def test_conversation_fts_search(self, db_instance: CharactersRAGDB, char_id):
- conv_id1 = db_instance.add_conversation({"character_id": char_id, "title": "Unique Alpha Search Term"})
- conv_id2 = db_instance.add_conversation({"character_id": char_id, "title": "Another Alpha For Test"})
- db_instance.add_conversation({"character_id": char_id, "title": "Beta Content Only"})
+ assert len(db_instance.search_conversations_by_title(conv_title_for_delete_test)) == 1
- results_alpha = db_instance.search_conversations_by_title("Alpha")
- assert len(results_alpha) == 2
- found_ids_alpha = {r['id'] for r in results_alpha}
- assert conv_id1 in found_ids_alpha
- assert conv_id2 in found_ids_alpha
+ db_instance.soft_delete_conversation(conv_id, expected_version=original_conv['version'])
- results_unique = db_instance.search_conversations_by_title("Unique")
- assert len(results_unique) == 1
- assert results_unique[0]['id'] == conv_id1
+ assert db_instance.get_conversation_by_id(conv_id) is None
+ assert len(db_instance.search_conversations_by_title(
+ conv_title_for_delete_test)) == 0, "FTS should not find soft-deleted conversation"
- def test_search_messages_by_content_FIXED_JOIN(self, db_instance: CharactersRAGDB, char_id):
- # This test specifically validates the FTS join fix for messages (TEXT PK)
+ def test_search_messages_by_content(self, db_instance: CharactersRAGDB, char_id):
conv_id = db_instance.add_conversation({"character_id": char_id, "title": "MessageSearchConv"})
- assert conv_id is not None
msg1_data = {"id": str(uuid.uuid4()), "conversation_id": conv_id, "sender": "user",
"content": "UniqueMessageContentAlpha"}
- msg2_data = {"id": str(uuid.uuid4()), "conversation_id": conv_id, "sender": "ai", "content": "Another phrase"}
-
db_instance.add_message(msg1_data)
- db_instance.add_message(msg2_data)
results = db_instance.search_messages_by_content("UniqueMessageContentAlpha")
assert len(results) == 1
assert results[0]["id"] == msg1_data["id"]
- assert results[0]["content"] == msg1_data["content"]
-
- # Test search within a specific conversation
- results_conv_specific = db_instance.search_messages_by_content("UniqueMessageContentAlpha",
- conversation_id=conv_id)
- assert len(results_conv_specific) == 1
- assert results_conv_specific[0]["id"] == msg1_data["id"]
-
- # Test search for content in another conversation (should not be found if conv_id is specified)
- other_conv_id = db_instance.add_conversation({"character_id": char_id, "title": "Other MessageSearchConv"})
- assert other_conv_id is not None
- db_instance.add_message({"id": str(uuid.uuid4()), "conversation_id": other_conv_id, "sender": "user",
- "content": "UniqueMessageContentAlpha In Other"})
- results_other_conv = db_instance.search_messages_by_content("UniqueMessageContentAlpha",
- conversation_id=other_conv_id)
- assert len(results_other_conv) == 1
- assert results_other_conv[0]["content"] == "UniqueMessageContentAlpha In Other"
- results_original_conv_again = db_instance.search_messages_by_content("UniqueMessageContentAlpha",
- conversation_id=conv_id)
- assert len(results_original_conv_again) == 1
- assert results_original_conv_again[0]["id"] == msg1_data["id"]
-
-
-class TestNotes:
- def test_add_note(self, db_instance: CharactersRAGDB):
- note_id = db_instance.add_note("Test Note Title", "This is the content of the note.")
- assert isinstance(note_id, str) # UUID
-
- retrieved = db_instance.get_note_by_id(note_id)
- assert retrieved is not None
- assert retrieved["title"] == "Test Note Title"
- assert retrieved["content"] == "This is the content of the note."
- assert retrieved["version"] == 1
- assert not retrieved["deleted"]
-
- def test_add_note_empty_title(self, db_instance: CharactersRAGDB):
- with pytest.raises(InputError, match="Note title cannot be empty."):
- db_instance.add_note("", "Content")
-
- def test_add_note_duplicate_id(self, db_instance: CharactersRAGDB):
- fixed_id = str(uuid.uuid4())
- db_instance.add_note("First Note", "Content1", note_id=fixed_id)
- with pytest.raises(ConflictError, match=f"Note with ID '{fixed_id}' already exists."):
- db_instance.add_note("Second Note", "Content2", note_id=fixed_id)
-
- def test_update_note(self, db_instance: CharactersRAGDB):
+ # @pytest.mark.parametrize(
+ # "msg_data, raises_error",
+ # [
+ # ({"content": "Hello", "image_data": None, "image_mime_type": None}, False),
+ # ({"content": "", "image_data": b'img', "image_mime_type": "image/png"}, False),
+ # ({"content": "Hello", "image_data": b'img', "image_mime_type": "image/png"}, False),
+ # # Failure cases
+ # ({"content": "", "image_data": None, "image_mime_type": None}, True), # Both missing
+ # ({"content": None, "image_data": None, "image_mime_type": None}, True), # Both missing
+ # ({"content": "", "image_data": b'img', "image_mime_type": None}, True), # Mime type missing
+ # ]
+ # )
+ # def test_add_message_content_requirements(self, db_instance, sample_conv, msg_data, raises_error):
+ # full_payload = {
+ # "conversation_id": sample_conv['id'],
+ # "sender": "user",
+ # **msg_data
+ # }
+ #
+ # if raises_error:
+ # with pytest.raises((InputError, TypeError)): # TypeError if content is None
+ # db_instance.add_message(full_payload)
+ # else:
+ # msg_id = db_instance.add_message(full_payload)
+ # assert msg_id is not None
+
+
+
+class TestNotesAndKeywords:
+ def test_add_and_update_note(self, db_instance: CharactersRAGDB):
note_id = db_instance.add_note("Original Title", "Original Content")
- assert note_id is not None
+ assert isinstance(note_id, str)
original_note = db_instance.get_note_by_id(note_id)
- assert original_note is not None
- expected_version = original_note['version'] # Should be 1
-
- updated = db_instance.update_note(note_id, {"title": "Updated Title", "content": "Updated Content"},
- expected_version=expected_version)
+ updated = db_instance.update_note(note_id, {"title": "Updated Title"},
+ expected_version=original_note['version'])
assert updated is True
retrieved = db_instance.get_note_by_id(note_id)
- assert retrieved is not None
assert retrieved["title"] == "Updated Title"
- assert retrieved["content"] == "Updated Content"
- assert retrieved["version"] == expected_version + 1
-
- def test_list_notes(self, db_instance: CharactersRAGDB):
- assert db_instance.list_notes() == []
- id1 = db_instance.add_note("Note A", "Content A")
- # Introduce a slight delay or ensure timestamps are distinct if relying on last_modified for order
- # For this test, assuming add_note sets distinct last_modified or order is by insertion for simple tests
- id2 = db_instance.add_note("Note B", "Content B")
- notes = db_instance.list_notes()
- assert len(notes) == 2
- # Default order is last_modified DESC
- # To make it robust, fetch and compare timestamps or ensure test data forces order
- # For simplicity, if Note B is added after Note A, it should appear first
- note_ids_in_order = [n['id'] for n in notes]
- if id1 and id2: # Ensure they were created
- # This assertion depends on the exact timing of creation and how last_modified is set.
- # A more robust test would explicitly set created_at/last_modified if possible,
- # or query and sort by a reliable field.
- # For now, we assume recent additions are first due to DESC order.
- assert note_ids_in_order[0] == id2
- assert note_ids_in_order[1] == id1
-
- def test_search_notes(self, db_instance: CharactersRAGDB):
- db_instance.add_note("Alpha Note", "Contains a keyword ZYX")
- db_instance.add_note("Beta Note", "Another one with ZYX in title")
- db_instance.add_note("Gamma Note", "Nothing special")
-
- # DEBUGGING:
- # conn = db_instance.get_connection()
- # fts_content = conn.execute("SELECT rowid, title, content FROM notes_fts;").fetchall()
- # print("\nNotes FTS Content:")
- # for row in fts_content:
- # print(dict(row))
- # END DEBUGGING
-
- results = db_instance.search_notes("ZYX")
- assert len(results) == 2
- titles = sorted([r['title'] for r in results]) # Sort for predictable assertion
- assert titles == ["Alpha Note", "Beta Note"]
-
+ assert retrieved["version"] == original_note['version'] + 1
-class TestKeywordsAndCollections:
- def test_add_keyword(self, db_instance: CharactersRAGDB):
- keyword_id = db_instance.add_keyword(" TestKeyword ") # Test stripping
- assert keyword_id is not None
- retrieved = db_instance.get_keyword_by_id(keyword_id)
- assert retrieved is not None
- assert retrieved["keyword"] == "TestKeyword"
- assert retrieved["version"] == 1
+ def test_add_keyword_and_undelete(self, db_instance: CharactersRAGDB):
+ keyword_id = db_instance.add_keyword("TestKeyword")
+ kw_v1 = db_instance.get_keyword_by_id(keyword_id)
- def test_add_keyword_duplicate_active(self, db_instance: CharactersRAGDB):
- db_instance.add_keyword("UniqueKeyword")
- with pytest.raises(ConflictError, match="'UniqueKeyword' already exists and is active"):
- db_instance.add_keyword("UniqueKeyword")
+ db_instance.soft_delete_keyword(keyword_id, expected_version=kw_v1['version'])
+ assert db_instance.get_keyword_by_id(keyword_id) is None
- def test_add_keyword_undelete(self, db_instance: CharactersRAGDB):
- keyword_id = db_instance.add_keyword("ToDeleteAndReadd")
- assert keyword_id is not None
-
- # Get current version for soft delete
- keyword_v1 = db_instance.get_keyword_by_id(keyword_id)
- assert keyword_v1 is not None
-
- db_instance.soft_delete_keyword(keyword_id, expected_version=keyword_v1['version']) # v2, deleted
-
- # Adding same keyword should undelete and update
- # The add_keyword method's undelete logic might not need an explicit expected_version
- # from the client for this specific "add which might undelete" scenario.
- # It internally handles its own version check if it finds a deleted record.
- new_keyword_id = db_instance.add_keyword("ToDeleteAndReadd")
- assert new_keyword_id == keyword_id # Should be the same ID
+ # Adding same keyword again should undelete it
+ new_keyword_id = db_instance.add_keyword("TestKeyword")
+ assert new_keyword_id == keyword_id
retrieved = db_instance.get_keyword_by_id(keyword_id)
- assert retrieved is not None
assert not retrieved["deleted"]
- # Version logic:
- # 1 (initial add)
- # 2 (soft_delete_keyword with expected_version=1)
- # 3 (add_keyword causing undelete, which itself bumps version)
- assert retrieved["version"] == 3
-
- def test_add_keyword_collection(self, db_instance: CharactersRAGDB):
- coll_id = db_instance.add_keyword_collection("My Collection")
- assert coll_id is not None
- retrieved = db_instance.get_keyword_collection_by_id(coll_id)
- assert retrieved is not None
- assert retrieved["name"] == "My Collection"
- assert retrieved["parent_id"] is None
+ assert retrieved["version"] == 3 # 1(add) -> 2(delete) -> 3(undelete/update)
- child_coll_id = db_instance.add_keyword_collection("Child Collection", parent_id=coll_id)
- assert child_coll_id is not None
- retrieved_child = db_instance.get_keyword_collection_by_id(child_coll_id)
- assert retrieved_child is not None
- assert retrieved_child["parent_id"] == coll_id
-
- def test_link_conversation_to_keyword(self, db_instance: CharactersRAGDB):
+ def test_link_and_unlink_conversation_to_keyword(self, db_instance: CharactersRAGDB):
char_id = db_instance.add_character_card(_create_sample_card_data("LinkChar"))
- assert char_id is not None
conv_id = db_instance.add_conversation({"character_id": char_id, "title": "LinkConv"})
- assert conv_id is not None
kw_id = db_instance.add_keyword("Linkable")
- assert kw_id is not None
assert db_instance.link_conversation_to_keyword(conv_id, kw_id) is True
keywords = db_instance.get_keywords_for_conversation(conv_id)
assert len(keywords) == 1
assert keywords[0]["id"] == kw_id
- # Test idempotency
- assert db_instance.link_conversation_to_keyword(conv_id, kw_id) is False # Already linked
+ # Test idempotency of linking
+ assert db_instance.link_conversation_to_keyword(conv_id, kw_id) is False
- # Test unlinking
assert db_instance.unlink_conversation_from_keyword(conv_id, kw_id) is True
assert len(db_instance.get_keywords_for_conversation(conv_id)) == 0
- assert db_instance.unlink_conversation_from_keyword(conv_id, kw_id) is False # Already unlinked
- # Similar tests for other link types:
- # link_collection_to_keyword, link_note_to_keyword
+ # Test idempotency of unlinking
+ assert db_instance.unlink_conversation_from_keyword(conv_id, kw_id) is False
class TestSyncLog:
- def test_sync_log_entry_on_add_character(self, db_instance: CharactersRAGDB):
+ def test_sync_log_entry_on_add_and_update_character(self, db_instance: CharactersRAGDB):
initial_log_max_id = db_instance.get_latest_sync_log_change_id()
card_data = _create_sample_card_data("SyncLogChar")
card_id = db_instance.add_character_card(card_data)
- assert card_id is not None
-
- log_entries = db_instance.get_sync_log_entries(since_change_id=initial_log_max_id) # Get new entries
-
- char_log_entry = None
- for entry in log_entries: # Search among new entries
- if entry["entity"] == "character_cards" and entry["entity_id"] == str(card_id) and entry[
- "operation"] == "create":
- char_log_entry = entry
- break
-
- assert char_log_entry is not None
- assert char_log_entry["payload"]["name"] == card_data["name"]
- assert char_log_entry["payload"]["version"] == 1
- assert char_log_entry["client_id"] == db_instance.client_id
-
- def test_sync_log_entry_on_update_character(self, db_instance: CharactersRAGDB):
- card_id = db_instance.add_character_card(_create_sample_card_data("SyncUpdateChar"))
- assert card_id is not None
-
- original_card = db_instance.get_character_card_by_id(card_id)
- assert original_card is not None
- expected_version = original_card['version']
-
- latest_change_id = db_instance.get_latest_sync_log_change_id()
-
- db_instance.update_character_card(card_id, {"description": "Updated for Sync"},
- expected_version=expected_version)
-
- new_entries = db_instance.get_sync_log_entries(since_change_id=latest_change_id)
- assert len(new_entries) >= 1
- update_log_entry = None
- for entry in new_entries:
- if entry["entity"] == "character_cards" and entry["entity_id"] == str(card_id) and entry[
- "operation"] == "update":
- update_log_entry = entry
- break
-
- assert update_log_entry is not None
- assert update_log_entry["payload"]["description"] == "Updated for Sync"
- assert update_log_entry["payload"]["version"] == expected_version + 1
-
- def test_sync_log_entry_on_soft_delete_character(self, db_instance: CharactersRAGDB):
+ log_entries = db_instance.get_sync_log_entries(since_change_id=initial_log_max_id)
+ create_entry = next((e for e in log_entries if e["entity"] == "character_cards" and e["operation"] == "create"),
+ None)
+ assert create_entry is not None
+ assert create_entry["entity_id"] == str(card_id)
+ assert create_entry["payload"]["name"] == card_data["name"]
+
+ # Test update
+ latest_change_id_after_add = db_instance.get_latest_sync_log_change_id()
+ db_instance.update_character_card(card_id, {"description": "Updated for Sync"}, expected_version=1)
+
+ update_log_entries = db_instance.get_sync_log_entries(since_change_id=latest_change_id_after_add)
+ update_entry = next(
+ (e for e in update_log_entries if e["entity"] == "character_cards" and e["operation"] == "update"), None)
+ assert update_entry is not None
+ assert update_entry["payload"]["description"] == "Updated for Sync"
+ assert update_entry["payload"]["version"] == 2
+
+ def test_sync_log_on_soft_delete_character(self, db_instance: CharactersRAGDB):
card_id = db_instance.add_character_card(_create_sample_card_data("SyncDeleteChar"))
- assert card_id is not None
-
- original_card = db_instance.get_character_card_by_id(card_id)
- assert original_card is not None
- expected_version = original_card['version']
-
latest_change_id = db_instance.get_latest_sync_log_change_id()
- db_instance.soft_delete_character_card(card_id, expected_version=expected_version)
+ db_instance.soft_delete_character_card(card_id, expected_version=1)
new_entries = db_instance.get_sync_log_entries(since_change_id=latest_change_id)
- delete_log_entry = None
- for entry in new_entries:
- if entry["entity"] == "character_cards" and entry["entity_id"] == str(card_id) and entry[
- "operation"] == "delete":
- delete_log_entry = entry
- break
-
- assert delete_log_entry is not None
- # assert delete_log_entry["payload"]["deleted"] is True # Original failing line
- assert delete_log_entry["payload"]["deleted"] == 1 # If JSON payload has integer 1 for true
- # OR, if you expect a boolean true after json.loads and your DB stores it in a way that json.loads makes it bool:
- # assert delete_log_entry["payload"]["deleted"] is True
- # For SQLite storing boolean as 0/1, json.loads(payload_with_integer_1) will keep it as integer 1.
- assert delete_log_entry["payload"]["version"] == expected_version + 1
+ delete_entry = next((e for e in new_entries if e["entity"] == "character_cards" and e["operation"] == "delete"),
+ None)
+ assert delete_entry is not None
+ assert delete_entry["entity_id"] == str(card_id)
+ assert delete_entry["payload"]["deleted"] == 1 # Stored as integer
+ assert delete_entry["payload"]["version"] == 2
def test_sync_log_for_link_tables(self, db_instance: CharactersRAGDB):
char_id = db_instance.add_character_card(_create_sample_card_data("SyncLinkChar"))
- assert char_id is not None
conv_id = db_instance.add_conversation({"character_id": char_id, "title": "SyncLinkConv"})
- assert conv_id is not None
kw_id = db_instance.add_keyword("SyncLinkable")
- assert kw_id is not None
-
latest_change_id = db_instance.get_latest_sync_log_change_id()
+
db_instance.link_conversation_to_keyword(conv_id, kw_id)
- new_entries = db_instance.get_sync_log_entries(since_change_id=latest_change_id)
- link_log_entry = None
- expected_entity_id = f"{conv_id}_{kw_id}"
- for entry in new_entries:
- if entry["entity"] == "conversation_keywords" and entry["entity_id"] == expected_entity_id and entry[
- "operation"] == "create":
- link_log_entry = entry
- break
-
- assert link_log_entry is not None
- assert link_log_entry["payload"]["conversation_id"] == conv_id
- assert link_log_entry["payload"]["keyword_id"] == kw_id
+ link_entries = db_instance.get_sync_log_entries(since_change_id=latest_change_id)
+ link_entry = next(
+ (e for e in link_entries if e["entity"] == "conversation_keywords" and e["operation"] == "create"), None)
+ assert link_entry is not None
+ assert link_entry["payload"]["conversation_id"] == conv_id
+ assert link_entry["payload"]["keyword_id"] == kw_id
+ # Test unlink
latest_change_id_after_link = db_instance.get_latest_sync_log_change_id()
db_instance.unlink_conversation_from_keyword(conv_id, kw_id)
- new_entries_unlink = db_instance.get_sync_log_entries(since_change_id=latest_change_id_after_link)
- unlink_log_entry = None
- for entry in new_entries_unlink:
- if entry["entity"] == "conversation_keywords" and entry["entity_id"] == expected_entity_id and entry[
- "operation"] == "delete":
- unlink_log_entry = entry
- break
- assert unlink_log_entry is not None
+ unlink_entries = db_instance.get_sync_log_entries(since_change_id=latest_change_id_after_link)
+ unlink_entry = next(
+ (e for e in unlink_entries if e["entity"] == "conversation_keywords" and e["operation"] == "delete"), None)
+ assert unlink_entry is not None
+ assert unlink_entry["entity_id"] == f"{conv_id}_{kw_id}"
class TestTransactions:
def test_transaction_commit(self, db_instance: CharactersRAGDB):
- card_data1_name = "Trans1 Character"
- card_data2_name = "Trans2 Character"
-
- with db_instance.transaction() as conn: # Get conn from context for direct execution
- # Direct execution to test transaction atomicity without involving add_character_card's own transaction
- conn.execute(
- "INSERT INTO character_cards (name, description, client_id, last_modified, version) VALUES (?, ?, ?, ?, ?)",
- (card_data1_name, "Desc1", db_instance.client_id, get_current_utc_timestamp_iso(), 1)
- )
- id1_row = conn.execute("SELECT id FROM character_cards WHERE name = ?", (card_data1_name,)).fetchone()
- assert id1_row is not None
- id1 = id1_row['id']
-
- conn.execute(
- "INSERT INTO character_cards (name, description, client_id, last_modified, version) VALUES (?, ?, ?, ?, ?)",
- (card_data2_name, "Desc2", db_instance.client_id, get_current_utc_timestamp_iso(), 1)
- )
-
- retrieved1 = db_instance.get_character_card_by_id(id1)
- retrieved2 = db_instance.get_character_card_by_name(card_data2_name)
- assert retrieved1 is not None
- assert retrieved2 is not None
+ with db_instance.transaction() as conn:
+ conn.execute("INSERT INTO character_cards (name, client_id) VALUES (?, ?)",
+ ("Trans1", db_instance.client_id))
+ conn.execute("INSERT INTO character_cards (name, client_id) VALUES (?, ?)",
+ ("Trans2", db_instance.client_id))
+
+ assert db_instance.get_character_card_by_name("Trans1") is not None
+ assert db_instance.get_character_card_by_name("Trans2") is not None
def test_transaction_rollback(self, db_instance: CharactersRAGDB):
- card_data_name = "TransRollback Character"
initial_count = len(db_instance.list_character_cards())
-
with pytest.raises(sqlite3.IntegrityError):
- with db_instance.transaction() as conn: # Get conn from context
- # First insert (will be part of transaction)
- conn.execute(
- "INSERT INTO character_cards (name, description, client_id, last_modified, version) VALUES (?, ?, ?, ?, ?)",
- (card_data_name, "DescRollback", db_instance.client_id, get_current_utc_timestamp_iso(), 1)
- )
- # Second insert that causes an error (duplicate unique name)
- conn.execute(
- "INSERT INTO character_cards (name, description, client_id, last_modified, version) VALUES (?, ?, ?, ?, ?)",
- (card_data_name, "DescRollbackFail", db_instance.client_id, get_current_utc_timestamp_iso(), 1)
- )
-
- # Check that the first insert was rolled back
+ with db_instance.transaction() as conn:
+ conn.execute("INSERT INTO character_cards (name, client_id) VALUES (?, ?)",
+ ("TransRollback", db_instance.client_id))
+ # This will fail due to duplicate name, causing a rollback
+ conn.execute("INSERT INTO character_cards (name, client_id) VALUES (?, ?)",
+ ("TransRollback", db_instance.client_id))
+
assert len(db_instance.list_character_cards()) == initial_count
- assert db_instance.get_character_card_by_name(card_data_name) is None
-
-# More tests can be added for:
-# - Specific FTS trigger behavior (though search tests cover them indirectly)
-# - Behavior of ON DELETE CASCADE / ON UPDATE CASCADE where applicable (e.g., true deletion of character should cascade to conversations IF hard delete was used and schema supported it)
-# - More complex conflict scenarios with multiple clients (harder to simulate perfectly in unit tests without multiple DB instances writing to the same file).
-# - All permutations of linking and unlinking for all link tables.
-# - All specific error conditions for each method (e.g. InputError for various fields).
\ No newline at end of file
+ assert db_instance.get_character_card_by_name("TransRollback") is None
\ No newline at end of file
diff --git a/Tests/ChaChaNotesDB/test_chachanotes_db_properties.py b/Tests/ChaChaNotesDB/test_chachanotes_db_properties.py
new file mode 100644
index 00000000..23b9ebdd
--- /dev/null
+++ b/Tests/ChaChaNotesDB/test_chachanotes_db_properties.py
@@ -0,0 +1,1128 @@
+# test_chachanotes_db_properties.py
+#
+# Property-based tests for the ChaChaNotes_DB library using Hypothesis.
+#
+# Imports
+import uuid
+import pytest
+import json
+from pathlib import Path
+import sqlite3
+import threading
+import time
+#
+# Third-Party Imports
+from hypothesis import given, strategies as st, settings, HealthCheck
+from hypothesis.stateful import RuleBasedStateMachine, rule, precondition, Bundle
+#
+# Local Imports
+from tldw_chatbook.DB.ChaChaNotes_DB import (
+ CharactersRAGDB,
+ InputError,
+ CharactersRAGDBError,
+ ConflictError
+)
+#
+########################################################################################################################
+#
+# Functions:
+# --- Hypothesis Tests ---
+
+settings.register_profile(
+ "db_friendly",
+ deadline=1000,
+ suppress_health_check=[
+ HealthCheck.too_slow,
+ HealthCheck.function_scoped_fixture # <--- THIS IS THE FIX
+ ]
+)
+settings.load_profile("db_friendly")
+
+# Strategy for generating a valid character card dictionary
+# The `.map(lambda t: ...)` part is to assemble the parts into a dictionary
+st_character_card_data = st.tuples(
+ st.text(min_size=1, max_size=100), # name
+ st.one_of(st.none(), st.text(max_size=500)), # description
+ st.one_of(st.none(), st.text(max_size=500)), # personality
+ st.one_of(st.none(), st.binary(max_size=1024)), # image
+ st.one_of(st.none(), st.lists(st.text(max_size=50)).map(json.dumps)), # alternate_greetings as json string
+ st.one_of(st.none(), st.lists(st.text(max_size=20)).map(json.dumps)) # tags as json string
+).map(lambda t: {
+ "name": t[0],
+ "description": t[1],
+ "personality": t[2],
+ "image": t[3],
+ "alternate_greetings": t[4],
+ "tags": t[5],
+})
+
+# Define a strategy for a non-zero integer to add to the version
+st_version_offset = st.integers().filter(lambda x: x != 0)
+
+# To prevent tests from being too slow on complex data, we can set a deadline.
+# We also disable the 'too_slow' health check as DB operations can sometimes be slow.
+settings.register_profile("db_friendly", deadline=1000, suppress_health_check=[HealthCheck.too_slow])
+settings.load_profile("db_friendly")
+
+# --- Fixtures (Copied from your existing test file for a self-contained example) ---
+
+@pytest.fixture
+def client_id():
+ """Provides a consistent client ID for tests."""
+ return "hypothesis_client"
+
+
+@pytest.fixture
+def db_path(tmp_path):
+ """Provides a temporary path for the database file for each test."""
+ return tmp_path / "prop_test_db.sqlite"
+
+
+@pytest.fixture(scope="function")
+def db_instance(db_path, client_id):
+ """Creates a DB instance for each test, ensuring a fresh database."""
+ current_db_path = Path(db_path)
+ # Ensure no leftover files from a failed previous run
+ for suffix in ["", "-wal", "-shm"]:
+ p = Path(str(current_db_path) + suffix)
+ if p.exists():
+ p.unlink(missing_ok=True)
+
+ db = CharactersRAGDB(current_db_path, client_id)
+ yield db
+ db.close_connection()
+
+
+# --- Hypothesis Strategies ---
+# These strategies define how to generate random, valid data for our database objects.
+
+# A strategy for text fields that cannot be empty or just whitespace.
+st_required_text = st.text(min_size=1, max_size=100).filter(lambda s: s.strip())
+
+# A strategy for optional text or binary fields.
+st_optional_text = st.one_of(st.none(), st.text(max_size=500))
+st_optional_binary = st.one_of(st.none(), st.binary(max_size=1024))
+
+# A strategy for fields that are stored as JSON strings in the DB.
+# We generate a Python list/dict and then map it to a JSON string.
+st_json_list = st.lists(st.text(max_size=50)).map(json.dumps)
+st_json_dict = st.dictionaries(st.text(max_size=20), st.text(max_size=100)).map(json.dumps)
+
+
+@st.composite
+def st_character_card_data(draw):
+ """A composite strategy to generate a dictionary of character card data."""
+ # `draw` is a function that pulls a value from another strategy.
+ name = draw(st_required_text)
+
+ # To avoid conflicts with the guaranteed 'Default Assistant', we filter it out.
+ if name == 'Default Assistant':
+ name += "_hypothesis" # Just ensure it's not the exact name
+
+ return {
+ "name": name,
+ "description": draw(st_optional_text),
+ "personality": draw(st_optional_text),
+ "scenario": draw(st_optional_text),
+ "system_prompt": draw(st_optional_text),
+ "image": draw(st_optional_binary),
+ "post_history_instructions": draw(st_optional_text),
+ "first_message": draw(st_optional_text),
+ "message_example": draw(st_optional_text),
+ "creator_notes": draw(st_optional_text),
+ "alternate_greetings": draw(st.one_of(st.none(), st_json_list)),
+ "tags": draw(st.one_of(st.none(), st_json_list)),
+ "creator": draw(st_optional_text),
+ "character_version": draw(st_optional_text),
+ "extensions": draw(st.one_of(st.none(), st_json_dict)),
+ }
+
+
+@st.composite
+def st_note_data(draw):
+ """Generates a dictionary for a note (title and content)."""
+ return {
+ "title": draw(st_required_text),
+ "content": draw(st.text(max_size=2000)) # Content can be empty
+ }
+
+
+# --- Property Test Classes ---
+
+class TestCharacterCardProperties:
+ """Property-based tests for Character Cards."""
+
+ @settings(suppress_health_check=[HealthCheck.function_scoped_fixture])
+ @given(card_data=st_character_card_data())
+ def test_character_card_roundtrip(self, db_instance: CharactersRAGDB, card_data: dict):
+ """
+ Property: If we add a character card, retrieving it should return the exact same data.
+ This is a "round-trip" test.
+ """
+ try:
+ card_id = db_instance.add_character_card(card_data)
+ except ConflictError:
+ # Hypothesis might generate the same name twice. This is not a failure of the
+ # DB logic, so we just skip this test case.
+ return
+
+ retrieved_card = db_instance.get_character_card_by_id(card_id)
+
+ assert retrieved_card is not None
+ assert retrieved_card["version"] == 1
+ assert not retrieved_card["deleted"]
+
+ # Compare original data with retrieved data
+ for key, value in card_data.items():
+ if key in db_instance._CHARACTER_CARD_JSON_FIELDS and value is not None:
+ # JSON fields are deserialized, so we compare to the parsed version
+ assert retrieved_card[key] == json.loads(value)
+ else:
+ assert retrieved_card[key] == value
+
+ @settings(suppress_health_check=[HealthCheck.function_scoped_fixture])
+ @given(initial_card=st_character_card_data(), update_payload=st_character_card_data())
+ def test_update_increments_version_and_changes_data(self, db_instance: CharactersRAGDB, initial_card: dict,
+ update_payload: dict):
+ """
+ Property: A successful update must increment the version number by exactly 1
+ and correctly apply the new data.
+ """
+ try:
+ card_id = db_instance.add_character_card(initial_card)
+ except ConflictError:
+ return # Skip if initial card name conflicts
+
+ original_card = db_instance.get_character_card_by_id(card_id)
+
+ try:
+ success = db_instance.update_character_card(card_id, update_payload,
+ expected_version=original_card['version'])
+ except ConflictError as e:
+ # An update can legitimately fail if the new name is already taken.
+ # We accept this as a valid outcome.
+ assert "already exists" in str(e)
+ return
+
+ assert success is True
+
+ updated_card = db_instance.get_character_card_by_id(card_id)
+ assert updated_card is not None
+ assert updated_card['version'] == original_card['version'] + 1
+
+ # Verify the payload was applied
+ assert updated_card['name'] == update_payload['name']
+ assert updated_card['description'] == update_payload['description']
+
+ @settings(suppress_health_check=[HealthCheck.function_scoped_fixture])
+ @given(card_data=st_character_card_data())
+ def test_soft_delete_makes_item_unfindable(self, db_instance: CharactersRAGDB, card_data: dict):
+ """
+ Property: After soft-deleting an item, it should not be retrievable by
+ the standard `get` or `list` methods, but should exist in the DB with deleted=1.
+ """
+ try:
+ card_id = db_instance.add_character_card(card_data)
+ except ConflictError:
+ return
+
+ original_card = db_instance.get_character_card_by_id(card_id)
+
+ # Perform the soft delete
+ success = db_instance.soft_delete_character_card(card_id, expected_version=original_card['version'])
+ assert success is True
+
+ # Assert it's no longer findable via public methods
+ assert db_instance.get_character_card_by_id(card_id) is None
+
+ all_cards = db_instance.list_character_cards()
+ assert card_id not in [c['id'] for c in all_cards]
+
+ # Assert its raw state in the DB is correct
+ conn = db_instance.get_connection()
+ raw_record = conn.execute("SELECT deleted, version FROM character_cards WHERE id = ?", (card_id,)).fetchone()
+ assert raw_record is not None
+ assert raw_record['deleted'] == 1
+ assert raw_record['version'] == original_card['version'] + 1
+
+ @settings(suppress_health_check=[HealthCheck.function_scoped_fixture])
+ @given(
+ initial_card=st_character_card_data(),
+ update_payload=st_character_card_data(),
+ version_offset=st_version_offset
+ )
+ def test_update_with_stale_version_always_fails(self, db_instance: CharactersRAGDB, initial_card: dict,
+ update_payload: dict, version_offset: int):
+ """
+ Property: Attempting to update a record with an incorrect `expected_version`
+ must always raise a ConflictError.
+ """
+ try:
+ card_id = db_instance.add_character_card(initial_card)
+ except ConflictError:
+ return # Skip if initial card name conflicts
+
+ original_card = db_instance.get_character_card_by_id(card_id)
+
+ # Use the generated non-zero offset to create a stale version
+ stale_version = original_card['version'] + version_offset
+
+ with pytest.raises(ConflictError):
+ db_instance.update_character_card(card_id, update_payload, expected_version=stale_version)
+
+ @settings(suppress_health_check=[HealthCheck.function_scoped_fixture])
+ @given(
+ initial_card=st_character_card_data(),
+ update_payload=st_character_card_data()
+ )
+ def test_update_does_not_change_immutable_fields(self, db_instance: CharactersRAGDB, initial_card: dict,
+ update_payload: dict):
+ """
+ Property: The `update` method must not change immutable fields like `id` and `created_at`,
+ even if they are passed in the payload.
+ """
+ try:
+ card_id = db_instance.add_character_card(initial_card)
+ except ConflictError:
+ return
+
+ original_card = db_instance.get_character_card_by_id(card_id)
+
+ # Add immutable fields to the update payload to try and change them
+ malicious_payload = update_payload.copy()
+ malicious_payload['id'] = 99999 # Try to change the ID
+ malicious_payload['created_at'] = "1999-01-01T00:00:00Z" # Try to change creation time
+
+ try:
+ db_instance.update_character_card(card_id, malicious_payload, expected_version=original_card['version'])
+ except ConflictError:
+ # This can happen if the update_payload name conflicts, which is a valid outcome.
+ return
+
+ updated_card = db_instance.get_character_card_by_id(card_id)
+
+ # Assert that the immutable fields did NOT change.
+ assert updated_card['id'] == original_card['id']
+ assert updated_card['created_at'] == original_card['created_at']
+
+
+class TestNoteAndKeywordProperties:
+ """Property-based tests for Notes, Keywords, and their linking."""
+
+ @settings(suppress_health_check=[HealthCheck.function_scoped_fixture])
+ @given(note_data=st_note_data())
+ def test_note_roundtrip(self, db_instance: CharactersRAGDB, note_data: dict):
+ """
+ Property: A created note, when retrieved, has the same data,
+ accounting for any sanitization (like stripping whitespace).
+ """
+ note_id = db_instance.add_note(**note_data)
+ assert note_id is not None
+
+ retrieved = db_instance.get_note_by_id(note_id)
+
+ assert retrieved is not None
+ # Compare the retrieved title to the STRIPPED version of the original title
+ assert retrieved['title'] == note_data['title'].strip() # <-- The fix
+ assert retrieved['content'] == note_data['content']
+ assert retrieved['version'] == 1
+
+ @settings(suppress_health_check=[HealthCheck.function_scoped_fixture])
+ @given(keyword=st_required_text)
+ def test_add_keyword_is_idempotent_on_undelete(self, db_instance: CharactersRAGDB, keyword: str):
+ """
+ Property: Adding a keyword that was previously soft-deleted should reactivate
+ it, not create a new one, and its version should be correctly incremented.
+ """
+ # 1. Add for the first time
+ try:
+ kw_id_v1 = db_instance.add_keyword(keyword)
+ except ConflictError:
+ return
+
+ kw_v1 = db_instance.get_keyword_by_id(kw_id_v1)
+ assert kw_v1['version'] == 1
+
+ # 2. Soft delete it
+ db_instance.soft_delete_keyword(kw_id_v1, expected_version=1)
+ kw_v2_raw = db_instance.get_connection().execute("SELECT * FROM keywords WHERE id = ?", (kw_id_v1,)).fetchone()
+ assert kw_v2_raw['deleted'] == 1
+ assert kw_v2_raw['version'] == 2
+
+ # 3. Add it again (should trigger undelete)
+ kw_id_v3 = db_instance.add_keyword(keyword)
+
+ # Assert it's the same record
+ assert kw_id_v3 == kw_id_v1
+
+ kw_v3 = db_instance.get_keyword_by_id(kw_id_v3)
+ assert kw_v3 is not None
+ assert not kw_v3['deleted']
+ # The version should be 3 (1=create, 2=delete, 3=undelete)
+ assert kw_v3['version'] == 3
+
+ @settings(suppress_health_check=[HealthCheck.function_scoped_fixture])
+ @given(note_data=st_note_data(), keyword_text=st_required_text)
+ def test_linking_and_unlinking_properties(self, db_instance: CharactersRAGDB, note_data: dict, keyword_text: str):
+ """
+ Property: Linking two items should make them appear in each other's "get_links"
+ methods, and unlinking should remove them.
+ """
+ try:
+ note_id = db_instance.add_note(**note_data)
+ keyword_id = db_instance.add_keyword(keyword_text)
+ except ConflictError:
+ return # Skip if hypothesis generates conflicting data
+
+ # Initially, no links should exist
+ assert db_instance.get_keywords_for_note(note_id) == []
+ assert db_instance.get_notes_for_keyword(keyword_id) == []
+
+ # --- Test Linking ---
+ link_success = db_instance.link_note_to_keyword(note_id, keyword_id)
+ assert link_success is True
+
+ # Check that linking again is idempotent (returns False)
+ link_again_success = db_instance.link_note_to_keyword(note_id, keyword_id)
+ assert link_again_success is False
+
+ # Verify the link exists from both sides
+ keywords_for_note = db_instance.get_keywords_for_note(note_id)
+ assert len(keywords_for_note) == 1
+ assert keywords_for_note[0]['id'] == keyword_id
+
+ notes_for_keyword = db_instance.get_notes_for_keyword(keyword_id)
+ assert len(notes_for_keyword) == 1
+ assert notes_for_keyword[0]['id'] == note_id
+
+ # --- Test Unlinking ---
+ unlink_success = db_instance.unlink_note_from_keyword(note_id, keyword_id)
+ assert unlink_success is True
+
+ # Check that unlinking again is idempotent (returns False)
+ unlink_again_success = db_instance.unlink_note_from_keyword(note_id, keyword_id)
+ assert unlink_again_success is False
+
+ # Verify the link is gone
+ assert db_instance.get_keywords_for_note(note_id) == []
+ assert db_instance.get_notes_for_keyword(keyword_id) == []
+
+
+class TestAdvancedProperties:
+ @settings(suppress_health_check=[HealthCheck.function_scoped_fixture])
+ @given(note_data=st_note_data())
+ def test_soft_deleted_item_is_not_in_fts(self, db_instance: CharactersRAGDB, note_data: dict):
+ """
+ Property: Once an item is soft-deleted, it must not appear in FTS search results.
+ """
+ # Ensure the title has a unique, searchable term.
+ unique_term = str(uuid.uuid4())
+ note_data['title'] = f"{note_data['title']} {unique_term}"
+
+ note_id = db_instance.add_note(**note_data)
+ original_note = db_instance.get_note_by_id(note_id)
+
+ # 1. Verify it IS searchable before deletion
+ search_results_before = db_instance.search_notes(unique_term)
+ assert len(search_results_before) == 1
+ assert search_results_before[0]['id'] == note_id
+
+ # 2. Soft-delete the note
+ db_instance.soft_delete_note(note_id, expected_version=original_note['version'])
+
+ # 3. Verify it is NOT searchable after deletion
+ search_results_after = db_instance.search_notes(unique_term)
+ assert len(search_results_after) == 0
+
+ @settings(suppress_health_check=[HealthCheck.function_scoped_fixture])
+ @given(note_data=st_note_data())
+ def test_add_creates_sync_log_entry(self, db_instance: CharactersRAGDB, note_data: dict):
+ """
+ Property: Adding a new item must create exactly one 'create' operation
+ in the sync_log for that item.
+ """
+ latest_change_id_before = db_instance.get_latest_sync_log_change_id()
+
+ # Add the note (this action should be logged by a trigger)
+ note_id = db_instance.add_note(**note_data)
+
+ new_log_entries = db_instance.get_sync_log_entries(since_change_id=latest_change_id_before)
+
+ # There should be exactly one new entry
+ assert len(new_log_entries) == 1
+
+ log_entry = new_log_entries[0]
+ assert log_entry['entity'] == 'notes'
+ assert log_entry['entity_id'] == note_id
+ assert log_entry['operation'] == 'create'
+ assert log_entry['client_id'] == db_instance.client_id
+ assert log_entry['version'] == 1
+ assert log_entry['payload']['title'] == note_data['title'].strip() # The log stores the stripped version
+ assert log_entry['payload']['content'] == note_data['content']
+
+ @settings(suppress_health_check=[HealthCheck.function_scoped_fixture])
+ @given(note_data=st_note_data(), update_content=st.text(max_size=500))
+ def test_update_creates_sync_log_entry(self, db_instance: CharactersRAGDB, note_data: dict, update_content: str):
+ """
+ Property: Updating an item must create exactly one 'update' operation
+ in the sync_log for that item.
+ """
+ note_id = db_instance.add_note(**note_data)
+ latest_change_id_before = db_instance.get_latest_sync_log_change_id()
+
+ # Update the note
+ update_payload = {'content': update_content}
+ db_instance.update_note(note_id, update_payload, expected_version=1)
+
+ new_log_entries = db_instance.get_sync_log_entries(since_change_id=latest_change_id_before)
+
+ assert len(new_log_entries) == 1
+ log_entry = new_log_entries[0]
+ assert log_entry['entity'] == 'notes'
+ assert log_entry['entity_id'] == note_id
+ assert log_entry['operation'] == 'update'
+ assert log_entry['version'] == 2
+ assert log_entry['payload']['content'] == update_content
+
+ @settings(suppress_health_check=[HealthCheck.function_scoped_fixture])
+ @given(note_data=st_note_data())
+ def test_delete_creates_sync_log_entry(self, db_instance: CharactersRAGDB, note_data: dict):
+ """
+ Property: Soft-deleting an item must create exactly one 'delete' operation
+ in the sync_log for that item.
+ """
+ note_id = db_instance.add_note(**note_data)
+ latest_change_id_before = db_instance.get_latest_sync_log_change_id()
+
+ # Delete the note
+ db_instance.soft_delete_note(note_id, expected_version=1)
+
+ new_log_entries = db_instance.get_sync_log_entries(since_change_id=latest_change_id_before)
+
+ assert len(new_log_entries) == 1
+ log_entry = new_log_entries[0]
+ assert log_entry['entity'] == 'notes'
+ assert log_entry['entity_id'] == note_id
+ assert log_entry['operation'] == 'delete'
+ assert log_entry['version'] == 2
+ assert log_entry['payload']['deleted'] == 1
+
+ @settings(suppress_health_check=[HealthCheck.function_scoped_fixture])
+ @given(note_data=st_note_data(), keyword_text=st_required_text)
+ def test_link_action_creates_correct_sync_log(self, db_instance: CharactersRAGDB, note_data: dict,
+ keyword_text: str):
+ """
+ Property: The `link_note_to_keyword` action must create a sync_log entry
+ with the correct entity, IDs, and operation in its payload.
+ """
+ try:
+ note_id = db_instance.add_note(**note_data)
+ kw_id = db_instance.add_keyword(keyword_text)
+ except ConflictError:
+ return
+
+ latest_change_id_before = db_instance.get_latest_sync_log_change_id()
+
+ # Action
+ db_instance.link_note_to_keyword(note_id, kw_id)
+
+ # There should be exactly one new log entry from the linking.
+ # We ignore the create logs for the note and keyword.
+ link_log_entries = db_instance.get_sync_log_entries(since_change_id=latest_change_id_before,
+ entity_type='note_keywords')
+ assert len(link_log_entries) == 1
+
+ log_entry = link_log_entries[0]
+ assert log_entry['entity'] == 'note_keywords'
+ assert log_entry['operation'] == 'create'
+ assert log_entry['entity_id'] == f"{note_id}_{kw_id}"
+ assert log_entry['payload']['note_id'] == note_id
+ assert log_entry['payload']['keyword_id'] == kw_id
+
+
+ @settings(suppress_health_check=[HealthCheck.function_scoped_fixture])
+ @given(note_data=st_note_data(), keyword_text=st_required_text)
+ def test_unlink_action_creates_correct_sync_log(self, db_instance: CharactersRAGDB, note_data: dict,
+ keyword_text: str):
+ """
+ Property: The `unlink_note_to_keyword` action must create a sync_log entry
+ with the correct entity, IDs, and operation.
+ """
+ try:
+ note_id = db_instance.add_note(**note_data)
+ kw_id = db_instance.add_keyword(keyword_text)
+ except ConflictError:
+ return
+
+ db_instance.link_note_to_keyword(note_id, kw_id)
+ latest_change_id_before = db_instance.get_latest_sync_log_change_id()
+
+ # Action
+ db_instance.unlink_note_from_keyword(note_id, kw_id)
+
+ unlink_log_entries = db_instance.get_sync_log_entries(since_change_id=latest_change_id_before,
+ entity_type='note_keywords')
+ assert len(unlink_log_entries) == 1
+
+ log_entry = unlink_log_entries[0]
+ assert log_entry['entity'] == 'note_keywords'
+ assert log_entry['operation'] == 'delete'
+ assert log_entry['entity_id'] == f"{note_id}_{kw_id}"
+ assert log_entry['payload']['note_id'] == note_id
+ assert log_entry['payload']['keyword_id'] == kw_id
+
+
+@pytest.fixture
+def populated_conversation(db_instance: CharactersRAGDB):
+ """A fixture to create a character, conversation, and message for cascade tests."""
+ card_id = db_instance.add_character_card({'name': 'Cascade Test Character'})
+ card = db_instance.get_character_card_by_id(card_id)
+
+ conv_id = db_instance.add_conversation({'character_id': card['id'], 'title': 'Cascade Conv'})
+ conv = db_instance.get_conversation_by_id(conv_id)
+
+ msg_id = db_instance.add_message({'conversation_id': conv['id'], 'sender': 'user', 'content': 'Cascade Msg'})
+ msg = db_instance.get_message_by_id(msg_id)
+
+ return {"card": card, "conv": conv, "msg": msg}
+
+
+class TestCascadeAndLinkingProperties:
+ def test_soft_deleting_conversation_makes_messages_unfindable(self, db_instance: CharactersRAGDB,
+ populated_conversation):
+ """
+ Property: After a conversation is soft-deleted, its messages should not be
+ returned by get_messages_for_conversation.
+ """
+ conv = populated_conversation['conv']
+ msg = populated_conversation['msg']
+
+ # 1. Verify message and conversation exist before
+ assert db_instance.get_message_by_id(msg['id']) is not None
+ assert len(db_instance.get_messages_for_conversation(conv['id'])) == 1
+
+ # 2. Soft-delete the parent conversation
+ db_instance.soft_delete_conversation(conv['id'], expected_version=conv['version'])
+
+ # 3. Verify the messages are now un-findable via the main query method
+ messages = db_instance.get_messages_for_conversation(conv['id'])
+ assert messages == []
+
+ # 4. As per our current design, the message record itself still exists (orphaned).
+ # This is an important check of the current behavior.
+ assert db_instance.get_message_by_id(msg['id']) is not None
+
+ def test_hard_deleting_conversation_cascades_to_messages(self, db_instance: CharactersRAGDB,
+ populated_conversation):
+ """
+ Property: A hard DELETE on a conversation should cascade and delete its
+ messages, enforcing the FOREIGN KEY ... ON DELETE CASCADE constraint.
+ """
+ conv = populated_conversation['conv']
+ msg = populated_conversation['msg']
+
+ # 1. Verify message exists before
+ assert db_instance.get_message_by_id(msg['id']) is not None
+
+ # 2. Perform a hard delete in a clean transaction
+ with db_instance.transaction() as conn:
+ conn.execute("DELETE FROM conversations WHERE id = ?", (conv['id'],))
+
+ # 3. Now the message should be truly gone.
+ assert db_instance.get_message_by_id(msg['id']) is None
+
+ @settings(suppress_health_check=[HealthCheck.function_scoped_fixture])
+ @given(keyword_text=st_required_text)
+ def test_deleting_keyword_cascades_to_link_tables(self, db_instance: CharactersRAGDB, populated_conversation,
+ keyword_text: str):
+ """
+ Property: Deleting a keyword should remove all links to it from linking tables
+ due to ON DELETE CASCADE.
+ """
+ conv = populated_conversation['conv']
+ try:
+ kw_id = db_instance.add_keyword(keyword_text)
+ except ConflictError:
+ return
+
+ # Link the conversation and keyword
+ db_instance.link_conversation_to_keyword(conv['id'], kw_id)
+
+ # Verify link exists
+ keywords = db_instance.get_keywords_for_conversation(conv['id'])
+ assert len(keywords) == 1
+ assert keywords[0]['id'] == kw_id
+
+ # Soft-delete the keyword
+ keyword = db_instance.get_keyword_by_id(kw_id)
+ db_instance.soft_delete_keyword(kw_id, keyword['version'])
+
+ # The link should now be gone when we retrieve it, because the JOIN will fail
+ # on `k.deleted = 0`. This tests the query logic.
+ keywords_after_delete = db_instance.get_keywords_for_conversation(conv['id'])
+ assert keywords_after_delete == []
+
+
+ @settings(suppress_health_check=[HealthCheck.function_scoped_fixture])
+ @given(note_data=st_note_data(), keyword_text=st_required_text)
+ def test_linking_is_idempotent(self, db_instance: CharactersRAGDB, note_data: dict, keyword_text: str):
+ """
+ Property: Calling a link function multiple times has the same effect as calling it once.
+ The first call should return True (1 row affected), subsequent calls should return False (0 rows affected).
+ """
+ try:
+ note_id = db_instance.add_note(**note_data)
+ kw_id = db_instance.add_keyword(keyword_text)
+ except ConflictError:
+ return
+
+ # First call should succeed and return True
+ assert db_instance.link_note_to_keyword(note_id, kw_id) is True
+
+ # Second call should do nothing and return False
+ assert db_instance.link_note_to_keyword(note_id, kw_id) is False
+
+ # Third call should also do nothing and return False
+ assert db_instance.link_note_to_keyword(note_id, kw_id) is False
+
+ # Verify there is still only one link
+ keywords = db_instance.get_keywords_for_note(note_id)
+ assert len(keywords) == 1
+
+
+ @settings(suppress_health_check=[HealthCheck.function_scoped_fixture])
+ @given(note_data=st_note_data(), keyword_text=st_required_text)
+ def test_unlinking_is_idempotent(self, db_instance: CharactersRAGDB, note_data: dict, keyword_text: str):
+ """
+ Property: Calling an unlink function on a non-existent link does nothing.
+ Calling it on an existing link works once, then does nothing on subsequent calls.
+ """
+ try:
+ note_id = db_instance.add_note(**note_data)
+ kw_id = db_instance.add_keyword(keyword_text)
+ except ConflictError:
+ return
+
+ # 1. Unlinking a non-existent link should return False
+ assert db_instance.unlink_note_from_keyword(note_id, kw_id) is False
+
+ # 2. Create the link
+ db_instance.link_note_to_keyword(note_id, kw_id)
+ assert len(db_instance.get_keywords_for_note(note_id)) == 1
+
+ # 3. First unlink should succeed and return True
+ assert db_instance.unlink_note_from_keyword(note_id, kw_id) is True
+
+ # 4. Second unlink should fail and return False
+ assert db_instance.unlink_note_from_keyword(note_id, kw_id) is False
+
+ # Verify the link is gone
+ assert len(db_instance.get_keywords_for_note(note_id)) == 0
+
+
+# ==========================================================
+# == STATE MACHINE SECTION
+# ==========================================================
+
+class NoteLifecycleMachine(RuleBasedStateMachine):
+ """
+ This class defines the rules and state for our test.
+ It is not run directly by pytest.
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.db = None # This will be injected by the test class
+ self.note_id = None
+ self.expected_version = 0
+ self.is_deleted = True
+
+ notes = Bundle('notes')
+
+ @rule(target=notes, note_data=st_note_data())
+ def create_note(self, note_data):
+ # We only want to test the lifecycle of one note per machine run.
+ if self.note_id is not None:
+ return
+
+ self.note_id = self.db.add_note(**note_data)
+ self.is_deleted = False
+ self.expected_version = 1
+
+ retrieved = self.db.get_note_by_id(self.note_id)
+ assert retrieved is not None
+ return self.note_id
+
+ @rule(note_id=notes, update_data=st_note_data())
+ def update_note(self, note_id, update_data):
+ if self.note_id is None or self.is_deleted:
+ return
+
+ success = self.db.update_note(note_id, update_data, self.expected_version)
+ assert success
+ self.expected_version += 1
+
+ retrieved = self.db.get_note_by_id(self.note_id)
+ assert retrieved is not None
+ assert retrieved['version'] == self.expected_version
+
+ @rule(note_id=notes)
+ def soft_delete_note(self, note_id):
+ if self.note_id is None or self.is_deleted:
+ return
+
+ success = self.db.soft_delete_note(note_id, self.expected_version)
+ assert success
+ self.expected_version += 1
+ self.is_deleted = True
+
+ assert self.db.get_note_by_id(self.note_id) is None
+
+
+# This class IS the test. pytest will discover it.
+# It inherits our rules and provides the `db_instance` fixture.
+@settings(suppress_health_check=[HealthCheck.function_scoped_fixture, HealthCheck.too_slow], max_examples=50)
+class TestNoteLifecycleAsTest(NoteLifecycleMachine):
+
+ @pytest.fixture(autouse=True)
+ def inject_db(self, db_instance):
+ """Injects the clean db_instance fixture into the state machine for each test run."""
+ self.db = db_instance
+
+
+# ==========================================================
+# == Character Card State Machine
+# ==========================================================
+
+class CharacterCardLifecycleMachine(RuleBasedStateMachine):
+ """Models the lifecycle of a CharacterCard."""
+
+ def __init__(self):
+ super().__init__()
+ self.db = None
+ self.card_id = None
+ self.card_name = None
+ self.expected_version = 0
+ self.is_deleted = True
+
+ cards = Bundle('cards')
+
+ @rule(target=cards, card_data=st_character_card_data())
+ def create_card(self, card_data):
+ # Only create one card per machine run for simplicity.
+ if self.card_id is not None:
+ return
+
+ try:
+ new_id = self.db.add_character_card(card_data)
+ except ConflictError:
+ # It's possible for hypothesis to generate a duplicate name
+ # in its sequence. We treat this as "no action taken".
+ return
+
+ self.card_id = new_id
+ self.card_name = card_data['name']
+ self.expected_version = 1
+ self.is_deleted = False
+
+ retrieved = self.db.get_character_card_by_id(self.card_id)
+ assert retrieved is not None
+ assert retrieved['name'] == self.card_name
+ return self.card_id
+
+ @rule(card_id=cards, update_data=st_character_card_data())
+ def update_card(self, card_id, update_data):
+ if self.card_id is None or self.is_deleted:
+ return
+
+ try:
+ success = self.db.update_character_card(card_id, update_data, self.expected_version)
+ assert success
+ self.expected_version += 1
+ self.card_name = update_data['name'] # Name can change
+ except ConflictError as e:
+ # Update can fail legitimately if the new name is already taken.
+ assert "already exists" in str(e)
+ # The state of our card hasn't changed, so we just return.
+ return
+
+ retrieved = self.db.get_character_card_by_id(self.card_id)
+ assert retrieved is not None
+ assert retrieved['version'] == self.expected_version
+ assert retrieved['name'] == self.card_name
+
+ @rule(card_id=cards)
+ def soft_delete_card(self, card_id):
+ if self.card_id is None or self.is_deleted:
+ return
+
+ success = self.db.soft_delete_character_card(card_id, self.expected_version)
+ assert success
+ self.expected_version += 1
+ self.is_deleted = True
+
+ assert self.db.get_character_card_by_id(self.card_id) is None
+ assert self.db.get_character_card_by_name(self.card_name) is None
+
+
+# The pytest test class that runs the machine
+@settings(suppress_health_check=[HealthCheck.function_scoped_fixture, HealthCheck.too_slow], max_examples=50)
+class TestCharacterCardLifecycle(CharacterCardLifecycleMachine):
+
+ @pytest.fixture(autouse=True)
+ def inject_db(self, db_instance):
+ self.db = db_instance
+
+
+st_message_content = st.text(max_size=1000)
+
+
+# ==========================================================
+# == Conversation/Message State Machine
+# ==========================================================
+
+class ConversationMachine(RuleBasedStateMachine):
+ """
+ Models creating a conversation and adding messages to it.
+ This machine tests the integrity of a single conversation over time.
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.db = None
+ self.card_id = None
+ self.conv_id = None
+ self.message_count = 0
+
+ @precondition(lambda self: self.card_id is None)
+ @rule()
+ def create_character(self):
+ """Setup step: create a character to host the conversation."""
+ self.card_id = self.db.add_character_card({'name': 'Chat Host Character'})
+ assert self.card_id is not None
+
+ @precondition(lambda self: self.card_id is not None and self.conv_id is None)
+ @rule()
+ def create_conversation(self):
+ """Create the main conversation for this test run."""
+ self.conv_id = self.db.add_conversation({'character_id': self.card_id, 'title': 'Test Chat'})
+ assert self.conv_id is not None
+
+ @precondition(lambda self: self.conv_id is not None)
+ @rule(content=st_message_content, sender=st.sampled_from(['user', 'ai']))
+ def add_message(self, content, sender):
+ """Add a new message to the existing conversation."""
+ if not content and not sender: # Ensure message has some substance
+ return
+
+ msg_id = self.db.add_message({
+ 'conversation_id': self.conv_id,
+ 'sender': sender,
+ 'content': content
+ })
+ assert msg_id is not None
+ self.message_count += 1
+
+ def teardown(self):
+ """
+ This method is called at the end of a state machine run.
+ We use it to check the final state of the system.
+ """
+ if self.conv_id is not None:
+ messages = self.db.get_messages_for_conversation(self.conv_id)
+ assert len(messages) == self.message_count
+
+
+# The pytest test class that runs the machine
+@settings(suppress_health_check=[HealthCheck.function_scoped_fixture, HealthCheck.too_slow], max_examples=20,
+ stateful_step_count=50)
+class TestConversationInteractions(ConversationMachine):
+
+ @pytest.fixture(autouse=True)
+ def inject_db(self, db_instance):
+ self.db = db_instance
+
+
+class TestDataIntegrity:
+ def test_add_conversation_with_nonexistent_character_fails(self, db_instance: CharactersRAGDB):
+ """
+ Property: Cannot create a conversation for a character_id that does not exist.
+ This tests the FOREIGN KEY constraint.
+ """
+ non_existent_char_id = 99999
+ conv_data = {"character_id": non_existent_char_id, "title": "Orphan Conversation"}
+
+ # The database will raise an IntegrityError. Your wrapper should catch this
+ # and raise a custom error.
+ with pytest.raises(CharactersRAGDBError, match="FOREIGN KEY constraint failed"):
+ db_instance.add_conversation(conv_data)
+
+ def test_add_message_to_nonexistent_conversation_fails(self, db_instance: CharactersRAGDB):
+ """
+ Property: Cannot add a message to a conversation_id that does not exist.
+ """
+ non_existent_conv_id = "a-fake-uuid-string"
+ msg_data = {
+ "conversation_id": non_existent_conv_id,
+ "sender": "user",
+ "content": "Message to nowhere"
+ }
+
+ # Your `add_message` has a pre-flight check for this, which should raise InputError.
+ # This tests your application-level check.
+ with pytest.raises(InputError, match="Conversation ID .* not found or deleted"):
+ db_instance.add_message(msg_data)
+
+ def test_rating_outside_range_fails(self, db_instance: CharactersRAGDB):
+ """
+ Property: Conversation rating must be between 1 and 5.
+ This tests the CHECK constraint via the public API.
+ """
+ card_id = db_instance.add_character_card({'name': 'Rating Test Character'})
+
+ # Test the application-level check in `update_conversation`
+ conv_id = db_instance.add_conversation({'character_id': card_id, "title": "Rating Conv"})
+ with pytest.raises(InputError, match="Rating must be between 1 and 5"):
+ db_instance.update_conversation(conv_id, {"rating": 0}, expected_version=1)
+ with pytest.raises(InputError, match="Rating must be between 1 and 5"):
+ db_instance.update_conversation(conv_id, {"rating": 6}, expected_version=1)
+
+ # Test the DB-level CHECK constraint directly by calling the wrapped execute_query
+ with pytest.raises(CharactersRAGDBError, match="Database constraint violation"):
+ # We start a transaction to ensure atomicity, but call the DB method
+ # that handles exception wrapping.
+ with db_instance.transaction():
+ db_instance.execute_query("UPDATE conversations SET rating = 10 WHERE id = ?", (conv_id,))
+
+
+class TestConcurrency:
+ def test_each_thread_gets_a_separate_connection(self, db_instance: CharactersRAGDB):
+ """
+ Property: The `_get_thread_connection` method must provide a unique
+ connection object for each thread.
+ """
+ connection_ids = set()
+ lock = threading.Lock()
+
+ def get_and_store_conn_id():
+ conn = db_instance.get_connection()
+ with lock:
+ connection_ids.add(id(conn))
+
+ threads = [threading.Thread(target=get_and_store_conn_id) for _ in range(5)]
+ for t in threads:
+ t.start()
+ for t in threads:
+ t.join()
+
+ # If threading.local is working, there should be 5 unique connection IDs.
+ assert len(connection_ids) == 5
+
+ def test_wal_mode_allows_concurrent_reads_during_write_transaction(self, db_instance: CharactersRAGDB):
+ """
+ Property: In WAL mode, one thread can read from the DB while another
+ thread has an open write transaction.
+ """
+ card_id = db_instance.add_character_card({'name': 'Concurrent Read Test'})
+
+ # A threading.Event to signal when the write transaction has started
+ write_transaction_started = threading.Event()
+ read_result = []
+
+ def writer_thread():
+ with db_instance.transaction():
+ db_instance.update_character_card(card_id, {'description': 'long update'}, 1)
+ write_transaction_started.set() # Signal that the transaction is open
+ time.sleep(0.2) # Hold the transaction open
+ # Transaction commits here
+
+ def reader_thread():
+ write_transaction_started.wait() # Wait until the writer is in its transaction
+ # This read should succeed immediately and not be blocked by the writer.
+ card = db_instance.get_character_card_by_id(card_id)
+ read_result.append(card)
+
+ w = threading.Thread(target=writer_thread)
+ r = threading.Thread(target=reader_thread)
+
+ w.start()
+ r.start()
+ w.join()
+ r.join()
+
+ # The reader thread should have completed successfully and read the *original* state.
+ assert len(read_result) == 1
+ assert read_result[0] is not None
+ assert read_result[0]['description'] is None # It read the state before the writer committed.
+
+
+class TestComplexQueries:
+ def test_get_keywords_for_conversation_filters_deleted_keywords(self, db_instance: CharactersRAGDB):
+ """
+ Property: When fetching keywords for a conversation, soft-deleted
+ keywords must be excluded from the results.
+ """
+ card_id = db_instance.add_character_card({'name': 'Filter Test'})
+ conv_id = db_instance.add_conversation({'character_id': card_id})
+
+ kw1_id = db_instance.add_keyword("Active Keyword")
+ kw2_id = db_instance.add_keyword("Keyword to be Deleted")
+ kw2 = db_instance.get_keyword_by_id(kw2_id)
+
+ db_instance.link_conversation_to_keyword(conv_id, kw1_id)
+ db_instance.link_conversation_to_keyword(conv_id, kw2_id)
+
+ # Verify both are present initially
+ assert len(db_instance.get_keywords_for_conversation(conv_id)) == 2
+
+ # Soft-delete one of the keywords
+ db_instance.soft_delete_keyword(kw2_id, kw2['version'])
+
+ # Fetch again and verify only the active one remains
+ remaining_keywords = db_instance.get_keywords_for_conversation(conv_id)
+ assert len(remaining_keywords) == 1
+ assert remaining_keywords[0]['id'] == kw1_id
+ assert remaining_keywords[0]['keyword'] == "Active Keyword"
+
+
+class TestDBOperations:
+ def test_backup_and_restore_correctness(self, db_instance: CharactersRAGDB, tmp_path: Path):
+ """
+ Property: A database created from a backup file must contain the exact
+ same data as the original database at the time of backup.
+ """
+ # 1. Populate the original database with known data
+ card_data = {'name': 'Backup Test Card', 'description': 'Data to be saved'}
+ card_id = db_instance.add_character_card(card_data)
+ original_card = db_instance.get_character_card_by_id(card_id)
+
+ # 2. Perform the backup
+ backup_path = tmp_path / "test_backup.db"
+ assert db_instance.backup_database(str(backup_path)) is True
+ assert backup_path.exists()
+
+ # 3. Close the original DB connection
+ db_instance.close_connection()
+
+ # 4. Open the backup database as a new instance
+ backup_db = CharactersRAGDB(backup_path, "restore_client")
+
+ # 5. Verify the data
+ restored_card = backup_db.get_character_card_by_id(card_id)
+ assert restored_card is not None
+
+ # Compare the entire dictionaries
+ # sqlite.Row objects need to be converted to dicts for direct comparison
+ assert dict(restored_card) == dict(original_card)
+
+ # Also check list methods
+ all_cards = backup_db.list_character_cards()
+ # Remember the default card!
+ assert len(all_cards) == 2
+
+ backup_db.close_connection()
+
+
+#
+# End of test_chachanotes_db_properties.py
+########################################################################################################################
diff --git a/Tests/Character_Chat/test_character_chat.py b/Tests/Character_Chat/test_character_chat.py
index c1089c0e..e76500eb 100644
--- a/Tests/Character_Chat/test_character_chat.py
+++ b/Tests/Character_Chat/test_character_chat.py
@@ -1,155 +1,323 @@
-# test_property_character_chat_lib.py
-
-import unittest
-from hypothesis import given, strategies as st, settings, HealthCheck
-import re
-
-import Character_Chat_Lib as ccl
-
-class TestCharacterChatLibProperty(unittest.TestCase):
-
- # --- replace_placeholders ---
- @given(text=st.text(),
- char_name=st.one_of(st.none(), st.text(min_size=1, max_size=50)),
- user_name=st.one_of(st.none(), st.text(min_size=1, max_size=50)))
- @settings(suppress_health_check=[HealthCheck.too_slow])
- def test_replace_placeholders_properties(self, text, char_name, user_name):
- processed = ccl.replace_placeholders(text, char_name, user_name)
- self.assertIsInstance(processed, str)
-
- expected_char = char_name if char_name else "Character"
- expected_user = user_name if user_name else "User"
-
- if "{{char}}" in text:
- self.assertIn(expected_char, processed)
- if "{{user}}" in text:
- self.assertIn(expected_user, processed)
- if "" in text:
- self.assertIn(expected_char, processed)
- if "" in text:
- self.assertIn(expected_user, processed)
-
- # If no placeholders, text should be identical
- placeholders = ['{{char}}', '{{user}}', '{{random_user}}', '', '']
- if not any(p in text for p in placeholders):
- self.assertEqual(processed, text)
-
- @given(text=st.one_of(st.none(), st.just("")), char_name=st.text(), user_name=st.text())
- def test_replace_placeholders_empty_or_none_input_text(self, text, char_name, user_name):
- self.assertEqual(ccl.replace_placeholders(text, char_name, user_name), "")
-
- # --- replace_user_placeholder ---
- @given(history=st.lists(st.tuples(st.one_of(st.none(), st.text()), st.one_of(st.none(), st.text()))),
- user_name=st.one_of(st.none(), st.text(min_size=1, max_size=50)))
- @settings(suppress_health_check=[HealthCheck.too_slow])
- def test_replace_user_placeholder_properties(self, history, user_name):
- processed_history = ccl.replace_user_placeholder(history, user_name)
- self.assertEqual(len(processed_history), len(history))
- expected_user = user_name if user_name else "User"
-
- for i, (original_user_msg, original_bot_msg) in enumerate(history):
- processed_user_msg, processed_bot_msg = processed_history[i]
- if original_user_msg is not None:
- self.assertIsInstance(processed_user_msg, str)
- if "{{user}}" in original_user_msg:
- self.assertIn(expected_user, processed_user_msg)
- else:
- self.assertEqual(processed_user_msg, original_user_msg)
- else:
- self.assertIsNone(processed_user_msg)
-
- if original_bot_msg is not None:
- self.assertIsInstance(processed_bot_msg, str)
- if "{{user}}" in original_bot_msg:
- self.assertIn(expected_user, processed_bot_msg)
- else:
- self.assertEqual(processed_bot_msg, original_bot_msg)
- else:
- self.assertIsNone(processed_bot_msg)
-
- # --- extract_character_id_from_ui_choice ---
- @given(name=st.text(alphabet=st.characters(min_codepoint=65, max_codepoint=122), min_size=1, max_size=20).filter(lambda x: '(' not in x and ')' not in x),
- id_val=st.integers(min_value=0, max_value=10**9))
- def test_extract_id_format_name_id(self, name, id_val):
- choice = f"{name} (ID: {id_val})"
- self.assertEqual(ccl.extract_character_id_from_ui_choice(choice), id_val)
-
- @given(id_val=st.integers(min_value=0, max_value=10**9))
- def test_extract_id_format_just_id(self, id_val):
- choice = str(id_val)
- self.assertEqual(ccl.extract_character_id_from_ui_choice(choice), id_val)
-
- @given(text=st.text().filter(lambda x: not re.search(r'\(\s*ID\s*:\s*\d+\s*\)\s*$', x) and not x.isdigit() and x != ""))
- def test_extract_id_invalid_format_raises_valueerror(self, text):
- with self.assertRaises(ValueError):
- ccl.extract_character_id_from_ui_choice(text)
-
- @given(choice=st.just(""))
- def test_extract_id_empty_string_raises_valueerror(self, choice):
- with self.assertRaises(ValueError):
- ccl.extract_character_id_from_ui_choice(choice)
-
- # --- process_db_messages_to_ui_history ---
- # This one is complex for property-based testing due to stateful accumulation.
- # We can test some basic properties.
- @given(db_messages=st.lists(st.fixed_dictionaries({
- 'sender': st.sampled_from(["User", "TestChar", "OtherSender"]),
- 'content': st.text(max_size=100)
- }), max_size=10),
- char_name=st.text(min_size=1, max_size=20),
- user_name=st.one_of(st.none(), st.text(min_size=1, max_size=20)))
- @settings(suppress_health_check=[HealthCheck.too_slow, HealthCheck.filter_too_much])
- def test_process_db_messages_to_ui_history_output_structure(self, db_messages, char_name, user_name):
- if not db_messages: # Avoid issues with empty messages list if logic depends on non-empty
- return
-
- ui_history = ccl.process_db_messages_to_ui_history(db_messages, char_name, user_name,
- actual_char_sender_id_in_db=char_name) # Map TestChar to char_name
- self.assertIsInstance(ui_history, list)
- for item in ui_history:
- self.assertIsInstance(item, tuple)
- self.assertEqual(len(item), 2)
- self.assertTrue(item[0] is None or isinstance(item[0], str))
- self.assertTrue(item[1] is None or isinstance(item[1], str))
-
- # If all messages are from User, bot messages in UI should be None
- if all(msg['sender'] == "User" for msg in db_messages):
- for _, bot_msg in ui_history:
- self.assertIsNone(bot_msg)
-
- # If all messages are from Character, user messages in UI should be None
- if all(msg['sender'] == char_name for msg in db_messages):
- for user_msg, _ in ui_history:
- self.assertIsNone(user_msg)
-
- # --- Card Validation Properties (Example for validate_character_book_entry) ---
- # Strategy for a valid character book entry core
- valid_entry_core_st = st.fixed_dictionaries({
- 'keys': st.lists(st.text(min_size=1, max_size=50), min_size=1, max_size=5),
- 'content': st.text(min_size=1, max_size=200),
- 'enabled': st.booleans(),
- 'insertion_order': st.integers()
- })
-
- @given(entry_core=valid_entry_core_st,
- entry_id_set=st.sets(st.integers(min_value=0, max_value=1000)))
- def test_validate_character_book_entry_valid_core(self, entry_core, entry_id_set):
- is_valid, errors = ccl.validate_character_book_entry(entry_core, 0, entry_id_set)
- self.assertTrue(is_valid, f"Errors for supposedly valid core: {errors}")
- self.assertEqual(len(errors), 0)
-
- @given(entry_core=valid_entry_core_st,
- bad_key_type=st.integers(), # Make keys not a list of strings
- entry_id_set=st.sets(st.integers()))
- def test_validate_character_book_entry_invalid_keys_type(self, entry_core, bad_key_type, entry_id_set):
- invalid_entry = {**entry_core, 'keys': bad_key_type}
- is_valid, errors = ccl.validate_character_book_entry(invalid_entry, 0, entry_id_set)
- self.assertFalse(is_valid)
- self.assertTrue(any("Field 'keys' must be of type 'list'" in e for e in errors))
-
- # More properties can be added for other parsing/validation functions if they
- # have clear invariants that can be tested with generated data.
-
-
-if __name__ == '__main__':
- unittest.main(argv=['first-arg-is-ignored'], exit=False)
\ No newline at end of file
+# test_character_chat_lib.py
+
+import pytest
+import sqlite3
+import json
+import uuid
+import base64
+import io
+from datetime import datetime, timezone
+from pathlib import Path
+from unittest.mock import patch, MagicMock
+
+from PIL import Image
+
+# Local Imports from this project
+from tldw_chatbook.DB.ChaChaNotes_DB import (
+ CharactersRAGDB,
+ CharactersRAGDBError,
+ ConflictError,
+ InputError
+)
+from tldw_chatbook.Character_Chat.Character_Chat_Lib import (
+ create_conversation,
+ get_conversation_details_and_messages,
+ add_message_to_conversation,
+ get_character_list_for_ui,
+ extract_character_id_from_ui_choice,
+ load_character_and_image,
+ process_db_messages_to_ui_history,
+ load_chat_and_character,
+ load_character_wrapper,
+ import_and_save_character_from_file,
+ load_chat_history_from_file_and_save_to_db,
+ start_new_chat_session,
+ list_character_conversations,
+ update_conversation_metadata,
+ post_message_to_conversation,
+ retrieve_conversation_messages_for_ui
+)
+
+
+# --- Standalone Fixtures (No conftest.py) ---
+
+@pytest.fixture
+def client_id():
+ """Provides a consistent client ID for tests."""
+ return "test_lib_client_001"
+
+
+@pytest.fixture
+def db_path(tmp_path):
+ """Provides a temporary path for the database file for each test."""
+ return tmp_path / "test_lib_db.sqlite"
+
+
+@pytest.fixture(scope="function")
+def db_instance(db_path, client_id):
+ """Creates a DB instance for each test, ensuring a fresh database."""
+ db = CharactersRAGDB(db_path, client_id)
+ yield db
+ db.close_connection()
+
+
+# --- Helper Functions ---
+
+def create_dummy_png_bytes() -> bytes:
+ """Creates a 1x1 black PNG image in memory."""
+ img = Image.new('RGB', (1, 1), color='black')
+ byte_arr = io.BytesIO()
+ img.save(byte_arr, format='PNG')
+ return byte_arr.getvalue()
+
+
+# FIXME
+# def create_dummy_png_with_chara(chara_json_str: str) -> bytes:
+# """Creates a 1x1 PNG with embedded 'chara' metadata."""
+# img = Image.new('RGB', (1, 1), color='red')
+# # The 'chara' metadata is a base64 encoded string of the JSON
+# chara_b64 = base64.b64encode(chara_json_str.encode('utf-8')).decode('utf-8')
+#
+# byte_arr = io.BytesIO()
+# # Pillow saves metadata in the 'info' dictionary for PNGs
+# img.save(byte_arr, format='PNG', pnginfo=Image.PngImagePlugin.PngInfo())
+# byte_arr.seek(0)
+#
+# # Re-open to add the custom chunk, as 'info' doesn't directly map to chunks
+# img_with_info = Image.open(byte_arr)
+# img_with_info.info['chara'] = chara_b64
+#
+# final_byte_arr = io.BytesIO()
+# img_with_info.save(final_byte_arr, format='PNG')
+# return final_byte_arr.getvalue()
+
+
+def create_sample_v2_card_json(name: str) -> str:
+ """Returns a V2 character card as a JSON string."""
+ card = {
+ "spec": "chara_card_v2",
+ "spec_version": "2.0",
+ "data": {
+ "name": name,
+ "description": "A test character from a V2 card.",
+ "personality": "Curious",
+ "scenario": "In a test.",
+ "first_mes": "Hello from the V2 card!",
+ "mes_example": "This is an example message.",
+ "creator_notes": "",
+ "system_prompt": "",
+ "post_history_instructions": "",
+ "tags": ["v2", "test"],
+ "creator": "Tester",
+ "character_version": "1.0",
+ "alternate_greetings": ["Hi there!", "Greetings!"]
+ }
+ }
+ return json.dumps(card)
+
+
+# --- Test Classes ---
+
+class TestConversationManagement:
+ def test_create_conversation_with_defaults(self, db_instance: CharactersRAGDB):
+ # Relies on the default character (ID 1) created by the schema
+ conv_id = create_conversation(db_instance)
+ assert conv_id is not None
+
+ details = db_instance.get_conversation_by_id(conv_id)
+ assert details is not None
+ assert details['character_id'] == 1 # DEFAULT_CHARACTER_ID
+ assert "Chat with Default Assistant" in details['title']
+
+ def test_create_conversation_with_initial_messages(self, db_instance: CharactersRAGDB):
+ char_id = db_instance.add_character_card({"name": "Talker"})
+ initial_messages = [
+ {'sender': 'User', 'content': 'Hi there!'},
+ {'sender': 'AI', 'content': 'Hello, User!'}
+ ]
+ conv_id = create_conversation(db_instance, character_id=char_id, initial_messages=initial_messages)
+ assert conv_id is not None
+
+ messages = db_instance.get_messages_for_conversation(conv_id)
+ assert len(messages) == 2
+ assert messages[0]['sender'] == 'User'
+ assert messages[1]['sender'] == 'Talker' # Sender 'AI' is mapped to character name
+
+ def test_get_conversation_details_and_messages(self, db_instance: CharactersRAGDB):
+ char_id = db_instance.add_character_card({"name": "DetailedChar"})
+ conv_id = create_conversation(db_instance, character_id=char_id, title="Test Details")
+ add_message_to_conversation(db_instance, conv_id, "User", "A message")
+
+ details = get_conversation_details_and_messages(db_instance, conv_id)
+ assert details is not None
+ assert details['metadata']['title'] == "Test Details"
+ assert details['character_name'] == "DetailedChar"
+ assert len(details['messages']) == 1
+ assert details['messages'][0]['content'] == "A message"
+
+
+class TestCharacterLoading:
+ def test_extract_character_id_from_ui_choice(self):
+ assert extract_character_id_from_ui_choice("My Character (ID: 123)") == 123
+ assert extract_character_id_from_ui_choice("456") == 456
+ with pytest.raises(ValueError):
+ extract_character_id_from_ui_choice("Invalid Format")
+ with pytest.raises(ValueError):
+ extract_character_id_from_ui_choice("")
+
+ def test_load_character_and_image(self, db_instance: CharactersRAGDB):
+ image_bytes = create_dummy_png_bytes()
+ card_data = {"name": "ImgChar", "first_message": "Hello, {{user}}!", "image": image_bytes}
+ char_id = db_instance.add_character_card(card_data)
+
+ char_data, history, img = load_character_and_image(db_instance, char_id, "Tester")
+
+ assert char_data is not None
+ assert char_data['name'] == "ImgChar"
+ assert len(history) == 1
+ assert history[0] == (None, "Hello, Tester!") # Placeholder replaced
+ assert isinstance(img, Image.Image)
+
+ def test_process_db_messages_to_ui_history(self):
+ db_messages = [
+ {"sender": "User", "content": "Msg 1"},
+ {"sender": "MyChar", "content": "Reply 1"},
+ {"sender": "User", "content": "Msg 2a"},
+ {"sender": "User", "content": "Msg 2b"},
+ {"sender": "MyChar", "content": "Reply 2"},
+ ]
+ history = process_db_messages_to_ui_history(db_messages, "MyChar", "TestUser")
+ expected = [
+ ("Msg 1", "Reply 1"),
+ ("Msg 2a", None),
+ ("Msg 2b", "Reply 2"),
+ ]
+ assert history == expected
+
+ def test_load_chat_and_character(self, db_instance: CharactersRAGDB):
+ char_id = db_instance.add_character_card({"name": "Chatter", "first_message": "Hi"})
+ conv_id = create_conversation(db_instance, character_id=char_id)
+ add_message_to_conversation(db_instance, conv_id, "User", "Hello there")
+ add_message_to_conversation(db_instance, conv_id, "Chatter", "General Kenobi")
+
+ char_data, history, img = load_chat_and_character(db_instance, conv_id, "TestUser")
+
+ assert char_data['name'] == "Chatter"
+ assert len(history) == 2 # Initial 'Hi' plus the two added messages
+ assert history[0] == (None, 'Hi')
+ assert history[1] == ("Hello there", "General Kenobi")
+
+ def test_load_character_wrapper(self, db_instance: CharactersRAGDB):
+ char_id = db_instance.add_character_card({"name": "WrappedChar"})
+ # Test with int ID
+ char_data_int, _, _ = load_character_wrapper(db_instance, char_id, "User")
+ assert char_data_int['id'] == char_id
+ # Test with UI string
+ char_data_str, _, _ = load_character_wrapper(db_instance, f"WrappedChar (ID: {char_id})", "User")
+ assert char_data_str['id'] == char_id
+
+
+class TestCharacterImport:
+ def test_import_from_json_string(self, db_instance: CharactersRAGDB):
+ card_name = "ImportedV2Char"
+ v2_json = create_sample_v2_card_json(card_name)
+
+ file_obj = io.BytesIO(v2_json.encode('utf-8'))
+ char_id = import_and_save_character_from_file(db_instance, file_obj)
+
+ assert char_id is not None
+ retrieved = db_instance.get_character_card_by_id(char_id)
+ assert retrieved['name'] == card_name
+ assert retrieved['description'] == "A test character from a V2 card."
+
+ def test_import_from_png_with_chara_metadata(self, db_instance: CharactersRAGDB):
+ card_name = "PngChar"
+ v2_json = create_sample_v2_card_json(card_name)
+ png_bytes = create_dummy_png_with_chara(v2_json)
+
+ file_obj = io.BytesIO(png_bytes)
+ char_id = import_and_save_character_from_file(db_instance, file_obj)
+
+ assert char_id is not None
+ retrieved = db_instance.get_character_card_by_id(char_id)
+ assert retrieved['name'] == card_name
+ assert retrieved['image'] is not None
+
+ def test_import_chat_history_and_save(self, db_instance: CharactersRAGDB):
+ # 1. Create the character that the chat log refers to
+ char_name = "LogChar"
+ char_id = db_instance.add_character_card({"name": char_name})
+
+ # 2. Create a sample chat log JSON
+ chat_log = {
+ "char_name": char_name,
+ "history": {
+ "internal": [
+ ["Hello", "Hi there"],
+ ["How are you?", "I am a test, I am fine."]
+ ]
+ }
+ }
+ chat_log_json = json.dumps(chat_log)
+ file_obj = io.BytesIO(chat_log_json.encode('utf-8'))
+
+ # 3. Import the log
+ conv_id, new_char_id = load_chat_history_from_file_and_save_to_db(db_instance, file_obj)
+
+ assert conv_id is not None
+ assert new_char_id == char_id
+
+ # 4. Verify messages were saved
+ messages = db_instance.get_messages_for_conversation(conv_id)
+ assert len(messages) == 4
+ assert messages[0]['content'] == "Hello"
+ assert messages[1]['content'] == "Hi there"
+ assert messages[1]['sender'] == char_name
+
+
+class TestHighLevelChatFlow:
+ def test_start_new_chat_session(self, db_instance: CharactersRAGDB):
+ char_id = db_instance.add_character_card({"name": "SessionStart", "first_message": "Greetings!"})
+
+ conv_id, char_data, ui_history, img = start_new_chat_session(db_instance, char_id, "TestUser")
+
+ assert conv_id is not None
+ assert char_data['name'] == "SessionStart"
+ assert ui_history == [(None, "Greetings!")]
+
+ # Verify conversation and first message were saved to DB
+ conv_details = db_instance.get_conversation_by_id(conv_id)
+ assert conv_details['character_id'] == char_id
+
+ messages = db_instance.get_messages_for_conversation(conv_id)
+ assert len(messages) == 1
+ assert messages[0]['content'] == "Greetings!"
+ assert messages[0]['sender'] == "SessionStart"
+
+ def test_post_message_to_conversation(self, db_instance: CharactersRAGDB):
+ char_id = db_instance.add_character_card({"name": "Poster"})
+ conv_id = db_instance.add_conversation({"character_id": char_id})
+
+ # Post user message
+ user_msg_id = post_message_to_conversation(db_instance, conv_id, "Poster", "User msg", is_user_message=True)
+ assert user_msg_id is not None
+ # Post character message
+ char_msg_id = post_message_to_conversation(db_instance, conv_id, "Poster", "Char reply", is_user_message=False)
+ assert char_msg_id is not None
+
+ messages = db_instance.get_messages_for_conversation(conv_id)
+ assert len(messages) == 2
+ assert messages[0]['sender'] == "User"
+ assert messages[1]['sender'] == "Poster"
+
+ def test_retrieve_conversation_messages_for_ui(self, db_instance: CharactersRAGDB):
+ char_id = db_instance.add_character_card({"name": "UIRetriever"})
+ conv_id = start_new_chat_session(db_instance, char_id, "TestUser")[0] # Get conv_id
+ post_message_to_conversation(db_instance, conv_id, "UIRetriever", "User message", True)
+ post_message_to_conversation(db_instance, conv_id, "UIRetriever", "Bot reply", False)
+
+ ui_history = retrieve_conversation_messages_for_ui(db_instance, conv_id, "UIRetriever", "TestUser")
+
+ # Initial message from start_new_chat_session + 1 pair
+ assert len(ui_history) == 2
+ assert ui_history[1] == ("User message", "Bot reply")
\ No newline at end of file
diff --git a/Tests/Chat/test_chat_functions.py b/Tests/Chat/test_chat_functions.py
index 690a6e33..0bec02a6 100644
--- a/Tests/Chat/test_chat_functions.py
+++ b/Tests/Chat/test_chat_functions.py
@@ -1,739 +1,338 @@
-# tests/unit/core/chat/test_chat_functions.py
-# Description: Unit tests for chat functions in the tldw_app.Chat module.
+# test_chat_functions.py
#
# Imports
+import pytest
import base64
-import re
-import os # For tmp_path with post_gen_replacement_dict
-import textwrap
+import io
from unittest.mock import patch, MagicMock
#
# 3rd-party Libraries
-import pytest
-from hypothesis import given, strategies as st, settings, HealthCheck
-from typing import Optional
-import requests # For mocking requests.exceptions
+import requests
+from PIL import Image
#
# Local Imports
-# Ensure these paths are correct for your project structure
+from tldw_chatbook.DB.ChaChaNotes_DB import CharactersRAGDB, ConflictError, InputError
from tldw_chatbook.Chat.Chat_Functions import (
chat_api_call,
chat,
save_chat_history_to_db_wrapper,
- API_CALL_HANDLERS,
- PROVIDER_PARAM_MAP,
- load_characters,
save_character,
- ChatDictionary,
+ load_characters,
+ get_character_names,
parse_user_dict_markdown_file,
- process_user_input, # Added for direct testing if needed
- # Import the actual LLM handler functions if you intend to test them directly (though mostly tested via chat_api_call)
- # e.g., _chat_with_openai_compatible_local_server, chat_with_kobold, etc.
- # For unit testing chat_api_call, we mock these handlers.
- # For unit testing chat, we mock chat_api_call.
+ process_user_input,
+ ChatDictionary,
+ DEFAULT_CHARACTER_NAME
)
from tldw_chatbook.Chat.Chat_Deps import (
- ChatAuthenticationError, ChatRateLimitError, ChatBadRequestError,
- ChatConfigurationError, ChatProviderError, ChatAPIError
+ ChatBadRequestError,
+ ChatAuthenticationError,
+ ChatRateLimitError,
+ ChatProviderError,
+ ChatAPIError
)
-from tldw_chatbook.DB.ChaChaNotes_DB import CharactersRAGDB # For mocking
-
-# Placeholder for load_settings if it's not directly in Chat_Functions or needs specific mocking path
-# from tldw_app.Chat.Chat_Functions import load_settings # Already imported effectively
-
-# Define a common set of known providers for hypothesis strategies
-KNOWN_PROVIDERS = list(API_CALL_HANDLERS.keys())
-if not KNOWN_PROVIDERS: # Should not happen if API_CALL_HANDLERS is populated
- KNOWN_PROVIDERS = ["openai", "anthropic", "ollama"] # Fallback for safety
-
-########################################################################################################################
#
-# Fixtures
-
-class ChatErrorBase(Exception):
- def __init__(self, provider: str, message: str, status_code: Optional[int] = None):
- self.provider = provider
- self.message = message
- self.status_code = status_code
- super().__init__(f"[{provider}] {message}" + (f" (HTTP {status_code})" if status_code else ""))
-
-@pytest.fixture(autouse=True)
-def mock_load_settings_globally():
- """
- Mocks load_settings used by Chat_Functions.py (e.g., in LLM handlers and chat function for post-gen dict).
- This ensures that tests don't rely on actual configuration files.
- """
- # This default mock should provide minimal valid config for most providers
- # to avoid KeyError if a handler tries to access its specific config section.
- default_provider_config = {
- "api_ip": "http://mock.host:1234",
- "api_url": "http://mock.host:1234",
- "api_key": "mock_api_key_from_settings",
- "model": "mock_model_from_settings",
- "temperature": 0.7,
- "streaming": False,
- "max_tokens": 1024,
- "n_predict": 1024, # for llama
- "num_predict": 1024, # for ollama
- "max_length": 150, # for kobold
- "api_timeout": 60,
- "api_retries": 1,
- "api_retry_delay": 1,
- # Add other common keys that might be accessed to avoid KeyErrors in handlers
- }
- mock_settings_data = {
- "chat_dictionaries": { # For `chat` function's post-gen replacement
- "post_gen_replacement": "False", # Default to off unless a test enables it
- "post_gen_replacement_dict": "dummy_path.md"
- },
- # Provide a default config for all known providers to prevent KeyErrors
- # when handlers try to load their specific sections via cfg.get(...)
- **{f"{provider}_api": default_provider_config.copy() for provider in KNOWN_PROVIDERS},
- # Specific overrides if needed by a default test case
- "local_llm": default_provider_config.copy(),
- "llama_api": default_provider_config.copy(),
- "kobold_api": default_provider_config.copy(),
- "ooba_api": default_provider_config.copy(),
- "tabby_api": default_provider_config.copy(),
- "vllm_api": default_provider_config.copy(),
- "aphrodite_api": default_provider_config.copy(),
- "ollama_api": default_provider_config.copy(),
- "custom_openai_api": default_provider_config.copy(),
- "custom_openai_api_2": default_provider_config.copy(),
- "openai": default_provider_config.copy(), # For actual OpenAI if it uses load_settings
- "anthropic": default_provider_config.copy(),
- }
- # Ensure all keys from PROVIDER_PARAM_MAP also have a generic config section if a handler uses it
- for provider_key in KNOWN_PROVIDERS:
- if provider_key not in mock_settings_data: # e.g. if provider is 'openai' not 'openai_api'
- mock_settings_data[provider_key] = default_provider_config.copy()
-
-
- with patch("tldw_app.Chat.Chat_Functions.load_settings", return_value=mock_settings_data) as mock_load:
- yield mock_load
+#######################################################################################################################
+#
+# --- Standalone Fixtures (No conftest.py) ---
+
+@pytest.fixture
+def client_id():
+ """Provides a consistent client ID for tests."""
+ return "test_chat_func_client"
@pytest.fixture
-def mock_llm_handlers(): # Renamed for clarity, same functionality
- original_handlers = API_CALL_HANDLERS.copy()
- mocked_handlers_dict = {}
- for provider_name, original_func in original_handlers.items():
- mock_handler = MagicMock(name=f"mock_{getattr(original_func, '__name__', provider_name)}")
- # Preserve the signature or relevant attributes if necessary, but MagicMock is often enough
- mock_handler.__name__ = getattr(original_func, '__name__', f"mock_{provider_name}_handler")
- mocked_handlers_dict[provider_name] = mock_handler
-
- with patch("tldw_app.Chat.Chat_Functions.API_CALL_HANDLERS", new=mocked_handlers_dict):
- yield mocked_handlers_dict
-
-# --- Tests for chat_api_call ---
-@pytest.mark.unit
-def test_chat_api_call_routing_and_param_mapping_openai(mock_llm_handlers):
- provider = "openai"
- mock_openai_handler = mock_llm_handlers[provider]
- mock_openai_handler.return_value = "OpenAI success"
-
- args = {
- "api_endpoint": provider,
- "messages_payload": [{"role": "user", "content": "Hi OpenAI"}],
- "api_key": "test_openai_key",
- "temp": 0.5,
- "system_message": "Be concise.",
- "streaming": False,
- "maxp": 0.9, # Generic name, maps to top_p for openai
- "model": "gpt-4o-mini",
- "tools": [{"type": "function", "function": {"name": "get_weather"}}],
- "tool_choice": "auto",
- "seed": 123,
- "response_format": {"type": "json_object"},
- "logit_bias": {"123": 10},
- }
- result = chat_api_call(**args)
- assert result == "OpenAI success"
- mock_openai_handler.assert_called_once()
- called_kwargs = mock_openai_handler.call_args.kwargs
-
- param_map_for_provider = PROVIDER_PARAM_MAP[provider]
-
- # Check that generic params were mapped to provider-specific names in the handler call
- assert called_kwargs[param_map_for_provider['messages_payload']] == args["messages_payload"]
- assert called_kwargs[param_map_for_provider['api_key']] == args["api_key"]
- assert called_kwargs[param_map_for_provider['temp']] == args["temp"]
- assert called_kwargs[param_map_for_provider['system_message']] == args["system_message"]
- assert called_kwargs[param_map_for_provider['streaming']] == args["streaming"]
- assert called_kwargs[param_map_for_provider['maxp']] == args["maxp"] # 'maxp' generic maps to 'top_p' (OpenAI specific)
- assert called_kwargs[param_map_for_provider['model']] == args["model"]
- assert called_kwargs[param_map_for_provider['tools']] == args["tools"]
- assert called_kwargs[param_map_for_provider['tool_choice']] == args["tool_choice"]
- assert called_kwargs[param_map_for_provider['seed']] == args["seed"]
- assert called_kwargs[param_map_for_provider['response_format']] == args["response_format"]
- assert called_kwargs[param_map_for_provider['logit_bias']] == args["logit_bias"]
-
-
-@pytest.mark.unit
-def test_chat_api_call_routing_and_param_mapping_anthropic(mock_llm_handlers):
- provider = "anthropic"
- mock_anthropic_handler = mock_llm_handlers[provider]
- mock_anthropic_handler.return_value = "Anthropic success"
-
- args = {
- "api_endpoint": provider,
- "messages_payload": [{"role": "user", "content": "Hi Anthropic"}],
- "api_key": "test_anthropic_key",
- "temp": 0.6,
- "system_message": "Be friendly.",
- "streaming": True,
- "model": "claude-3-opus-20240229",
- "topp": 0.92, # Generic name, maps to top_p for anthropic
- "topk": 50,
- "max_tokens": 100, # Generic, maps to max_tokens for anthropic
- "stop": ["\nHuman:", "\nAssistant:"] # Generic, maps to stop_sequences
- }
- result = chat_api_call(**args)
- assert result == "Anthropic success"
- mock_anthropic_handler.assert_called_once()
- called_kwargs = mock_anthropic_handler.call_args.kwargs
-
- param_map = PROVIDER_PARAM_MAP[provider]
- assert called_kwargs[param_map['messages_payload']] == args["messages_payload"]
- assert called_kwargs[param_map['api_key']] == args["api_key"]
- assert called_kwargs[param_map['temp']] == args["temp"]
- assert called_kwargs[param_map['system_message']] == args["system_message"] # maps to 'system_prompt'
- assert called_kwargs[param_map['streaming']] == args["streaming"]
- assert called_kwargs[param_map['model']] == args["model"]
- assert called_kwargs[param_map['topp']] == args["topp"]
- assert called_kwargs[param_map['topk']] == args["topk"]
- assert called_kwargs[param_map['max_tokens']] == args["max_tokens"]
- assert called_kwargs[param_map['stop']] == args["stop"]
-
-
-@pytest.mark.unit
-def test_chat_api_call_unsupported_provider():
- with pytest.raises(ValueError, match="Unsupported API endpoint: non_existent_provider"):
- chat_api_call(api_endpoint="non_existent_provider", messages_payload=[])
-
-
-@pytest.mark.unit
-@pytest.mark.parametrize("raised_exception, expected_custom_error_type, expected_status_code_in_error", [
- (requests.exceptions.HTTPError(response=MagicMock(status_code=401, text="Auth error text")), ChatAuthenticationError, 401),
- (requests.exceptions.HTTPError(response=MagicMock(status_code=429, text="Rate limit text")), ChatRateLimitError, 429),
- (requests.exceptions.HTTPError(response=MagicMock(status_code=400, text="Bad req text")), ChatBadRequestError, 400),
- (requests.exceptions.HTTPError(response=MagicMock(status_code=503, text="Provider down text")), ChatProviderError, 503),
- (requests.exceptions.ConnectionError("Network fail"), ChatProviderError, 504), # Default for RequestException
- (ValueError("Internal value error"), ChatBadRequestError, None), # Status code might not be set by default for these
- (TypeError("Internal type error"), ChatBadRequestError, None),
- (KeyError("Internal key error"), ChatBadRequestError, None),
- (ChatConfigurationError("config issue", provider="openai"), ChatConfigurationError, None), # Direct raise
- (Exception("Very generic error"), ChatAPIError, 500),
-])
-def test_chat_api_call_exception_mapping(
- mock_llm_handlers,
- raised_exception, expected_custom_error_type, expected_status_code_in_error
-):
- provider_to_test = "openai" # Use any valid provider name that is mocked
- mock_handler = mock_llm_handlers[provider_to_test]
- mock_handler.side_effect = raised_exception
-
- with pytest.raises(expected_custom_error_type) as exc_info:
- chat_api_call(api_endpoint=provider_to_test, messages_payload=[{"role": "user", "content": "test"}])
-
- assert exc_info.value.provider == provider_to_test
- if expected_status_code_in_error is not None and hasattr(exc_info.value, 'status_code'):
- assert exc_info.value.status_code == expected_status_code_in_error
-
- # Check that original error message part is in the custom error message if applicable
- if hasattr(raised_exception, 'response') and hasattr(raised_exception.response, 'text'):
- assert raised_exception.response.text[:100] in exc_info.value.message # Check beginning of text
- elif not isinstance(raised_exception, ChatErrorBase): # Don't double-check message for already custom errors
- assert str(raised_exception) in exc_info.value.message
-
-
-# --- Tests for the `chat` function (multimodal chat coordinator) ---
-
-@pytest.mark.unit
-@patch("tldw_app.Chat.Chat_Functions.chat_api_call")
-@patch("tldw_app.Chat.Chat_Functions.process_user_input", side_effect=lambda text, *args, **kwargs: text)
-# mock_load_settings_globally is active via autouse=True
-def test_chat_function_basic_text_call(mock_process_input, mock_chat_api_call_shim):
- mock_chat_api_call_shim.return_value = "LLM Response from chat function"
-
- response = chat(
- message="Hello LLM",
- history=[],
- media_content=None, selected_parts=[], api_endpoint="test_provider_for_chat",
- api_key="test_key_for_chat", custom_prompt="Be very brief.", temperature=0.1,
- system_message="You are a test bot for chat.",
- llm_seed=42, llm_max_tokens=100, llm_user_identifier="user123"
- )
- assert response == "LLM Response from chat function"
- mock_chat_api_call_shim.assert_called_once()
- call_args = mock_chat_api_call_shim.call_args.kwargs
-
- assert call_args["api_endpoint"] == "test_provider_for_chat"
- assert call_args["api_key"] == "test_key_for_chat"
- assert call_args["temp"] == 0.1
- assert call_args["system_message"] == "You are a test bot for chat."
- assert call_args["seed"] == 42 # Check new llm_param
- assert call_args["max_tokens"] == 100 # Check new llm_param
- assert call_args["user_identifier"] == "user123" # Check new llm_param
-
-
- payload = call_args["messages_payload"]
- assert len(payload) == 1
- assert payload[0]["role"] == "user"
- assert isinstance(payload[0]["content"], list)
- assert len(payload[0]["content"]) == 1
- assert payload[0]["content"][0]["type"] == "text"
- # Custom prompt is prepended to user message
- assert payload[0]["content"][0]["text"] == "Be very brief.\n\nHello LLM"
-
-
-@pytest.mark.unit
-@patch("tldw_app.Chat.Chat_Functions.chat_api_call")
-@patch("tldw_app.Chat.Chat_Functions.process_user_input", side_effect=lambda x, *a, **kw: x)
-def test_chat_function_with_text_history(mock_process_input, mock_chat_api_call_shim):
- mock_chat_api_call_shim.return_value = "LLM Response with history"
- history_for_chat_func = [
- {"role": "user", "content": "Previous question?"}, # Will be wrapped
- {"role": "assistant", "content": [{"type": "text", "text": "Previous answer."}]} # Already wrapped
- ]
- response = chat(
- message="New question", history=history_for_chat_func, media_content=None,
- selected_parts=[], api_endpoint="hist_provider", api_key="hist_key",
- custom_prompt=None, temperature=0.2, system_message="Sys History"
- )
- assert response == "LLM Response with history"
- mock_chat_api_call_shim.assert_called_once()
- payload = mock_chat_api_call_shim.call_args.kwargs["messages_payload"]
- assert len(payload) == 3
- assert payload[0]["content"][0]["type"] == "text"
- assert payload[0]["content"][0]["text"] == "Previous question?"
- assert payload[1]["content"][0]["type"] == "text"
- assert payload[1]["content"][0]["text"] == "Previous answer."
- assert payload[2]["content"][0]["type"] == "text"
- assert payload[2]["content"][0]["text"] == "New question"
-
-
-@pytest.mark.unit
-@patch("tldw_app.Chat.Chat_Functions.chat_api_call")
-@patch("tldw_app.Chat.Chat_Functions.process_user_input", side_effect=lambda x, *a, **kw: x)
-def test_chat_function_with_current_image(mock_process_input, mock_chat_api_call_shim):
- mock_chat_api_call_shim.return_value = "LLM image Response"
- current_image = {"base64_data": "fakeb64imagedata", "mime_type": "image/png"}
-
- response = chat(
- message="What is this image?", history=[], media_content=None, selected_parts=[],
- api_endpoint="img_provider", api_key="img_key", custom_prompt=None, temperature=0.3,
- current_image_input=current_image
- )
- assert response == "LLM image Response"
- mock_chat_api_call_shim.assert_called_once()
- payload = mock_chat_api_call_shim.call_args.kwargs["messages_payload"]
- assert len(payload) == 1
- user_content_parts = payload[0]["content"]
- assert isinstance(user_content_parts, list)
- text_part_found = any(p["type"] == "text" and p["text"] == "What is this image?" for p in user_content_parts)
- image_part_found = any(p["type"] == "image_url" and p["image_url"]["url"] == "" for p in user_content_parts)
- assert text_part_found and image_part_found
- assert len(user_content_parts) == 2 # one text, one image
-
-
-@pytest.mark.unit
-@patch("tldw_app.Chat.Chat_Functions.chat_api_call")
-@patch("tldw_app.Chat.Chat_Functions.process_user_input", side_effect=lambda x, *a, **kw: x)
-def test_chat_function_image_history_tag_past(mock_process_input, mock_chat_api_call_shim):
- mock_chat_api_call_shim.return_value = "Tagged image history response"
- history_with_image = [
- {"role": "user", "content": [
- {"type": "text", "text": "Here is an image."},
- {"type": "image_url", "image_url": {"url": ""}}
- ]},
- {"role": "assistant", "content": "I see the image."}
- ]
- response = chat(
- message="What about that previous image?",
- media_content=None,
- selected_parts=[],
- api_endpoint="tag_provider",
- history=history_with_image,
- api_key="tag_key",
- custom_prompt=None,
- temperature=0.4,
- image_history_mode="tag_past"
- )
- payload = mock_chat_api_call_shim.call_args.kwargs["messages_payload"]
- assert len(payload) == 3 # 2 history items processed + 1 current
-
- # First user message from history
- user_hist_content = payload[0]["content"]
- assert isinstance(user_hist_content, list)
- assert {"type": "text", "text": "Here is an image."} in user_hist_content
- assert {"type": "text", "text": ""} in user_hist_content
- assert not any(p["type"] == "image_url" for p in user_hist_content) # Image should be replaced by tag
-
- # Assistant message from history
- assistant_hist_content = payload[1]["content"]
- assert assistant_hist_content[0]["text"] == "I see the image."
-
-
-@pytest.mark.unit
-@patch("tldw_app.Chat.Chat_Functions.chat_api_call")
-@patch("tldw_app.Chat.Chat_Functions.process_user_input", side_effect=lambda x, *args, **kwargs: x)
-def test_chat_function_streaming_passthrough(mock_process_input, mock_chat_api_call_shim):
- def dummy_stream_gen():
- yield "stream chunk 1"
- yield "stream chunk 2"
- mock_chat_api_call_shim.return_value = dummy_stream_gen()
-
- response_gen = chat(
- message="Stream this",
- media_content=None,
- selected_parts=[],
- api_endpoint="stream_provider",
- history=[],
- api_key="key",
- custom_prompt=None,
- temperature=0.1,
- streaming=True
- )
- assert hasattr(response_gen, '__iter__')
- result = list(response_gen)
- assert result == ["stream chunk 1", "stream chunk 2"]
- mock_chat_api_call_shim.assert_called_once()
- assert mock_chat_api_call_shim.call_args.kwargs["streaming"] is True
-
-
-# --- Tests for save_chat_history_to_db_wrapper ---
-# These tests seem okay with the TUI context as they mock the DB.
-
-@pytest.mark.unit
-@patch("tldw_app.Chat.Chat_Functions.DEFAULT_CHARACTER_NAME", "TestDefaultChar")
-def test_save_chat_history_new_conversation_default_char(mock_load_settings_globally): # Ensure settings mock is active if needed by save_chat
- mock_db = MagicMock(spec=CharactersRAGDB)
- mock_db.client_id = "unit_test_client"
- mock_db.get_character_card_by_name.return_value = {"id": 99, "name": "TestDefaultChar", "version":1}
- mock_db.add_conversation.return_value = "new_conv_id_123"
- mock_db.transaction.return_value.__enter__.return_value = None # for 'with db.transaction():'
-
- history_to_save = [
- {"role": "user", "content": "Hello, default character!"},
- {"role": "assistant", "content": [{"type": "text", "text": "Hello, user!"}, {"type": "image_url", "image_url": {
- "url": ""}}]}
- ]
- conv_id, message = save_chat_history_to_db_wrapper(
- db=mock_db,
- chatbot_history=history_to_save,
- conversation_id=None,
- media_content_for_char_assoc=None,
- character_name_for_chat=None
- )
- assert conv_id == "new_conv_id_123"
- assert "success" in message.lower()
- mock_db.get_character_card_by_name.assert_called_once_with("TestDefaultChar")
- # ... (rest of assertions from your original test are likely still valid)
- first_message_call_args = mock_db.add_message.call_args_list[0].args[0]
- assert first_message_call_args["image_data"] is None
-
- second_message_call_args = mock_db.add_message.call_args_list[1].args[0]
- assert second_message_call_args["image_mime_type"] == "image/gif"
- assert isinstance(second_message_call_args["image_data"], bytes)
-
-
-@pytest.mark.unit
-def test_save_chat_history_resave_conversation_specific_char(mock_load_settings_globally):
- mock_db = MagicMock(spec=CharactersRAGDB)
- mock_db.client_id = "unit_test_client_resave"
- existing_conv_id = "existing_conv_456"
- char_id_for_resave = 77
- char_name_for_resave = "SpecificResaveChar"
- mock_db.get_character_card_by_name.return_value = {"id": char_id_for_resave, "name": char_name_for_resave, "version": 1}
- mock_db.get_conversation_by_id.return_value = {"id": existing_conv_id, "character_id": char_id_for_resave, "title": "Old Title", "version": 2}
- mock_db.get_messages_for_conversation.return_value = [{"id": "msg1", "version": 1}, {"id": "msg2", "version": 1}]
- mock_db.transaction.return_value.__enter__.return_value = None
-
- history_to_resave = [{"role": "user", "content": "Updated question for resave."}]
- conv_id, message = save_chat_history_to_db_wrapper(
- db=mock_db,
- chatbot_history=history_to_resave,
- conversation_id=existing_conv_id,
- media_content_for_char_assoc=None,
- character_name_for_chat=char_name_for_resave
- )
- assert conv_id == existing_conv_id
- assert "success" in message.lower()
- # ... (rest of assertions from your original test)
-
-
-# --- Chat Dictionary and Character Save/Load Tests ---
-# These tests seem okay with the TUI context.
-
-@pytest.mark.unit
-def test_parse_user_dict_markdown_file_various_formats(tmp_path):
- # ... (your original test content is good)
- md_content = textwrap.dedent("""
+def db_path(tmp_path):
+ """Provides a temporary path for the database file for each test."""
+ return tmp_path / "test_chat_func_db.sqlite"
+
+
+@pytest.fixture(scope="function")
+def db_instance(db_path, client_id):
+ """Creates a DB instance for each test, ensuring a fresh database."""
+ db = CharactersRAGDB(db_path, client_id)
+ yield db
+ db.close_connection()
+
+
+# --- Helper Functions ---
+
+def create_base64_image():
+ """Creates a dummy 1x1 png and returns its base64 string."""
+ img_bytes = io.BytesIO()
+ Image.new('RGB', (1, 1)).save(img_bytes, format='PNG')
+ return base64.b64encode(img_bytes.getvalue()).decode('utf-8')
+
+
+# --- Test Classes ---
+
+@patch('tldw_chatbook.Chat.Chat_Functions.API_CALL_HANDLERS')
+class TestChatApiCall:
+ def test_routes_to_correct_handler(self, mock_handlers, mocker):
+ mock_openai_handler = mocker.MagicMock(return_value="OpenAI response")
+ mock_handlers.get.return_value = mock_openai_handler
+
+ response = chat_api_call(
+ api_endpoint="openai",
+ messages_payload=[{"role": "user", "content": "test"}],
+ model="gpt-4"
+ )
+
+ mock_handlers.get.assert_called_with("openai")
+ mock_openai_handler.assert_called_once()
+ kwargs = mock_openai_handler.call_args.kwargs
+ assert kwargs['input_data'][0]['content'] == "test" # Mapped to 'input_data' for openai
+ assert kwargs['model'] == "gpt-4"
+ assert response == "OpenAI response"
+
+ def test_unsupported_endpoint_raises_error(self, mock_handlers):
+ mock_handlers.get.return_value = None
+ with pytest.raises(ValueError, match="Unsupported API endpoint: unsupported"):
+ chat_api_call("unsupported", messages_payload=[])
+
+ def test_http_error_401_raises_auth_error(self, mock_handlers, mocker):
+ mock_response = MagicMock()
+ mock_response.status_code = 401
+ mock_response.text = "Invalid API key"
+ http_error = requests.exceptions.HTTPError(response=mock_response)
+
+ mock_handler = mocker.MagicMock(side_effect=http_error)
+ mock_handlers.get.return_value = mock_handler
+
+ with pytest.raises(ChatAuthenticationError):
+ chat_api_call("openai", messages_payload=[])
+
+
+class TestChatFunction:
+ @patch('tldw_chatbook.Chat.Chat_Functions.chat_api_call')
+ def test_chat_basic_flow(self, mock_chat_api_call):
+ mock_chat_api_call.return_value = "LLM says hi"
+
+ response = chat(
+ message="Hello",
+ history=[],
+ media_content=None,
+ selected_parts=[],
+ api_endpoint="openai",
+ api_key="sk-123",
+ model="gpt-4",
+ temperature=0.7,
+ custom_prompt="Be brief."
+ )
+
+ assert response == "LLM says hi"
+ mock_chat_api_call.assert_called_once()
+ kwargs = mock_chat_api_call.call_args.kwargs
+
+ assert kwargs['api_endpoint'] == 'openai'
+ assert kwargs['model'] == 'gpt-4'
+ payload = kwargs['messages_payload']
+ assert len(payload) == 1
+ assert payload[0]['role'] == 'user'
+ user_content = payload[0]['content']
+ assert isinstance(user_content, list)
+ assert user_content[0]['type'] == 'text'
+ assert user_content[0]['text'] == "Be brief.\n\nHello"
+
+ @patch('tldw_chatbook.Chat.Chat_Functions.chat_api_call')
+ def test_chat_with_image_and_rag(self, mock_chat_api_call):
+ b64_img = create_base64_image()
+
+ chat(
+ message="Describe this.",
+ history=[],
+ media_content={"summary": "This is a summary."},
+ selected_parts=["summary"],
+ api_endpoint="openai",
+ api_key="sk-123",
+ model="gpt-4-vision-preview",
+ temperature=0.5,
+ current_image_input={'base64_data': b64_img, 'mime_type': 'image/png'},
+ custom_prompt=None
+ )
+
+ kwargs = mock_chat_api_call.call_args.kwargs
+ payload = kwargs['messages_payload']
+ user_content_parts = payload[0]['content']
+
+ assert len(user_content_parts) == 2 # RAG text + image
+
+ text_part = next(p for p in user_content_parts if p['type'] == 'text')
+ image_part = next(p for p in user_content_parts if p['type'] == 'image_url')
+
+ assert "Summary: This is a summary." in text_part['text']
+ assert "Describe this." in text_part['text']
+ assert image_part['image_url']['url'].startswith("data:image/png;base64,")
+
+ @patch('tldw_chatbook.Chat.Chat_Functions.chat_api_call')
+ def test_chat_adapts_payload_for_deepseek(self, mock_chat_api_call):
+ chat(
+ message="Hello",
+ history=[
+ {"role": "user", "content": [{"type": "text", "text": "Old message"},
+ {"type": "image_url", "image_url": {"url": "data:..."}}]},
+ {"role": "assistant", "content": "Old reply"}
+ ],
+ media_content=None,
+ selected_parts=[],
+ api_endpoint="deepseek", # The endpoint that needs adaptation
+ api_key="sk-123",
+ model="deepseek-chat",
+ temperature=0.7,
+ custom_prompt=None,
+ image_history_mode="tag_past"
+ )
+
+ kwargs = mock_chat_api_call.call_args.kwargs
+ adapted_payload = kwargs['messages_payload']
+
+ # Check that all content fields are strings, not lists of parts
+ assert isinstance(adapted_payload[0]['content'], str)
+ assert adapted_payload[0]['content'] == "Old message\n"
+ assert isinstance(adapted_payload[1]['content'], str)
+ assert adapted_payload[1]['content'] == "Old reply"
+ assert isinstance(adapted_payload[2]['content'], str)
+ assert adapted_payload[2]['content'] == "Hello"
+
+
+class TestChatHistorySaving:
+ def test_save_chat_history_to_db_new_conversation(self, db_instance: CharactersRAGDB):
+ # The history format is now OpenAI's message objects
+ chatbot_history = [
+ {"role": "user", "content": "Hello there"},
+ {"role": "assistant", "content": "General Kenobi"}
+ ]
+
+ # Uses default character
+ conv_id, status = save_chat_history_to_db_wrapper(
+ db=db_instance,
+ chatbot_history=chatbot_history,
+ conversation_id=None,
+ media_content_for_char_assoc=None,
+ character_name_for_chat=None
+ )
+
+ assert "success" in status.lower()
+ assert conv_id is not None
+
+ messages = db_instance.get_messages_for_conversation(conv_id)
+ assert len(messages) == 2
+ assert messages[0]['sender'] == 'user'
+ assert messages[1]['sender'] == 'assistant'
+
+ conv_details = db_instance.get_conversation_by_id(conv_id)
+ assert conv_details['character_id'] == 1 # Default character
+
+ def test_save_chat_history_with_image(self, db_instance: CharactersRAGDB):
+ b64_img = create_base64_image()
+ chatbot_history = [
+ {"role": "user", "content": [
+ {"type": "text", "text": "Look at this image"},
+ {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{b64_img}"}}
+ ]},
+ {"role": "assistant", "content": "I see a 1x1 black square."}
+ ]
+
+ conv_id, status = save_chat_history_to_db_wrapper(db_instance, chatbot_history, None, None, None)
+ assert "success" in status.lower()
+
+ messages = db_instance.get_messages_for_conversation(conv_id)
+ assert len(messages) == 2
+ assert messages[0]['content'] == "Look at this image"
+ assert messages[0]['image_data'] is not None
+ assert messages[0]['image_mime_type'] == "image/png"
+ assert messages[1]['image_data'] is None
+
+ def test_resave_chat_history(self, db_instance: CharactersRAGDB):
+ char_id = db_instance.add_character_card({"name": "Resaver"})
+ initial_history = [{"role": "user", "content": "First message"}]
+ conv_id, _ = save_chat_history_to_db_wrapper(db_instance, initial_history, None, None, "Resaver")
+
+ updated_history = [
+ {"role": "user", "content": "New first message"},
+ {"role": "assistant", "content": "New reply"}
+ ]
+
+ # Resave with same conv_id
+ resave_id, status = save_chat_history_to_db_wrapper(db_instance, updated_history, conv_id, None, "Resaver")
+ assert "success" in status.lower()
+ assert resave_id == conv_id
+
+ messages = db_instance.get_messages_for_conversation(conv_id)
+ assert len(messages) == 2
+ assert messages[0]['content'] == "New first message"
+
+
+class TestCharacterManagement:
+ def test_save_and_load_character(self, db_instance: CharactersRAGDB):
+ char_data = {
+ "name": "Super Coder",
+ "description": "A character that codes.",
+ "image": create_base64_image()
+ }
+
+ char_id = save_character(db_instance, char_data)
+ assert isinstance(char_id, int)
+
+ loaded_chars = load_characters(db_instance)
+ assert "Super Coder" in loaded_chars
+ loaded_char_data = loaded_chars["Super Coder"]
+ assert loaded_char_data['description'] == "A character that codes."
+ assert loaded_char_data['image_base64'] is not None
+
+ def test_get_character_names(self, db_instance: CharactersRAGDB):
+ save_character(db_instance, {"name": "Beta"})
+ save_character(db_instance, {"name": "Alpha"})
+
+ # Default character is also present
+ names = get_character_names(db_instance)
+ assert names == ["Alpha", "Beta", DEFAULT_CHARACTER_NAME]
+
+
+class TestChatDictionary:
+ def test_parse_user_dict_markdown_file(self, tmp_path):
+ dict_content = """
key1: value1
key2: |
- This is a
- multi-line value for key2.
- It has several lines.
- # This comment line is part of key2's value.
+ This is a
+ multiline value.
---@@@---
- key_after_term: after_terminator_value
- """).strip()
- dict_file = tmp_path / "test_dict.md"
- dict_file.write_text(md_content)
- parsed = parse_user_dict_markdown_file(str(dict_file))
- expected_key2_value = ("This is a\n multi-line value for key2.\n It has several lines.\n# This comment line is part of key2's value.")
- assert parsed.get("key1") == "value1"
- assert parsed.get("key2") == expected_key2_value
- assert parsed.get("key_after_term") == "after_terminator_value"
-
-
-@pytest.mark.unit
-def test_chat_dictionary_class_methods():
- # ... (your original test content is good)
- entry_plain = ChatDictionary(key="hello", content="hi there")
- entry_regex = ChatDictionary(key=r"/\bworld\b/i", content="planet") # Added /i for ignore case in regex
- assert entry_plain.matches("hello world")
- assert entry_regex.matches("Hello World!")
- assert isinstance(entry_regex.key, re.Pattern)
- assert entry_regex.key.flags & re.IGNORECASE
-
-
-@pytest.mark.unit
-@patch("tldw_app.Chat.Chat_Functions.chat_api_call") # Mock the inner chat_api_call used by `chat`
-@patch("tldw_app.Chat.Chat_Functions.parse_user_dict_markdown_file")
-# mock_load_settings_globally is active
-def test_chat_function_with_chat_dictionary_post_replacement(
- mock_parse_dict, mock_chat_api_call_inner_shim, tmp_path, mock_load_settings_globally
-):
- post_gen_dict_path = str(tmp_path / "post_gen.md")
- # Override the global mock for this specific test case
- mock_load_settings_globally.return_value = {
- "chat_dictionaries": {
- "post_gen_replacement": "True",
- "post_gen_replacement_dict": post_gen_dict_path
- },
- # Ensure other necessary default configs for the provider are present if chat_api_call or its handlers need them
- "openai_api": {"api_key": "testkey"} # Example
- }
-
- post_gen_dict_file = tmp_path / "post_gen.md" # Actual file creation
- post_gen_dict_file.write_text("AI: Artificial Intelligence\nLLM: Large Language Model")
- os.path.exists(post_gen_dict_path) # For the check in `chat`
-
- mock_parse_dict.return_value = {"AI": "Artificial Intelligence", "LLM": "Large Language Model"}
- raw_llm_response = "The AI assistant uses an LLM."
- mock_chat_api_call_inner_shim.return_value = raw_llm_response
-
- final_response = chat(
- message="Tell me about AI.",
- media_content=None,
- selected_parts=[],
- api_endpoint="openai",
- api_key="testkey",
- custom_prompt=None,
- history=[],
- temperature=0.7,
- streaming=False
- )
- expected_response = "The Artificial Intelligence assistant uses an Large Language Model."
- assert final_response == expected_response
- mock_parse_dict.assert_called_once_with(post_gen_dict_path)
-
-
-@pytest.mark.unit
-def test_save_character_new_and_update():
- # ... (your original test content is good)
- mock_db = MagicMock(spec=CharactersRAGDB)
- char_data_v1 = {"name": "TestCharacter", "description": "Hero.", "image": ""}
- mock_db.get_character_card_by_name.return_value = None
- mock_db.add_character_card.return_value = 1
- save_character(db=mock_db, character_data=char_data_v1)
- mock_db.add_character_card.assert_called_once()
- # ... more assertions ...
-
-@pytest.mark.unit
-def test_load_characters_empty_and_with_data():
- # ... (your original test content is good)
- mock_db = MagicMock(spec=CharactersRAGDB)
- mock_db.list_character_cards.return_value = []
- assert load_characters(db=mock_db) == {}
- # ... more assertions for data ...
-
-# --- Property-Based Tests ---
-
-# Helper strategy for generating message content (simple text or list of parts)
-st_text_content = st.text(min_size=1, max_size=50)
-st_image_url_part = st.fixed_dictionaries({
- "type": st.just("image_url"),
- "image_url": st.fixed_dictionaries({
- "url": st.text(min_size=10, max_size=30).map(lambda s: f"data:image/png;base64,{s}")
- })
-})
-st_text_part = st.fixed_dictionaries({"type": st.just("text"), "text": st_text_content})
-st_message_part = st.one_of(st_text_part, st_image_url_part)
-
-st_message_content_list = st.lists(st_message_part, min_size=1, max_size=3)
-st_valid_message_content = st.one_of(st_text_content, st_message_content_list)
-
-st_message = st.fixed_dictionaries({
- "role": st.sampled_from(["user", "assistant"]),
- "content": st_valid_message_content
-})
-st_history = st.lists(st_message, max_size=5)
-
-# Strategy for optional float parameters like temperature, top_p
-st_optional_float_0_to_1 = st.one_of(st.none(), st.floats(min_value=0.0, max_value=1.0))
-st_optional_float_0_to_2 = st.one_of(st.none(), st.floats(min_value=0.0, max_value=2.0)) # For penalties
-st_optional_int_gt_0 = st.one_of(st.none(), st.integers(min_value=1, max_value=2048)) # For max_tokens, top_k
-
-
-@given(
- api_endpoint=st.sampled_from(KNOWN_PROVIDERS),
- temp=st_optional_float_0_to_1,
- system_message=st.one_of(st.none(), st.text(max_size=50)),
- streaming=st.booleans(),
- max_tokens=st_optional_int_gt_0,
- seed=st.one_of(st.none(), st.integers()),
- # Add more strategies for other chat_api_call params if desired
-)
-@settings(max_examples=20, deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture])
-def test_property_chat_api_call_param_passing(
- mock_llm_handlers, # Fixture to mock the actual handlers
- api_endpoint, temp, system_message, streaming, max_tokens, seed
-):
- """
- Tests that chat_api_call correctly routes to the mocked handler
- and passes through known parameters, mapping them if necessary.
- """
- mock_handler = mock_llm_handlers[api_endpoint]
- mock_handler.reset_mock()
- mock_handler.return_value = "Property test success"
- param_map_for_provider = PROVIDER_PARAM_MAP.get(api_endpoint, {})
-
- messages = [{"role": "user", "content": "Hypothesis test"}]
- args_to_call = {
- "api_endpoint": api_endpoint,
- "messages_payload": messages,
- "api_key": "prop_test_key", # Assuming all handlers take api_key mapped
- "model": "prop_test_model", # Assuming all handlers take model mapped
- }
- # Add optional params to args_to_call only if they are not None
- if temp is not None: args_to_call["temp"] = temp
- if system_message is not None: args_to_call["system_message"] = system_message
- if streaming is not None: args_to_call["streaming"] = streaming # streaming is not optional in signature, but can be None in map
- if max_tokens is not None: args_to_call["max_tokens"] = max_tokens
- if seed is not None: args_to_call["seed"] = seed
-
- result = chat_api_call(**args_to_call)
- assert result == "Property test success"
- mock_handler.assert_called_once()
- called_kwargs = mock_handler.call_args.kwargs
-
- # Check messages_payload (or its mapped equivalent)
- mapped_messages_key = param_map_for_provider.get('messages_payload', 'messages_payload') # Default if not in map
- assert called_kwargs.get(mapped_messages_key) == messages
-
- # Check other params if they were passed and are in the map
- if temp is not None and 'temp' in param_map_for_provider:
- assert called_kwargs.get(param_map_for_provider['temp']) == temp
- if system_message is not None and 'system_message' in param_map_for_provider:
- assert called_kwargs.get(param_map_for_provider['system_message']) == system_message
- if streaming is not None and 'streaming' in param_map_for_provider:
- assert called_kwargs.get(param_map_for_provider['streaming']) == streaming
- if max_tokens is not None and 'max_tokens' in param_map_for_provider:
- assert called_kwargs.get(param_map_for_provider['max_tokens']) == max_tokens
- if seed is not None and 'seed' in param_map_for_provider:
- assert called_kwargs.get(param_map_for_provider['seed']) == seed
-
-
-@given(
- message=st.text(max_size=100),
- history=st_history,
- custom_prompt=st.one_of(st.none(), st.text(max_size=50)),
- temperature=st.floats(min_value=0.0, max_value=1.0),
- system_message=st.one_of(st.none(), st.text(max_size=50)),
- streaming=st.booleans(),
- llm_max_tokens=st_optional_int_gt_0,
- llm_seed=st.one_of(st.none(), st.integers()),
- image_history_mode=st.sampled_from(["send_all", "send_last_user_image", "tag_past", "ignore_past"]),
- current_image_input=st.one_of(
- st.none(),
- st.fixed_dictionaries({
- "base64_data": st.text(min_size=5, max_size=20).map(lambda s: base64.b64encode(s.encode()).decode()),
- "mime_type": st.sampled_from(["image/png", "image/jpeg"])
- })
- )
-)
-@settings(max_examples=20, deadline=None)
-@patch("tldw_app.Chat.Chat_Functions.chat_api_call")
-@patch("tldw_app.Chat.Chat_Functions.process_user_input", side_effect=lambda x, *a, **kw: x)
-# mock_load_settings_globally is active
-def test_property_chat_function_payload_construction(
- mock_process_input, mock_chat_api_call_shim, # Mocked dependencies first
- message, history, custom_prompt, temperature, system_message, streaming, # Generated inputs
- llm_max_tokens, llm_seed, image_history_mode, current_image_input
-):
- mock_chat_api_call_shim.return_value = "Property LLM Response" if not streaming else (lambda: (yield "Stream"))()
-
- response = chat(
- message=message, history=history, media_content=None, selected_parts=[],
- api_endpoint="prop_provider", api_key="prop_key",
- custom_prompt=custom_prompt, temperature=temperature, system_message=system_message,
- streaming=streaming, llm_max_tokens=llm_max_tokens, llm_seed=llm_seed,
- image_history_mode=image_history_mode, current_image_input=current_image_input
- )
-
- if streaming:
- assert hasattr(response, '__iter__')
- list(response) # Consume
- else:
- assert response == "Property LLM Response"
-
- mock_chat_api_call_shim.assert_called_once()
- call_args = mock_chat_api_call_shim.call_args.kwargs
-
- assert call_args["api_endpoint"] == "prop_provider"
- assert call_args["temp"] == temperature
- if system_message is not None:
- assert call_args["system_message"] == system_message
- assert call_args["streaming"] == streaming
- if llm_max_tokens is not None:
- assert call_args["max_tokens"] == llm_max_tokens
- if llm_seed is not None:
- assert call_args["seed"] == llm_seed
-
- payload = call_args["messages_payload"]
- assert isinstance(payload, list)
- if not payload: # Should not happen if message is non-empty, but good to check
- assert not message and not history # Only if input is truly empty
- return
-
- # Verify structure of the last message (current user input)
- last_message_in_payload = payload[-1]
- assert last_message_in_payload["role"] == "user"
- assert isinstance(last_message_in_payload["content"], list)
-
- # Check if custom_prompt is prepended
- expected_current_text = message
- if custom_prompt:
- expected_current_text = f"{custom_prompt}\n\n{expected_current_text}"
-
- text_part_found = any(p["type"] == "text" and p["text"] == expected_current_text.strip() for p in last_message_in_payload["content"])
- if not message and not custom_prompt and not current_image_input: # if no user text and no image
- assert any(p["type"] == "text" and "(No user input for this turn)" in p["text"] for p in last_message_in_payload["content"])
- elif expected_current_text.strip() or (not expected_current_text.strip() and not current_image_input): # if only text or no text and no image
- assert text_part_found or (not expected_current_text.strip() and not any(p["type"] == "text" for p in last_message_in_payload["content"]))
-
-
- if current_image_input:
- expected_image_url = f"data:{current_image_input['mime_type']};base64,{current_image_input['base64_data']}"
- image_part_found = any(p["type"] == "image_url" and p["image_url"]["url"] == expected_image_url for p in last_message_in_payload["content"])
- assert image_part_found
-
- # Further checks on history processing (e.g., image_history_mode effects) could be added here,
- # but they become complex for property tests. Unit tests are better for those specifics.
+ /key3/i: value3
+ """
+ dict_file = tmp_path / "test_dict.md"
+ dict_file.write_text(dict_content)
+
+ parsed = parse_user_dict_markdown_file(str(dict_file))
+ assert parsed["key1"] == "value1"
+ assert parsed["key2"] == "This is a\nmultiline value."
+ assert parsed["/key3/i"] == "value3"
+
+ def test_process_user_input_simple_replacement(self):
+ entries = [ChatDictionary(key="hello", content="GREETING")]
+ user_input = "I said hello to the world."
+ result = process_user_input(user_input, entries)
+ assert result == "I said GREETING to the world."
+
+ def test_process_user_input_regex_replacement(self):
+ entries = [ChatDictionary(key=r"/h[aeiou]llo/i", content="GREETING")]
+ user_input = "I said hallo and heLlo."
+ # It replaces only the first match
+ result = process_user_input(user_input, entries)
+ assert result == "I said GREETING and heLlo."
+
+ def test_process_user_input_token_budget(self):
+ # Content is 4 tokens, budget is 3. Should not replace.
+ entries = [ChatDictionary(key="long", content="this is too long")]
+ user_input = "This is a long test."
+ result = process_user_input(user_input, entries, max_tokens=3)
+ assert result == "This is a long test."
+
+ # Content is 3 tokens, budget is 3. Should replace.
+ entries = [ChatDictionary(key="short", content="this is fine")]
+ user_input = "This is a short test."
+ result = process_user_input(user_input, entries, max_tokens=3)
+ assert result == "This is a this is fine test."
#
# End of test_chat_functions.py
diff --git a/Tests/Chat/test_prompt_template_manager.py b/Tests/Chat/test_prompt_template_manager.py
index f5e71015..922411e0 100644
--- a/Tests/Chat/test_prompt_template_manager.py
+++ b/Tests/Chat/test_prompt_template_manager.py
@@ -1,142 +1,163 @@
-# Tests/Chat/test_prompt_template_manager.py
+# test_prompt_template_manager.py
+
import pytest
import json
from pathlib import Path
-from unittest.mock import patch, mock_open
+# Local Imports from this project
from tldw_chatbook.Chat.prompt_template_manager import (
PromptTemplate,
load_template,
apply_template_to_string,
get_available_templates,
- DEFAULT_RAW_PASSTHROUGH_TEMPLATE,
- _loaded_templates # For clearing cache in tests
+ _loaded_templates,
+ DEFAULT_RAW_PASSTHROUGH_TEMPLATE
)
-# Fixture to clear the template cache before each test
+# --- Test Setup ---
+
@pytest.fixture(autouse=True)
def clear_template_cache():
+ """Fixture to clear the template cache before each test."""
+ original_templates = _loaded_templates.copy()
_loaded_templates.clear()
- # Re-add the default passthrough because it's normally added at module load
+ # Ensure the default is always there for tests that might rely on it
_loaded_templates["raw_passthrough"] = DEFAULT_RAW_PASSTHROUGH_TEMPLATE
+ yield
+ # Restore original cache state if needed, though clearing is usually sufficient
+ _loaded_templates.clear()
+ _loaded_templates.update(original_templates)
@pytest.fixture
-def mock_templates_dir(tmp_path: Path):
- templates_dir = tmp_path / "prompt_templates_test"
+def mock_templates_dir(tmp_path, monkeypatch):
+ """Creates a temporary directory for prompt templates and patches the module-level constant."""
+ templates_dir = tmp_path / "prompt_templates"
templates_dir.mkdir()
- # Create a valid template file
- valid_template_data = {
- "name": "test_valid",
- "description": "A valid test template",
- "system_message_template": "System: {sys_var}",
- "user_message_content_template": "User: {user_var} - {message_content}"
+ # Patch the PROMPT_TEMPLATES_DIR constant in the target module
+ monkeypatch.setattr('tldw_chatbook.Chat.prompt_template_manager.PROMPT_TEMPLATES_DIR', templates_dir)
+
+ # Create some dummy template files
+ template1_data = {
+ "name": "test_template_1",
+ "description": "A simple test template.",
+ "user_message_content_template": "User said: {{message_content}}"
}
- with open(templates_dir / "test_valid.json", "w") as f:
- json.dump(valid_template_data, f)
+ (templates_dir / "test_template_1.json").write_text(json.dumps(template1_data))
- # Create an invalid JSON template file
- with open(templates_dir / "test_invalid_json.json", "w") as f:
- f.write("{'name': 'invalid', 'description': 'bad json'") # Invalid JSON
+ template2_data = {
+ "name": "test_template_2",
+ "system_message_template": "System context: {{system_context}}",
+ "user_message_content_template": "{{message_content}}"
+ }
+ (templates_dir / "test_template_2.json").write_text(json.dumps(template2_data))
- # Create an empty template file (valid JSON but might be handled as error by Pydantic)
- with open(templates_dir / "test_empty.json", "w") as f:
- json.dump({}, f)
+ # Create a malformed JSON file
+ (templates_dir / "malformed.json").write_text("{'invalid': 'json'")
return templates_dir
-@pytest.mark.unit
-def test_load_template_success(mock_templates_dir):
- with patch("tldw_app.Chat.prompt_template_manager.PROMPT_TEMPLATES_DIR", mock_templates_dir):
- template = load_template("test_valid")
- assert template is not None
- assert template.name == "test_valid"
- assert template.system_message_template == "System: {sys_var}"
- # Test caching
- template_cached = load_template("test_valid")
- assert template_cached is template # Should be the same object from cache
+# --- Test Cases ---
+class TestPromptTemplateManager:
-@pytest.mark.unit
-def test_load_template_not_found(mock_templates_dir):
- with patch("tldw_app.Chat.prompt_template_manager.PROMPT_TEMPLATES_DIR", mock_templates_dir):
+ def test_load_template_success(self, mock_templates_dir):
+ """Test successfully loading a valid template file."""
+ template = load_template("test_template_1")
+ assert template is not None
+ assert isinstance(template, PromptTemplate)
+ assert template.name == "test_template_1"
+ assert template.user_message_content_template == "User said: {{message_content}}"
+
+ def test_load_template_not_found(self, mock_templates_dir):
+ """Test loading a template that does not exist."""
template = load_template("non_existent_template")
assert template is None
-
-@pytest.mark.unit
-def test_load_template_invalid_json(mock_templates_dir):
- with patch("tldw_app.Chat.prompt_template_manager.PROMPT_TEMPLATES_DIR", mock_templates_dir):
- template = load_template("test_invalid_json")
- assert template is None # Should fail to parse
-
-
-@pytest.mark.unit
-def test_load_template_empty_json_fails_validation(mock_templates_dir):
- with patch("tldw_app.Chat.prompt_template_manager.PROMPT_TEMPLATES_DIR", mock_templates_dir):
- template = load_template("test_empty")
- # Pydantic will raise validation error because 'name' is missing,
- # load_template should catch this and return None.
+ def test_load_template_malformed_json(self, mock_templates_dir):
+ """Test loading a template from a file with invalid JSON."""
+ template = load_template("malformed")
assert template is None
+ def test_load_template_caching(self, mock_templates_dir):
+ """Test that a loaded template is cached and not re-read from disk."""
+ template1 = load_template("test_template_1")
+ assert "test_template_1" in _loaded_templates
-@pytest.mark.unit
-# Inside test_apply_template_to_string():
-def test_apply_template_to_string():
- template_str_jinja = "Hello {{name}}, welcome to {{place}}." # Use Jinja
- data_full = {"name": "Alice", "place": "Wonderland"}
- assert apply_template_to_string(template_str_jinja, data_full) == "Hello Alice, welcome to Wonderland."
+ # Modify the file on disk
+ (mock_templates_dir / "test_template_1.json").write_text(json.dumps({"name": "modified"}))
- template_partial_jinja = "Hello {{name}}." # Use Jinja
- data_partial = {"name": "Bob"}
- assert apply_template_to_string(template_partial_jinja, data_partial) == "Hello Bob."
+ # Load again - should return the cached version
+ template2 = load_template("test_template_1")
+ assert template2 is not None
+ assert template2.name == "test_template_1" # Original name from cache
+ assert template2 == template1
- # Test with missing data - Jinja renders empty for missing by default if not strict
- assert apply_template_to_string(template_partial_jinja, {}) == "Hello ."
-
- # Test with None template string
- assert apply_template_to_string(None, data_full) == ""
-
-
-@pytest.mark.unit
-def test_get_available_templates(mock_templates_dir):
- with patch("tldw_app.Chat.prompt_template_manager.PROMPT_TEMPLATES_DIR", mock_templates_dir):
+ def test_get_available_templates(self, mock_templates_dir):
+ """Test discovering available templates from the directory."""
available = get_available_templates()
assert isinstance(available, list)
- assert "test_valid" in available
- assert "test_invalid_json" in available
- assert "test_empty" in available
- assert len(available) == 3
-
-
-@pytest.mark.unit
-def test_get_available_templates_no_dir():
- with patch("tldw_app.Chat.prompt_template_manager.PROMPT_TEMPLATES_DIR", Path("/non/existent/dir")):
- available = get_available_templates()
- assert available == []
+ assert set(available) == {"test_template_1", "test_template_2", "malformed"}
+ def test_get_available_templates_no_dir(self, tmp_path, monkeypatch):
+ """Test getting templates when the directory doesn't exist."""
+ non_existent_dir = tmp_path / "non_existent_dir"
+ monkeypatch.setattr('tldw_chatbook.Chat.prompt_template_manager.PROMPT_TEMPLATES_DIR', non_existent_dir)
+ assert get_available_templates() == []
-@pytest.mark.unit
-def test_default_raw_passthrough_template():
- assert DEFAULT_RAW_PASSTHROUGH_TEMPLATE is not None
- assert DEFAULT_RAW_PASSTHROUGH_TEMPLATE.name == "raw_passthrough"
- data = {"message_content": "test content", "original_system_message_from_request": "system content"}
-
- # User message template (is "{{message_content}}")
- assert apply_template_to_string(DEFAULT_RAW_PASSTHROUGH_TEMPLATE.user_message_content_template,
- data) == "test content"
- # System message template (is "{{original_system_message_from_request}}")
- assert apply_template_to_string(DEFAULT_RAW_PASSTHROUGH_TEMPLATE.system_message_template,
- data) == "system content"
-
- data_empty_sys = {"original_system_message_from_request": ""}
- assert apply_template_to_string(DEFAULT_RAW_PASSTHROUGH_TEMPLATE.system_message_template,
- data_empty_sys) == ""
-
- data_missing_sys = {"message_content": "some_content"} # original_system_message_from_request is missing
- assert apply_template_to_string(DEFAULT_RAW_PASSTHROUGH_TEMPLATE.system_message_template,
- data_missing_sys) == "" # Jinja renders missing as empty
-
+ def test_default_passthrough_template_is_available(self):
+ """Test that the default 'raw_passthrough' template is loaded."""
+ template = load_template("raw_passthrough")
+ assert template is not None
+ assert template.name == "raw_passthrough"
+ assert template.user_message_content_template == "{{message_content}}"
+
+
+class TestTemplateRendering:
+
+ def test_apply_template_to_string_success(self):
+ """Test basic successful rendering."""
+ template_str = "Hello, {{ name }}!"
+ data = {"name": "World"}
+ result = apply_template_to_string(template_str, data)
+ assert result == "Hello, World!"
+
+ def test_apply_template_to_string_missing_placeholder(self):
+ """Test rendering when a placeholder in the template is not in the data."""
+ template_str = "Hello, {{ name }}! Your age is {{ age }}."
+ data = {"name": "World"} # 'age' is missing
+ result = apply_template_to_string(template_str, data)
+ assert result == "Hello, World! Your age is ." # Jinja renders missing variables as empty strings
+
+ def test_apply_template_with_none_input_string(self):
+ """Test that a None template string returns an empty string."""
+ data = {"name": "World"}
+ result = apply_template_to_string(None, data)
+ assert result == ""
+
+ def test_apply_template_with_complex_data(self):
+ """Test rendering with more complex data structures like lists and dicts."""
+ template_str = "User: {{ user.name }}. Items: {% for item in items %}{{ item }}{% if not loop.last %}, {% endif %}{% endfor %}."
+ data = {
+ "user": {"name": "Alice"},
+ "items": ["apple", "banana", "cherry"]
+ }
+ result = apply_template_to_string(template_str, data)
+ assert result == "User: Alice. Items: apple, banana, cherry."
+
+ def test_safe_render_prevents_unsafe_operations(self):
+ """Test that the sandboxed environment prevents access to unsafe attributes."""
+ # Attempt to access a private attribute or a method that could be unsafe
+ template_str = "Unsafe access: {{ my_obj.__class__ }}"
+
+ class MyObj: pass
+
+ data = {"my_obj": MyObj()}
+
+ # In a sandboxed environment, this should raise a SecurityError, which our wrapper catches.
+ # The wrapper then returns the original string.
+ result = apply_template_to_string(template_str, data)
+ assert result == template_str
\ No newline at end of file
diff --git a/Tests/DB/test_sqlite_db.py b/Tests/DB/test_sqlite_db.py
deleted file mode 100644
index 5eabfd8b..00000000
--- a/Tests/DB/test_sqlite_db.py
+++ /dev/null
@@ -1,553 +0,0 @@
-# tests/test_sqlite_db.py
-# Description: Unit tests for SQLite database operations, including CRUD, transactions, and sync log management.
-#
-# Imports:
-import json
-import os
-import pytest
-import time
-import sqlite3
-from datetime import datetime, timezone, timedelta
-#
-# 3rd-Party Imports:
-#
-# Local imports
-from tldw_cli.tldw_app.DB.Client_Media_DB_v2 import Database, ConflictError
-# Import from src using adjusted sys.path in conftest
-#
-#######################################################################################################################
-#
-# Functions:
-
-# Helper to get sync log entries for assertions
-def get_log_count(db: Database, entity_uuid: str) -> int:
- cursor = db.execute_query("SELECT COUNT(*) FROM sync_log WHERE entity_uuid = ?", (entity_uuid,))
- return cursor.fetchone()[0]
-
-def get_latest_log(db: Database, entity_uuid: str) -> dict | None:
- cursor = db.execute_query(
- "SELECT * FROM sync_log WHERE entity_uuid = ? ORDER BY change_id DESC LIMIT 1",
- (entity_uuid,)
- )
- row = cursor.fetchone()
- return dict(row) if row else None
-
-def get_entity_version(db: Database, entity_table: str, uuid: str) -> int | None:
- cursor = db.execute_query(f"SELECT version FROM {entity_table} WHERE uuid = ?", (uuid,))
- row = cursor.fetchone()
- return row['version'] if row else None
-
-class TestDatabaseInitialization:
- def test_memory_db_creation(self, memory_db_factory):
- """Test creating an in-memory database."""
- db = memory_db_factory("client_mem")
- assert db.is_memory_db
- assert db.client_id == "client_mem"
- # Check if a table exists (schema creation check)
- cursor = db.execute_query("SELECT name FROM sqlite_master WHERE type='table' AND name='Media'")
- assert cursor.fetchone() is not None
- db.close_connection()
-
- def test_file_db_creation(self, file_db, temp_db_path):
- """Test creating a file-based database."""
- assert not file_db.is_memory_db
- assert file_db.client_id == "file_client"
- assert os.path.exists(temp_db_path)
- cursor = file_db.execute_query("SELECT name FROM sqlite_master WHERE type='table' AND name='Media'")
- assert cursor.fetchone() is not None
- # file_db fixture handles closure
-
- def test_missing_client_id(self):
- """Test that ValueError is raised if client_id is missing."""
- with pytest.raises(ValueError, match="Client ID cannot be empty"):
- Database(db_path=":memory:", client_id="")
- with pytest.raises(ValueError, match="Client ID cannot be empty"):
- Database(db_path=":memory:", client_id=None)
-
-
-class TestDatabaseTransactions:
- def test_transaction_commit(self, memory_db_factory):
- db = memory_db_factory()
- keyword = "commit_test"
- with db.transaction():
- # Use internal method _add_keyword_internal or simplified version for test
- kw_id, kw_uuid = db.add_keyword(keyword) # add_keyword uses transaction internally too, nested is ok
- # Verify outside transaction
- cursor = db.execute_query("SELECT keyword FROM Keywords WHERE id = ?", (kw_id,))
- assert cursor.fetchone()['keyword'] == keyword
-
- def test_transaction_rollback(self, memory_db_factory):
- db = memory_db_factory()
- keyword = "rollback_test"
- initial_count_cursor = db.execute_query("SELECT COUNT(*) FROM Keywords")
- initial_count = initial_count_cursor.fetchone()[0]
- try:
- with db.transaction():
- # Simplified insert for test clarity
- new_uuid = db._generate_uuid()
- db.execute_query(
- "INSERT INTO Keywords (keyword, uuid, last_modified, version, client_id, deleted) VALUES (?, ?, ?, 1, ?, 0)",
- (keyword, new_uuid, db._get_current_utc_timestamp_str(), db.client_id),
- commit=False # Important: commit=False inside transaction block
- )
- # Check *inside* transaction
- cursor_inside = db.execute_query("SELECT COUNT(*) FROM Keywords")
- assert cursor_inside.fetchone()[0] == initial_count + 1
- raise ValueError("Simulating error to trigger rollback") # Force rollback
- except ValueError:
- pass # Expected error
- except Exception as e:
- pytest.fail(f"Unexpected exception during rollback test: {e}")
-
- # Verify outside transaction (count should be back to initial)
- final_count_cursor = db.execute_query("SELECT COUNT(*) FROM Keywords")
- assert final_count_cursor.fetchone()[0] == initial_count
-
-
-class TestDatabaseCRUDAndSync:
-
- @pytest.fixture
- def db_instance(self, memory_db_factory):
- """Provides a fresh in-memory DB for each test in this class."""
- return memory_db_factory("crud_client")
-
- def test_add_keyword(self, db_instance):
- keyword = " test keyword "
- expected_keyword = "test keyword"
- kw_id, kw_uuid = db_instance.add_keyword(keyword)
-
- assert kw_id is not None
- assert kw_uuid is not None
-
- # Verify DB state
- cursor = db_instance.execute_query("SELECT * FROM Keywords WHERE id = ?", (kw_id,))
- row = cursor.fetchone()
- assert row['keyword'] == expected_keyword
- assert row['uuid'] == kw_uuid
- assert row['version'] == 1
- assert row['client_id'] == db_instance.client_id
- assert not row['deleted']
-
- # Verify Sync Log
- log_entry = get_latest_log(db_instance, kw_uuid)
- assert log_entry['operation'] == 'create'
- assert log_entry['entity'] == 'Keywords'
- assert log_entry['version'] == 1
- assert log_entry['client_id'] == db_instance.client_id
- payload = json.loads(log_entry['payload'])
- assert payload['keyword'] == expected_keyword
- assert payload['uuid'] == kw_uuid
-
- def test_add_existing_keyword(self, db_instance):
- keyword = "existing"
- kw_id1, kw_uuid1 = db_instance.add_keyword(keyword)
- log_count1 = get_log_count(db_instance, kw_uuid1)
-
- kw_id2, kw_uuid2 = db_instance.add_keyword(keyword) # Add again
- log_count2 = get_log_count(db_instance, kw_uuid1)
-
- assert kw_id1 == kw_id2
- assert kw_uuid1 == kw_uuid2
- assert log_count1 == log_count2 # No new log entry
-
- def test_soft_delete_keyword(self, db_instance):
- keyword = "to_delete"
- kw_id, kw_uuid = db_instance.add_keyword(keyword)
- initial_version = get_entity_version(db_instance, "Keywords", kw_uuid)
-
- deleted = db_instance.soft_delete_keyword(keyword)
- assert deleted is True
-
- # Verify DB state
- cursor = db_instance.execute_query("SELECT deleted, version FROM Keywords WHERE id = ?", (kw_id,))
- row = cursor.fetchone()
- assert row['deleted'] == 1
- assert row['version'] == initial_version + 1
-
- # Verify Sync Log
- log_entry = get_latest_log(db_instance, kw_uuid)
- assert log_entry['operation'] == 'delete'
- assert log_entry['entity'] == 'Keywords'
- assert log_entry['version'] == initial_version + 1
- payload = json.loads(log_entry['payload'])
- assert payload['uuid'] == kw_uuid # Delete payload is minimal
-
- def test_undelete_keyword(self, db_instance):
- keyword = "to_undelete"
- kw_id, kw_uuid = db_instance.add_keyword(keyword)
- db_instance.soft_delete_keyword(keyword) # Delete it first
- deleted_version = get_entity_version(db_instance, "Keywords", kw_uuid)
-
- # Adding it again should undelete it
- undelete_id, undelete_uuid = db_instance.add_keyword(keyword)
-
- assert undelete_id == kw_id
- assert undelete_uuid == kw_uuid
-
- # Verify DB state
- cursor = db_instance.execute_query("SELECT deleted, version FROM Keywords WHERE id = ?", (kw_id,))
- row = cursor.fetchone()
- assert row['deleted'] == 0
- assert row['version'] == deleted_version + 1
-
- # Verify Sync Log
- log_entry = get_latest_log(db_instance, kw_uuid)
- # Undelete is logged as an 'update'
- assert log_entry['operation'] == 'update'
- assert log_entry['entity'] == 'Keywords'
- assert log_entry['version'] == deleted_version + 1
- payload = json.loads(log_entry['payload'])
- assert payload['uuid'] == kw_uuid
- assert payload['deleted'] == 0 # Payload shows undeleted state
-
- def test_add_media_with_keywords_create(self, db_instance):
- title = "Test Media Create"
- content = "Some unique content for create."
- keywords = ["create_kw1", "create_kw2"]
-
- media_id, media_uuid, msg = db_instance.add_media_with_keywords(
- title=title,
- media_type="article",
- content=content,
- keywords=keywords,
- author="Tester"
- )
-
- assert media_id is not None
- assert media_uuid is not None
- assert f"Media '{title}' added." == msg # NEW (Exact match)
-
- # Verify Media DB state
- cursor = db_instance.execute_query("SELECT * FROM Media WHERE id = ?", (media_id,))
- media_row = cursor.fetchone()
- assert media_row['title'] == title
- assert media_row['uuid'] == media_uuid
- assert media_row['version'] == 1 # Initial version
- assert not media_row['deleted']
-
- # Verify Keywords exist
- cursor = db_instance.execute_query("SELECT COUNT(*) FROM Keywords WHERE keyword IN (?, ?)", tuple(keywords))
- assert cursor.fetchone()[0] == 2
-
- # Verify MediaKeywords links
- cursor = db_instance.execute_query("SELECT COUNT(*) FROM MediaKeywords WHERE media_id = ?", (media_id,))
- assert cursor.fetchone()[0] == 2
-
- # Verify DocumentVersion creation
- cursor = db_instance.execute_query("SELECT version_number, content FROM DocumentVersions WHERE media_id = ? ORDER BY version_number DESC LIMIT 1", (media_id,))
- version_row = cursor.fetchone()
- assert version_row['version_number'] == 1
- assert version_row['content'] == content
-
- # Verify Sync Log for Media
- log_entry = get_latest_log(db_instance, media_uuid)
- assert log_entry['operation'] == 'create'
- assert log_entry['entity'] == 'Media'
- # Note: MediaKeywords triggers might log *after* the media create trigger
-
- def test_add_media_with_keywords_update(self, db_instance):
- title = "Test Media Update"
- content1 = "Initial content."
- content2 = "Updated content."
- keywords1 = ["update_kw1"]
- keywords2 = ["update_kw2", "update_kw3"]
-
- # Add initial version
- media_id, media_uuid, _ = db_instance.add_media_with_keywords(
- title=title, media_type="text", content=content1, keywords=keywords1
- )
- initial_version = get_entity_version(db_instance, "Media", media_uuid)
-
- # --- FIX: Fetch the hash AFTER creation ---
- cursor_check_initial = db_instance.execute_query("SELECT content_hash FROM Media WHERE id = ?", (media_id,))
- initial_hash_row = cursor_check_initial.fetchone()
- assert initial_hash_row is not None # Ensure fetch worked
- initial_content_hash = initial_hash_row['content_hash']
- # --- End Fix ---
-
- # --- Attempt 1: Update using a generated URL with the initial hash ---
- # (This test might be slightly less relevant if your primary update mechanism
- # relies on UUID or finding via hash internally when URL is None)
- generated_url = f"local://text/{initial_content_hash}"
- media_id_up1, media_uuid_up1, msg1 = db_instance.add_media_with_keywords(
- title=title + " Updated Via URL",
- media_type="text",
- content=content2, # Update content (changes hash)
- keywords=["url_update_kw"],
- overwrite=True,
- url=generated_url # Use the generated URL
- )
-
- # Assertions for the first update attempt (if you keep it)
- assert media_id_up1 == media_id
- assert media_uuid_up1 == media_uuid
- assert f"Media '{title + ' Updated Via URL'}' updated." == msg1
- # Check version incremented after first update
- version_after_update1 = get_entity_version(db_instance, "Media", media_uuid)
- assert version_after_update1 == initial_version + 1
-
- # --- Attempt 2: Simulate finding by hash (URL=None) ---
- # Update again, changing keywords
- media_id_up2, media_uuid_up2, msg2 = db_instance.add_media_with_keywords(
- title=title + " Updated Via Hash", # Change title again
- media_type="text",
- content=content2, # Keep content same as first update
- keywords=keywords2, # Use the final keyword set
- overwrite=True,
- url=None # Force lookup by hash (which is now hash of content2)
- )
-
- # Assertions for the second update attempt
- assert media_id_up2 == media_id
- assert media_uuid_up2 == media_uuid
- assert f"Media '{title + ' Updated Via Hash'}' updated." == msg2
-
- # Verify Final Media DB state
- cursor = db_instance.execute_query("SELECT title, content, version FROM Media WHERE id = ?", (media_id,))
- media_row = cursor.fetchone() # Now media_row is correctly defined for assertions
- assert media_row['title'] == title + " Updated Via Hash"
- assert media_row['content'] == content2
- # Version should have incremented again from the second update
- assert media_row['version'] == version_after_update1 + 1
-
- # Verify Keywords links updated to the final set
- cursor = db_instance.execute_query("""
- SELECT k.keyword
- FROM MediaKeywords mk
- JOIN Keywords k ON mk.keyword_id = k.id
- WHERE mk.media_id = ?
- ORDER BY k.keyword
- """, (media_id,))
- current_keywords = [r['keyword'] for r in cursor.fetchall()]
- assert current_keywords == sorted(keywords2)
-
- # Verify latest DocumentVersion reflects the last content state (content2)
- cursor = db_instance.execute_query(
- "SELECT version_number, content FROM DocumentVersions WHERE media_id = ? ORDER BY version_number DESC LIMIT 1",
- (media_id,))
- version_row = cursor.fetchone()
- # There should be 3 versions now (initial create, update 1, update 2)
- assert version_row['version_number'] == 3
- assert version_row['content'] == content2
-
- # Verify Sync Log for the *last* Media update
- log_entry = get_latest_log(db_instance, media_uuid)
- assert log_entry['operation'] == 'update'
- assert log_entry['entity'] == 'Media'
- assert log_entry['version'] == version_after_update1 + 1
-
- def test_soft_delete_media_cascade(self, db_instance):
- # 1. Setup complex item
- media_id, media_uuid, _ = db_instance.add_media_with_keywords(
- title="Cascade Test", content="Cascade content", media_type="article",
- keywords=["cascade1", "cascade2"], author="Cascade Author"
- )
- # Add a transcript manually (assuming no direct add_transcript method)
- t_uuid = db_instance._generate_uuid()
- db_instance.execute_query(
- """INSERT INTO Transcripts (media_id, whisper_model, transcription, uuid, last_modified, version, client_id, deleted)
- VALUES (?, ?, ?, ?, ?, 1, ?, 0)""",
- (media_id, "model_xyz", "Transcript text", t_uuid, db_instance._get_current_utc_timestamp_str(), db_instance.client_id),
- commit=True
- )
- # Add a chunk manually
- c_uuid = db_instance._generate_uuid()
- db_instance.execute_query(
- """INSERT INTO MediaChunks (media_id, chunk_text, uuid, last_modified, version, client_id, deleted)
- VALUES (?, ?, ?, ?, 1, ?, 0)""",
- (media_id, "Chunk text", c_uuid, db_instance._get_current_utc_timestamp_str(), db_instance.client_id),
- commit=True
- )
- media_version = get_entity_version(db_instance, "Media", media_uuid)
- transcript_version = get_entity_version(db_instance, "Transcripts", t_uuid)
- chunk_version = get_entity_version(db_instance, "MediaChunks", c_uuid)
-
-
- # 2. Perform soft delete with cascade
- deleted = db_instance.soft_delete_media(media_id, cascade=True)
- assert deleted is True
-
- # 3. Verify parent and children are marked deleted and versioned
- cursor = db_instance.execute_query("SELECT deleted, version FROM Media WHERE id = ?", (media_id,))
- assert dict(cursor.fetchone()) == {'deleted': 1, 'version': media_version + 1}
-
- cursor = db_instance.execute_query("SELECT deleted, version FROM Transcripts WHERE uuid = ?", (t_uuid,))
- assert dict(cursor.fetchone()) == {'deleted': 1, 'version': transcript_version + 1}
-
- cursor = db_instance.execute_query("SELECT deleted, version FROM MediaChunks WHERE uuid = ?", (c_uuid,))
- assert dict(cursor.fetchone()) == {'deleted': 1, 'version': chunk_version + 1}
-
- # 4. Verify keywords are unlinked
- cursor = db_instance.execute_query("SELECT COUNT(*) FROM MediaKeywords WHERE media_id = ?", (media_id,))
- assert cursor.fetchone()[0] == 0
-
- # 5. Verify Sync Logs
- media_log = get_latest_log(db_instance, media_uuid)
- assert media_log['operation'] == 'delete'
- assert media_log['version'] == media_version + 1
-
- transcript_log = get_latest_log(db_instance, t_uuid)
- assert transcript_log['operation'] == 'delete'
- assert transcript_log['version'] == transcript_version + 1
-
- chunk_log = get_latest_log(db_instance, c_uuid)
- assert chunk_log['operation'] == 'delete'
- assert chunk_log['version'] == chunk_version + 1
-
- # Check MediaKeywords unlink logs (tricky to get exact UUIDs, check count)
- cursor = db_instance.execute_query("SELECT COUNT(*) FROM sync_log WHERE entity = 'MediaKeywords' AND operation = 'unlink' AND payload LIKE ?", (f'%{media_uuid}%',))
- assert cursor.fetchone()[0] == 2 # Should be 2 unlink events
-
- def test_optimistic_locking_prevents_update_with_stale_version(self, db_instance):
- """Test that an UPDATE with a stale version number fails (rowcount 0)."""
- keyword = "conflict_test"
- kw_id, kw_uuid = db_instance.add_keyword(keyword)
- original_version = get_entity_version(db_instance, "Keywords", kw_uuid) # Should be 1
- assert original_version == 1, "Initial version should be 1"
-
- # Simulate external update incrementing version
- db_instance.execute_query(
- "UPDATE Keywords SET version = ?, client_id = ? WHERE id = ?",
- (original_version + 1, "external_client", kw_id),
- commit=True
- )
- version_after_external_update = get_entity_version(db_instance, "Keywords", kw_uuid) # Should be 2
- assert version_after_external_update == original_version + 1, "Version after external update should be 2"
-
- # Now, manually attempt an update using the *original stale version* (version=1)
- # This mimics what would happen if a process read version 1, then tried
- # to update after the external process bumped it to version 2.
- current_time = db_instance._get_current_utc_timestamp_str()
- client_id = db_instance.client_id
- cursor = db_instance.execute_query(
- "UPDATE Keywords SET keyword='stale_update', last_modified=?, version=?, client_id=? WHERE id=? AND version=?",
- (current_time, original_version + 1, client_id, kw_id, original_version), # <<< WHERE version = 1 (stale)
- commit=True # Commit needed to actually perform the check
- )
-
- # Assert that the update failed because the WHERE clause (version=1) didn't match any rows
- assert cursor.rowcount == 0, "Update with stale version should affect 0 rows"
-
- # Verify DB state is unchanged by the failed update (still shows external update's state)
- cursor_check = db_instance.execute_query("SELECT keyword, version, client_id FROM Keywords WHERE id = ?",
- (kw_id,))
- row = cursor_check.fetchone()
- assert row is not None, "Keyword should still exist"
- assert row['keyword'] == keyword, "Keyword text should not have changed to 'stale_update'"
- assert row['version'] == original_version + 1, "Version should remain 2 from the external update"
- assert row['client_id'] == "external_client", "Client ID should remain from the external update"
-
- def test_version_validation_trigger(self, db_instance):
- """Test trigger preventing non-sequential version updates."""
- kw_id, kw_uuid = db_instance.add_keyword("validation_test")
- current_version = get_entity_version(db_instance, "Keywords", kw_uuid)
-
- # Try to update version incorrectly (skipping a version)
- with pytest.raises(sqlite3.IntegrityError, match="Sync Error \(Keywords\): Version must increment by exactly 1"):
- db_instance.execute_query(
- "UPDATE Keywords SET version = ?, keyword = ? WHERE id = ?",
- (current_version + 2, "bad version", kw_id),
- commit=True
- )
-
- # Try to update version incorrectly (same version)
- with pytest.raises(sqlite3.IntegrityError, match="Sync Error \(Keywords\): Version must increment by exactly 1"):
- db_instance.execute_query(
- "UPDATE Keywords SET version = ?, keyword = ? WHERE id = ?",
- (current_version, "same version", kw_id),
- commit=True
- )
-
- def test_client_id_validation_trigger(self, db_instance):
- """Test trigger preventing null/empty client_id on update."""
- kw_id, kw_uuid = db_instance.add_keyword("clientid_test")
- current_version = get_entity_version(db_instance, "Keywords", kw_uuid)
-
- # Try to update with NULL client_id
- with pytest.raises(sqlite3.IntegrityError, match="Sync Error \(Keywords\): Client ID cannot be NULL or empty"):
- db_instance.execute_query(
- "UPDATE Keywords SET version = ?, client_id = NULL WHERE id = ?",
- (current_version + 1, kw_id),
- commit=True
- )
-
- # Try to update with empty client_id
- with pytest.raises(sqlite3.IntegrityError, match="Sync Error \(Keywords\): Client ID cannot be NULL or empty"):
- db_instance.execute_query(
- "UPDATE Keywords SET version = ?, client_id = '' WHERE id = ?",
- (current_version + 1, kw_id),
- commit=True
- )
-
-
-class TestSyncLogManagement:
-
- @pytest.fixture
- def db_instance(self, memory_db_factory):
- db = memory_db_factory("log_client")
- # Add some initial data to generate logs
- db.add_keyword("log_kw_1")
- time.sleep(0.01) # Ensure timestamp difference
- db.add_keyword("log_kw_2")
- time.sleep(0.01)
- db.add_keyword("log_kw_3")
- db.soft_delete_keyword("log_kw_2")
- return db
-
- def test_get_sync_log_entries_all(self, db_instance):
- logs = db_instance.get_sync_log_entries()
- # Expect 3 creates + 1 delete = 4 entries
- assert len(logs) == 4
- assert logs[0]['change_id'] == 1
- assert logs[-1]['change_id'] == 4
-
- def test_get_sync_log_entries_since(self, db_instance):
- logs = db_instance.get_sync_log_entries(since_change_id=2) # Get 3 and 4
- assert len(logs) == 2
- assert logs[0]['change_id'] == 3
- assert logs[1]['change_id'] == 4
-
- def test_get_sync_log_entries_limit(self, db_instance):
- logs = db_instance.get_sync_log_entries(limit=2) # Get 1 and 2
- assert len(logs) == 2
- assert logs[0]['change_id'] == 1
- assert logs[1]['change_id'] == 2
-
- def test_get_sync_log_entries_since_and_limit(self, db_instance):
- logs = db_instance.get_sync_log_entries(since_change_id=1, limit=2) # Get 2 and 3
- assert len(logs) == 2
- assert logs[0]['change_id'] == 2
- assert logs[1]['change_id'] == 3
-
- def test_delete_sync_log_entries_specific(self, db_instance):
- initial_logs = db_instance.get_sync_log_entries()
- initial_count = len(initial_logs) # Should be 4
- ids_to_delete = [initial_logs[1]['change_id'], initial_logs[2]['change_id']] # Delete 2 and 3
-
- deleted_count = db_instance.delete_sync_log_entries(ids_to_delete)
- assert deleted_count == 2
-
- remaining_logs = db_instance.get_sync_log_entries()
- assert len(remaining_logs) == initial_count - 2
- remaining_ids = {log['change_id'] for log in remaining_logs}
- assert remaining_ids == {initial_logs[0]['change_id'], initial_logs[3]['change_id']} # 1 and 4 should remain
-
- def test_delete_sync_log_entries_before(self, db_instance):
- initial_logs = db_instance.get_sync_log_entries()
- initial_count = len(initial_logs) # Should be 4
- threshold_id = initial_logs[2]['change_id'] # Delete up to and including ID 3
-
- deleted_count = db_instance.delete_sync_log_entries_before(threshold_id)
- assert deleted_count == 3 # Deleted 1, 2, 3
-
- remaining_logs = db_instance.get_sync_log_entries()
- assert len(remaining_logs) == 1
- assert remaining_logs[0]['change_id'] == initial_logs[3]['change_id'] # Only 4 remains
-
- def test_delete_sync_log_entries_empty(self, db_instance):
- deleted_count = db_instance.delete_sync_log_entries([])
- assert deleted_count == 0
-
- def test_delete_sync_log_entries_invalid_id(self, db_instance):
- with pytest.raises(ValueError):
- db_instance.delete_sync_log_entries([1, "two", 3])
\ No newline at end of file
diff --git a/Tests/MediaDB2/__init__.py b/Tests/Event_Handlers/Chat_Events/__init__.py
similarity index 100%
rename from Tests/MediaDB2/__init__.py
rename to Tests/Event_Handlers/Chat_Events/__init__.py
diff --git a/Tests/Event_Handlers/Chat_Events/test_chat_events.py b/Tests/Event_Handlers/Chat_Events/test_chat_events.py
new file mode 100644
index 00000000..9a2b3433
--- /dev/null
+++ b/Tests/Event_Handlers/Chat_Events/test_chat_events.py
@@ -0,0 +1,301 @@
+# /tests/Event_Handlers/Chat_Events/test_chat_events.py
+
+import pytest
+from unittest.mock import AsyncMock, MagicMock, patch, call
+
+from rich.text import Text
+# Mock Textual UI elements before they are imported by the module under test
+from textual.widgets import (
+ Button, Input, TextArea, Static, Select, Checkbox, ListView, ListItem, Label
+)
+from textual.containers import VerticalScroll
+from textual.css.query import QueryError
+
+# Mock DB Errors
+from tldw_chatbook.DB.ChaChaNotes_DB import ConflictError, CharactersRAGDBError, InputError
+
+# Functions to test
+from tldw_chatbook.Event_Handlers.Chat_Events.chat_events import (
+ handle_chat_send_button_pressed,
+ handle_chat_action_button_pressed,
+ handle_chat_new_conversation_button_pressed,
+ handle_chat_save_current_chat_button_pressed,
+ handle_chat_load_character_button_pressed,
+ handle_chat_clear_active_character_button_pressed,
+ # ... import other handlers as you write tests for them
+)
+from tldw_chatbook.Widgets.chat_message import ChatMessage
+
+pytestmark = pytest.mark.asyncio
+
+
+# A very comprehensive mock app fixture is needed here
+@pytest.fixture
+def mock_app():
+ app = AsyncMock()
+
+ # Mock services and DBs
+ app.chachanotes_db = MagicMock()
+ app.notes_service = MagicMock()
+ app.notes_service._get_db.return_value = app.chachanotes_db
+ app.media_db = MagicMock()
+
+ # Mock core app properties
+ app.API_IMPORTS_SUCCESSFUL = True
+ app.app_config = {
+ "api_settings": {
+ "openai": {"streaming": True, "api_key_env_var": "OPENAI_API_KEY"},
+ "anthropic": {"streaming": False, "api_key": "xyz-key"}
+ },
+ "chat_defaults": {"system_prompt": "Default system prompt."},
+ "USERS_NAME": "Tester"
+ }
+
+ # Mock app state
+ app.current_chat_conversation_id = None
+ app.current_chat_is_ephemeral = True
+ app.current_chat_active_character_data = None
+ app.current_ai_message_widget = None
+
+ # Mock app methods
+ app.query_one = MagicMock()
+ app.notify = AsyncMock()
+ app.copy_to_clipboard = MagicMock()
+ app.set_timer = MagicMock()
+ app.run_worker = MagicMock()
+ app.chat_wrapper = AsyncMock()
+
+ # Timers
+ app._conversation_search_timer = None
+
+ # --- Set up mock widgets ---
+ # This is complex; a helper function simplifies it.
+ def setup_mock_widgets(q_one_mock):
+ widgets = {
+ "#chat-input": MagicMock(spec=TextArea, text="User message", is_mounted=True),
+ "#chat-log": AsyncMock(spec=VerticalScroll, is_mounted=True),
+ "#chat-api-provider": MagicMock(spec=Select, value="OpenAI"),
+ "#chat-api-model": MagicMock(spec=Select, value="gpt-4"),
+ "#chat-system-prompt": MagicMock(spec=TextArea, text="UI system prompt"),
+ "#chat-temperature": MagicMock(spec=Input, value="0.7"),
+ "#chat-top-p": MagicMock(spec=Input, value="0.9"),
+ "#chat-min-p": MagicMock(spec=Input, value="0.1"),
+ "#chat-top-k": MagicMock(spec=Input, value="40"),
+ "#chat-llm-max-tokens": MagicMock(spec=Input, value="1024"),
+ "#chat-llm-seed": MagicMock(spec=Input, value=""),
+ "#chat-llm-stop": MagicMock(spec=Input, value=""),
+ "#chat-llm-response-format": MagicMock(spec=Select, value="text"),
+ "#chat-llm-n": MagicMock(spec=Input, value="1"),
+ "#chat-llm-user-identifier": MagicMock(spec=Input, value=""),
+ "#chat-llm-logprobs": MagicMock(spec=Checkbox, value=False),
+ "#chat-llm-top-logprobs": MagicMock(spec=Input, value=""),
+ "#chat-llm-logit-bias": MagicMock(spec=TextArea, text="{}"),
+ "#chat-llm-presence-penalty": MagicMock(spec=Input, value="0.0"),
+ "#chat-llm-frequency-penalty": MagicMock(spec=Input, value="0.0"),
+ "#chat-llm-tools": MagicMock(spec=TextArea, text="[]"),
+ "#chat-llm-tool-choice": MagicMock(spec=Input, value=""),
+ "#chat-llm-fixed-tokens-kobold": MagicMock(spec=Checkbox, value=False),
+ "#chat-strip-thinking-tags-checkbox": MagicMock(spec=Checkbox, value=True),
+ "#chat-character-search-results-list": AsyncMock(spec=ListView),
+ "#chat-character-name-edit": MagicMock(spec=Input),
+ "#chat-character-description-edit": MagicMock(spec=TextArea),
+ "#chat-character-personality-edit": MagicMock(spec=TextArea),
+ "#chat-character-scenario-edit": MagicMock(spec=TextArea),
+ "#chat-character-system-prompt-edit": MagicMock(spec=TextArea),
+ "#chat-character-first-message-edit": MagicMock(spec=TextArea),
+ "#chat-right-sidebar": MagicMock(), # Mock container
+ }
+
+ def query_one_side_effect(selector, _type=None):
+ # Special case for querying within the sidebar
+ if isinstance(selector, MagicMock) and hasattr(selector, 'query_one'):
+ return selector.query_one(selector, _type)
+
+ if selector in widgets:
+ return widgets[selector]
+
+ # Allow querying for sub-widgets inside a container like the right sidebar
+ if widgets["#chat-right-sidebar"].query_one.call_args:
+ inner_selector = widgets["#chat-right-sidebar"].query_one.call_args[0][0]
+ if inner_selector in widgets:
+ return widgets[inner_selector]
+
+ raise QueryError(f"Widget not found by mock: {selector}")
+
+ q_one_mock.side_effect = query_one_side_effect
+
+ # Make the sidebar mock also use the main query_one logic
+ widgets["#chat-right-sidebar"].query_one.side_effect = lambda sel, _type: widgets[sel]
+
+ setup_mock_widgets(app.query_one)
+
+ return app
+
+
+# Mock external dependencies used in chat_events.py
+@patch('tldw_chatbook.Event_Handlers.Chat_Events.chat_events.ccl')
+@patch('tldw_chatbook.Event_Handlers.Chat_Events.chat_events.os')
+@patch('tldw_chatbook.Event_Handlers.Chat_Events.chat_events.ChatMessage', new_callable=AsyncMock)
+async def test_handle_chat_send_button_pressed_basic(mock_chat_message_class, mock_os, mock_ccl, mock_app):
+ """Test a basic message send operation."""
+ mock_os.environ.get.return_value = "fake-key"
+
+ await handle_chat_send_button_pressed(mock_app, MagicMock())
+
+ # Assert UI updates
+ mock_app.query_one("#chat-input").clear.assert_called_once()
+ mock_app.query_one("#chat-log").mount.assert_any_call(mock_chat_message_class.return_value) # Mounts user message
+ mock_app.query_one("#chat-log").mount.assert_any_call(mock_app.current_ai_message_widget) # Mounts AI placeholder
+
+ # Assert worker is called
+ mock_app.run_worker.assert_called_once()
+
+ # Assert chat_wrapper is called with correct parameters by the worker
+ worker_lambda = mock_app.run_worker.call_args[0][0]
+ worker_lambda() # Execute the lambda to trigger the call to chat_wrapper
+
+ mock_app.chat_wrapper.assert_called_once()
+ wrapper_kwargs = mock_app.chat_wrapper.call_args.kwargs
+ assert wrapper_kwargs['message'] == "User message"
+ assert wrapper_kwargs['api_endpoint'] == "OpenAI"
+ assert wrapper_kwargs['api_key'] == "fake-key"
+ assert wrapper_kwargs['system_message'] == "UI system prompt"
+ assert wrapper_kwargs['streaming'] is True # From config
+
+
+@patch('tldw_chatbook.Event_Handlers.Chat_Events.chat_events.ccl')
+@patch('tldw_chatbook.Event_Handlers.Chat_Events.chat_events.os')
+@patch('tldw_chatbook.Event_Handlers.Chat_Events.chat_events.ChatMessage', new_callable=AsyncMock)
+async def test_handle_chat_send_with_active_character(mock_chat_message_class, mock_os, mock_ccl, mock_app):
+ """Test that an active character's system prompt overrides the UI."""
+ mock_os.environ.get.return_value = "fake-key"
+ mock_app.current_chat_active_character_data = {
+ 'name': 'TestChar',
+ 'system_prompt': 'You are TestChar.'
+ }
+
+ await handle_chat_send_button_pressed(mock_app, MagicMock())
+
+ worker_lambda = mock_app.run_worker.call_args[0][0]
+ worker_lambda()
+
+ wrapper_kwargs = mock_app.chat_wrapper.call_args.kwargs
+ assert wrapper_kwargs['system_message'] == "You are TestChar."
+
+
+async def test_handle_new_conversation_button_pressed(mock_app):
+ """Test that the new chat button clears state and UI."""
+ # Set some state to ensure it's cleared
+ mock_app.current_chat_conversation_id = "conv_123"
+ mock_app.current_chat_is_ephemeral = False
+ mock_app.current_chat_active_character_data = {'name': 'char'}
+
+ await handle_chat_new_conversation_button_pressed(mock_app, MagicMock())
+
+ mock_app.query_one("#chat-log").remove_children.assert_called_once()
+ assert mock_app.current_chat_conversation_id is None
+ assert mock_app.current_chat_is_ephemeral is True
+ assert mock_app.current_chat_active_character_data is None
+ # Check that a UI field was reset
+ assert mock_app.query_one("#chat-system-prompt").text == "Default system prompt."
+
+
+@patch('tldw_chatbook.Event_Handlers.Chat_Events.chat_events.ccl')
+@patch('tldw_chatbook.Event_Handlers.Chat_Events.chat_events.display_conversation_in_chat_tab_ui',
+ new_callable=AsyncMock)
+async def test_handle_save_current_chat_button_pressed(mock_display_conv, mock_ccl, mock_app):
+ """Test saving an ephemeral chat."""
+ mock_app.current_chat_is_ephemeral = True
+ mock_app.current_chat_conversation_id = None
+
+ # Setup mock messages in the chat log
+ mock_msg1 = MagicMock(spec=ChatMessage, role="User", message_text="Hello", generation_complete=True,
+ image_data=None, image_mime_type=None)
+ mock_msg2 = MagicMock(spec=ChatMessage, role="AI", message_text="Hi", generation_complete=True, image_data=None,
+ image_mime_type=None)
+ mock_app.query_one("#chat-log").query.return_value = [mock_msg1, mock_msg2]
+
+ mock_ccl.create_conversation.return_value = "new_conv_id"
+
+ await handle_chat_save_current_chat_button_pressed(mock_app, MagicMock())
+
+ mock_ccl.create_conversation.assert_called_once()
+ create_kwargs = mock_ccl.create_conversation.call_args.kwargs
+ assert create_kwargs['title'].startswith("Chat: Hello...")
+ assert len(create_kwargs['initial_messages']) == 2
+ assert create_kwargs['initial_messages'][0]['content'] == "Hello"
+
+ assert mock_app.current_chat_conversation_id == "new_conv_id"
+ assert mock_app.current_chat_is_ephemeral is False
+ mock_app.notify.assert_called_with("Chat saved successfully!", severity="information")
+ mock_display_conv.assert_called_once_with(mock_app, "new_conv_id")
+
+
+@patch('tldw_chatbook.Event_Handlers.Chat_Events.chat_events.ccl')
+async def test_handle_chat_action_button_pressed_edit_and_save(mock_ccl, mock_app):
+ """Test the edit->save workflow for a chat message."""
+ mock_button = MagicMock(spec=Button, classes=["edit-button"])
+ mock_action_widget = AsyncMock(spec=ChatMessage)
+ mock_action_widget.message_text = "Original text"
+ mock_action_widget.message_id_internal = "msg_123"
+ mock_action_widget.message_version_internal = 0
+ mock_action_widget._editing = False # Start in non-editing mode
+ mock_static_text = mock_action_widget.query_one.return_value
+
+ # --- 1. First press: Start editing ---
+ await handle_chat_action_button_pressed(mock_app, mock_button, mock_action_widget)
+
+ mock_action_widget.mount.assert_called_once() # Mounts the TextArea
+ assert mock_action_widget._editing is True
+ assert "💾" in mock_button.label # Check for save emoji
+
+ # --- 2. Second press: Save edit ---
+ mock_action_widget._editing = True # Simulate being in editing mode
+ mock_edit_area = MagicMock(spec=TextArea, text="New edited text")
+ mock_action_widget.query_one.return_value = mock_edit_area
+ mock_ccl.edit_message_content.return_value = True
+
+ await handle_chat_action_button_pressed(mock_app, mock_button, mock_action_widget)
+
+ mock_edit_area.remove.assert_called_once()
+ assert mock_action_widget.message_text == "New edited text"
+ assert isinstance(mock_static_text.update.call_args[0][0], Text)
+ assert mock_static_text.update.call_args[0][0].plain == "New edited text"
+
+ mock_ccl.edit_message_content.assert_called_with(
+ mock_app.chachanotes_db, "msg_123", "New edited text", 0
+ )
+ assert mock_action_widget.message_version_internal == 1 # Version incremented
+ assert "✏️" in mock_button.label # Check for edit emoji
+
+
+@patch('tldw_chatbook.Event_Handlers.Chat_Events.chat_events.load_character_and_image')
+async def test_handle_chat_load_character_with_greeting(mock_load_char, mock_app):
+ """Test that loading a character into an empty, ephemeral chat posts a greeting."""
+ mock_app.current_chat_is_ephemeral = True
+ mock_app.query_one("#chat-log").query.return_value = [] # Empty chat log
+
+ char_data = {
+ 'id': 'char_abc', 'name': 'Greeter', 'first_message': 'Hello, adventurer!'
+ }
+ mock_load_char.return_value = (char_data, None, None)
+
+ # Mock the list item from the character search list
+ mock_list_item = MagicMock(spec=ListItem)
+ mock_list_item.character_id = 'char_abc'
+ mock_app.query_one("#chat-character-search-results-list").highlighted_child = mock_list_item
+
+ with patch('tldw_chatbook.Event_Handlers.Chat_Events.chat_events.ChatMessage',
+ new_callable=AsyncMock) as mock_chat_msg_class:
+ await handle_chat_load_character_button_pressed(mock_app, MagicMock())
+
+ # Assert character data is loaded
+ assert mock_app.current_chat_active_character_data == char_data
+
+ # Assert greeting message was created and mounted
+ mock_chat_msg_class.assert_called_with(
+ message='Hello, adventurer!',
+ role='Greeter',
+ generation_complete=True
+ )
+ mock_app.query_one("#chat-log").mount.assert_called_once_with(mock_chat_msg_class.return_value)
\ No newline at end of file
diff --git a/Tests/Event_Handlers/Chat_Events/test_chat_events_sidebar.py b/Tests/Event_Handlers/Chat_Events/test_chat_events_sidebar.py
new file mode 100644
index 00000000..6cec77c1
--- /dev/null
+++ b/Tests/Event_Handlers/Chat_Events/test_chat_events_sidebar.py
@@ -0,0 +1,227 @@
+# /tests/Event_Handlers/Chat_Events/test_chat_events_sidebar.py
+
+import pytest
+from unittest.mock import AsyncMock, MagicMock, patch
+
+from textual.widgets import Button, Input, ListView, TextArea, ListItem, Label
+from textual.css.query import QueryError
+
+# Functions to test
+from tldw_chatbook.Event_Handlers.Chat_Events.chat_events_sidebar import (
+ _disable_media_copy_buttons,
+ perform_media_sidebar_search,
+ handle_chat_media_search_input_changed,
+ handle_chat_media_load_selected_button_pressed,
+ handle_chat_media_copy_title_button_pressed,
+ handle_chat_media_copy_content_button_pressed,
+ handle_chat_media_copy_author_button_pressed,
+ handle_chat_media_copy_url_button_pressed,
+)
+
+pytestmark = pytest.mark.asyncio
+
+
+@pytest.fixture
+def mock_app():
+ """Provides a comprehensive mock of the TldwCli app."""
+ app = AsyncMock()
+
+ # Mock UI components
+ app.query_one = MagicMock()
+ mock_results_list = AsyncMock(spec=ListView)
+ mock_review_display = AsyncMock(spec=TextArea)
+ mock_copy_title_btn = MagicMock(spec=Button)
+ mock_copy_content_btn = MagicMock(spec=Button)
+ mock_copy_author_btn = MagicMock(spec=Button)
+ mock_copy_url_btn = MagicMock(spec=Button)
+ mock_search_input = MagicMock(spec=Input)
+
+ # Configure query_one to return the correct mock widget
+ def query_one_side_effect(selector, _type):
+ if selector == "#chat-media-search-results-listview":
+ return mock_results_list
+ if selector == "#chat-media-content-display":
+ return mock_review_display
+ if selector == "#chat-media-copy-title-button":
+ return mock_copy_title_btn
+ if selector == "#chat-media-copy-content-button":
+ return mock_copy_content_btn
+ if selector == "#chat-media-copy-author-button":
+ return mock_copy_author_btn
+ if selector == "#chat-media-copy-url-button":
+ return mock_copy_url_btn
+ if selector == "#chat-media-search-input":
+ return mock_search_input
+ raise QueryError(f"Widget not found: {selector}")
+
+ app.query_one.side_effect = query_one_side_effect
+
+ # Mock DB and state
+ app.media_db = MagicMock()
+ app.current_sidebar_media_item = None
+
+ # Mock app methods
+ app.notify = AsyncMock()
+ app.copy_to_clipboard = MagicMock()
+ app.set_timer = MagicMock()
+ app.run_worker = MagicMock()
+
+ # For debouncing timer
+ app._media_sidebar_search_timer = None
+
+ return app
+
+
+async def test_disable_media_copy_buttons(mock_app):
+ """Test that all copy buttons are disabled and the current item is cleared."""
+ await _disable_media_copy_buttons(mock_app)
+
+ assert mock_app.current_sidebar_media_item is None
+ assert mock_app.query_one("#chat-media-copy-title-button", Button).disabled is True
+ assert mock_app.query_one("#chat-media-copy-content-button", Button).disabled is True
+ assert mock_app.query_one("#chat-media-copy-author-button", Button).disabled is True
+ assert mock_app.query_one("#chat-media-copy-url-button", Button).disabled is True
+
+
+async def test_perform_media_sidebar_search_with_results(mock_app):
+ """Test searching with a term that returns results."""
+ mock_media_items = [
+ {'title': 'Test Title 1', 'media_id': 'id12345678'},
+ {'title': 'Test Title 2', 'media_id': 'id87654321'},
+ ]
+ mock_app.media_db.search_media_db.return_value = mock_media_items
+ mock_results_list = mock_app.query_one("#chat-media-search-results-listview", ListView)
+
+ with patch('tldw_chatbook.Event_Handlers.Chat_Events.chat_events_sidebar.ListItem',
+ side_effect=ListItem) as mock_list_item_class:
+ await perform_media_sidebar_search(mock_app, "test term")
+
+ mock_results_list.clear.assert_called_once()
+ mock_app.query_one("#chat-media-content-display", TextArea).clear.assert_called_once()
+ mock_app.media_db.search_media_db.assert_called_once()
+ assert mock_results_list.append.call_count == 2
+
+ # Check that ListItem was called with a Label containing the correct text
+ first_call_args = mock_list_item_class.call_args_list[0].args
+ assert isinstance(first_call_args[0], Label)
+ assert "Test Title 1" in first_call_args[0].renderable
+
+
+async def test_perform_media_sidebar_search_no_results(mock_app):
+ """Test searching with a term that returns no results."""
+ mock_app.media_db.search_media_db.return_value = []
+ mock_results_list = mock_app.query_one("#chat-media-search-results-listview", ListView)
+
+ await perform_media_sidebar_search(mock_app, "no results term")
+
+ mock_results_list.append.assert_called_once()
+ # The call argument is a ListItem, which contains a Label. We check the Label's content.
+ call_arg = mock_results_list.append.call_args[0][0]
+ assert isinstance(call_arg, ListItem)
+ assert call_arg.children[0].renderable == "No media found."
+
+
+async def test_perform_media_sidebar_search_empty_term(mock_app):
+ """Test that an empty search term clears results and does not search."""
+ await perform_media_sidebar_search(mock_app, "")
+ mock_app.media_db.search_media_db.assert_not_called()
+ mock_app.query_one("#chat-media-search-results-listview", ListView).clear.assert_called_once()
+
+
+async def test_handle_chat_media_search_input_changed_debouncing(mock_app):
+ """Test that input changes are debounced via set_timer."""
+ mock_timer = MagicMock()
+ mock_app._media_sidebar_search_timer = mock_timer
+ mock_input_widget = MagicMock(spec=Input, value=" new search ")
+
+ await handle_chat_media_search_input_changed(mock_app, mock_input_widget)
+
+ mock_timer.stop.assert_called_once()
+ mock_app.set_timer.assert_called_once()
+ # Check that run_worker is part of the callback, which calls perform_media_sidebar_search
+ callback_lambda = mock_app.set_timer.call_args[0][1]
+ # We can't easily execute the lambda here, but we can verify it's set.
+ assert callable(callback_lambda)
+
+
+async def test_handle_chat_media_load_selected_button_pressed(mock_app):
+ """Test loading a selected media item into the display."""
+ media_data = {
+ 'title': 'Loaded Title', 'author': 'Author Name', 'media_type': 'Article',
+ 'url': 'http://example.com', 'content': 'This is the full content.'
+ }
+ mock_list_item = MagicMock(spec=ListItem)
+ mock_list_item.media_data = media_data
+
+ mock_results_list = mock_app.query_one("#chat-media-search-results-listview", ListView)
+ mock_results_list.highlighted_child = mock_list_item
+
+ await handle_chat_media_load_selected_button_pressed(mock_app, MagicMock())
+
+ assert mock_app.current_sidebar_media_item == media_data
+ mock_app.query_one("#chat-media-content-display", TextArea).load_text.assert_called_once()
+ loaded_text = mock_app.query_one("#chat-media-content-display", TextArea).load_text.call_args[0][0]
+ assert "Title: Loaded Title" in loaded_text
+ assert "Author: Author Name" in loaded_text
+ assert "This is the full content." in loaded_text
+
+ assert mock_app.query_one("#chat-media-copy-title-button", Button).disabled is False
+
+
+async def test_handle_chat_media_load_selected_no_selection(mock_app):
+ """Test load button when nothing is selected."""
+ mock_results_list = mock_app.query_one("#chat-media-search-results-listview", ListView)
+ mock_results_list.highlighted_child = None
+
+ await handle_chat_media_load_selected_button_pressed(mock_app, MagicMock())
+
+ mock_app.notify.assert_called_with("No media item selected.", severity="warning")
+ mock_app.query_one("#chat-media-content-display", TextArea).clear.assert_called_once()
+ assert mock_app.query_one("#chat-media-copy-title-button", Button).disabled is True
+
+
+async def test_handle_copy_buttons_with_data(mock_app):
+ """Test all copy buttons when data is available."""
+ media_data = {'title': 'Copy Title', 'content': 'Copy Content', 'author': 'Copy Author', 'url': 'http://copy.url'}
+ mock_app.current_sidebar_media_item = media_data
+
+ # Test copy title
+ await handle_chat_media_copy_title_button_pressed(mock_app, MagicMock())
+ mock_app.copy_to_clipboard.assert_called_with('Copy Title')
+ mock_app.notify.assert_called_with("Title copied to clipboard.")
+
+ # Test copy content
+ await handle_chat_media_copy_content_button_pressed(mock_app, MagicMock())
+ mock_app.copy_to_clipboard.assert_called_with('Copy Content')
+ mock_app.notify.assert_called_with("Content copied to clipboard.")
+
+ # Test copy author
+ await handle_chat_media_copy_author_button_pressed(mock_app, MagicMock())
+ mock_app.copy_to_clipboard.assert_called_with('Copy Author')
+ mock_app.notify.assert_called_with("Author copied to clipboard.")
+
+ # Test copy URL
+ await handle_chat_media_copy_url_button_pressed(mock_app, MagicMock())
+ mock_app.copy_to_clipboard.assert_called_with('http://copy.url')
+ mock_app.notify.assert_called_with("URL copied to clipboard.")
+
+
+async def test_handle_copy_buttons_no_data(mock_app):
+ """Test copy buttons when data is not available."""
+ mock_app.current_sidebar_media_item = None
+
+ # Test copy title
+ await handle_chat_media_copy_title_button_pressed(mock_app, MagicMock())
+ mock_app.notify.assert_called_with("No media title to copy.", severity="warning")
+
+ # Test copy content
+ await handle_chat_media_copy_content_button_pressed(mock_app, MagicMock())
+ mock_app.notify.assert_called_with("No media content to copy.", severity="warning")
+
+ # Test copy author
+ await handle_chat_media_copy_author_button_pressed(mock_app, MagicMock())
+ mock_app.notify.assert_called_with("No media author to copy.", severity="warning")
+
+ # Test copy URL
+ await handle_chat_media_copy_url_button_pressed(mock_app, MagicMock())
+ mock_app.notify.assert_called_with("No media URL to copy.", severity="warning")
\ No newline at end of file
diff --git a/Tests/Event_Handlers/Chat_Events/test_chat_streaming_events.py b/Tests/Event_Handlers/Chat_Events/test_chat_streaming_events.py
new file mode 100644
index 00000000..39fe0f46
--- /dev/null
+++ b/Tests/Event_Handlers/Chat_Events/test_chat_streaming_events.py
@@ -0,0 +1,158 @@
+import pytest
+from unittest.mock import AsyncMock, MagicMock, patch
+
+from rich.text import Text
+from textual.containers import VerticalScroll
+from textual.widgets import Static, TextArea
+
+from tldw_chatbook.Event_Handlers.worker_events import StreamingChunk, StreamDone
+from tldw_chatbook.Constants import TAB_CHAT, TAB_CCP
+
+# Functions to test (they are methods on the app, so we test them as such)
+from tldw_chatbook.Event_Handlers.Chat_Events.chat_streaming_events import (
+ handle_streaming_chunk,
+ handle_stream_done
+)
+
+pytestmark = pytest.mark.asyncio
+
+
+@pytest.fixture
+def mock_app():
+ """Provides a mock app instance ('self' for the handlers)."""
+ app = AsyncMock()
+
+ # Mock logger and config
+ app.loguru_logger = MagicMock()
+ app.app_config = {"chat_defaults": {"strip_thinking_tags": True}}
+
+ # Mock UI state and components
+ app.current_tab = TAB_CHAT
+ mock_static_text = AsyncMock(spec=Static)
+ mock_chat_message_widget = AsyncMock()
+ mock_chat_message_widget.is_mounted = True
+ mock_chat_message_widget.message_text = ""
+ mock_chat_message_widget.query_one.return_value = mock_static_text
+ app.current_ai_message_widget = mock_chat_message_widget
+
+ mock_chat_log = AsyncMock(spec=VerticalScroll)
+ mock_chat_input = AsyncMock(spec=TextArea)
+
+ app.query_one = MagicMock(side_effect=lambda sel, type: mock_chat_log if sel == "#chat-log" else mock_chat_input)
+
+ # Mock DB and state
+ app.chachanotes_db = MagicMock()
+ app.current_chat_conversation_id = "conv_123"
+ app.current_chat_is_ephemeral = False
+
+ # Mock app methods
+ app.notify = AsyncMock()
+
+ return app
+
+
+async def test_handle_streaming_chunk_appends_text(mock_app):
+ """Test that a streaming chunk appends text and updates the widget."""
+ event = StreamingChunk(text_chunk="Hello, ")
+ mock_app.current_ai_message_widget.message_text = "Initial."
+
+ await handle_streaming_chunk(mock_app, event)
+
+ assert mock_app.current_ai_message_widget.message_text == "Initial.Hello, "
+
+ # Check that update is called with the full, escaped text
+ with patch('tldw_chatbook.Event_Handlers.Chat_Events.chat_streaming_events.escape_markup',
+ return_value="Escaped: Initial.Hello, ") as mock_escape:
+ await handle_streaming_chunk(mock_app, event)
+ mock_escape.assert_called_with("Initial.Hello, Hello, ")
+ mock_app.current_ai_message_widget.query_one().update.assert_called_with("Escaped: Initial.Hello, ")
+
+ # Check that scroll_end is called
+ mock_app.query_one.assert_called_with("#chat-log", VerticalScroll)
+ mock_app.query_one().scroll_end.assert_called()
+
+
+async def test_handle_stream_done_success_and_save(mock_app):
+ """Test successful stream completion and saving to DB."""
+ event = StreamDone(full_text="This is the final response.", error=None)
+ mock_app.current_ai_message_widget.role = "AI"
+
+ with patch('tldw_chatbook.Event_Handlers.Chat_Events.chat_streaming_events.ccl') as mock_ccl:
+ # Mock DB returns for getting the saved message details
+ mock_ccl.add_message_to_conversation.return_value = "msg_abc"
+ mock_app.chachanotes_db.get_message_by_id.return_value = {'id': 'msg_abc', 'version': 0}
+
+ await handle_stream_done(mock_app, event)
+
+ # Assert UI update
+ mock_app.current_ai_message_widget.query_one().update.assert_called_with("This is the final response.")
+ mock_app.current_ai_message_widget.mark_generation_complete.assert_called_once()
+
+ # Assert DB call
+ mock_ccl.add_message_to_conversation.assert_called_once_with(
+ mock_app.chachanotes_db, "conv_123", "AI", "This is the final response."
+ )
+ assert mock_app.current_ai_message_widget.message_id_internal == 'msg_abc'
+ assert mock_app.current_ai_message_widget.message_version_internal == 0
+
+ # Assert state reset
+ assert mock_app.current_ai_message_widget is None
+ mock_app.query_one().focus.assert_called_once()
+
+
+async def test_handle_stream_done_with_tag_stripping(mock_app):
+ """Test that tags are stripped from the final text before saving."""
+ full_text = "I should start.This is the actual response.I am done now."
+ expected_text = "This is the actual response."
+ event = StreamDone(full_text=full_text, error=None)
+ mock_app.app_config["chat_defaults"]["strip_thinking_tags"] = True
+
+ with patch('tldw_chatbook.Event_Handlers.Chat_Events.chat_streaming_events.ccl') as mock_ccl:
+ await handle_stream_done(mock_app, event)
+
+ # Check that the saved text is the stripped version
+ mock_ccl.add_message_to_conversation.assert_called_once()
+ saved_text = mock_ccl.add_message_to_conversation.call_args[0][3]
+ assert saved_text == expected_text
+
+ # Check that the displayed text is also the stripped version (escaped)
+ mock_app.current_ai_message_widget.query_one().update.assert_called_with(expected_text)
+
+
+async def test_handle_stream_done_with_error(mock_app):
+ """Test stream completion when an error occurred."""
+ event = StreamDone(full_text="Partial response.", error="API limit reached")
+
+ with patch('tldw_chatbook.Event_Handlers.Chat_Events.chat_streaming_events.ccl') as mock_ccl:
+ await handle_stream_done(mock_app, event)
+
+ # Assert UI is updated with error message
+ mock_static_widget = mock_app.current_ai_message_widget.query_one()
+ mock_static_widget.update.assert_called_once()
+ update_call_arg = mock_static_widget.update.call_args[0][0]
+ assert isinstance(update_call_arg, Text)
+ assert "Partial response." in update_call_arg.plain
+ assert "Stream Error" in update_call_arg.plain
+ assert "API limit reached" in update_call_arg.plain
+
+ # Assert role is changed and DB is NOT called
+ assert mock_app.current_ai_message_widget.role == "System"
+ mock_ccl.add_message_to_conversation.assert_not_called()
+
+ # Assert state reset
+ assert mock_app.current_ai_message_widget is None
+
+
+async def test_handle_stream_done_no_widget(mock_app):
+ """Test graceful handling when the AI widget is missing."""
+ mock_app.current_ai_message_widget = None
+ event = StreamDone(full_text="Some text", error="Some error")
+
+ await handle_stream_done(mock_app, event)
+
+ # Just ensure it doesn't crash and notifies about the error
+ mock_app.notify.assert_called_once_with(
+ "Stream error (display widget missing): Some error",
+ severity="error",
+ timeout=10
+ )
\ No newline at end of file
diff --git a/Tests/Sync/__init__.py b/Tests/Event_Handlers/__init__.py
similarity index 100%
rename from Tests/Sync/__init__.py
rename to Tests/Event_Handlers/__init__.py
diff --git a/Tests/LLM_Management/__init__.py b/Tests/LLM_Management/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/Tests/MediaDB2/conftest.py b/Tests/MediaDB2/conftest.py
deleted file mode 100644
index f6e4c1dc..00000000
--- a/Tests/MediaDB2/conftest.py
+++ /dev/null
@@ -1,82 +0,0 @@
-# tests/conftest.py
-import pytest
-import tempfile
-import os
-import shutil
-from pathlib import Path
-import sys
-
-# Add src directory to sys.path to allow importing library/engine code
-src_path = Path(__file__).parent.parent / "src"
-sys.path.insert(0, str(src_path))
-
-# Now import after modifying path
-try:
- from tldw_Server_API.app.core.DB_Management.Media_DB_v2 import MediaDatabase
-except ImportError as e:
- print(f"ERROR in conftest: Could not import Database from sqlite_db. Error: {e}")
- # Define dummy class if import fails to avoid crashing pytest collection
- class Database:
- def __init__(self, *args, **kwargs): pass
- def close_connection(self): pass
- def get_sync_log_entries(self, *args, **kwargs): return []
- def execute_query(self, *args, **kwargs):
- class MockCursor:
- rowcount = 0
- def fetchone(self): return None
- def fetchall(self): return []
- def execute(self, *a, **k): pass
- return MockCursor()
- def transaction(self):
- class MockTransaction:
- def __enter__(self): return None # Return a mock connection/cursor if needed
- def __exit__(self, *args): pass
- return MockTransaction()
-
-
-# --- Database Fixtures ---
-
-@pytest.fixture(scope="function")
-def temp_db_path():
- """Creates a temporary directory and returns a unique DB path within it."""
- temp_dir = tempfile.mkdtemp()
- db_file = Path(temp_dir) / "test_db.sqlite"
- yield str(db_file) # Provide the path to the test function
- # Teardown: remove the directory after the test function finishes
- shutil.rmtree(temp_dir, ignore_errors=True)
-
-@pytest.fixture(scope="function")
-def memory_db_factory():
- """Factory fixture to create in-memory Database instances."""
- created_dbs = []
- def _create_db(client_id="test_client"):
- db = MediaDatabase(db_path=":memory:", client_id=client_id)
- created_dbs.append(db)
- return db
- yield _create_db
- # Teardown: close connections for all created in-memory DBs
- for db in created_dbs:
- try:
- db.close_connection()
- except: # Ignore errors during cleanup
- pass
-
-@pytest.fixture(scope="function")
-def file_db(temp_db_path):
- """Creates a file-based Database instance using a temporary path."""
- db = MediaDatabase(db_path=temp_db_path, client_id="file_client")
- yield db
- db.close_connection() # Ensure connection is closed
-
-# --- Sync Engine State Fixtures ---
-
-@pytest.fixture(scope="function")
-def temp_state_file():
- """Provides a path to a temporary file for sync state."""
- # Use NamedTemporaryFile which handles deletion automatically
- with tempfile.NamedTemporaryFile(delete=False, mode='w', suffix=".json") as tf:
- state_path = tf.name
- yield state_path # Provide the path
- # Teardown: Ensure the file is deleted even if the test fails mid-write
- if os.path.exists(state_path):
- os.remove(state_path)
\ No newline at end of file
diff --git a/Tests/MediaDB2/test_sqlite_db.py b/Tests/MediaDB2/test_sqlite_db.py
deleted file mode 100644
index 73049081..00000000
--- a/Tests/MediaDB2/test_sqlite_db.py
+++ /dev/null
@@ -1,668 +0,0 @@
-# tests/test_sqlite_db.py
-# Description: Unit tests for SQLite database operations, including CRUD, transactions, and sync log management.
-#
-# Imports:
-import json
-import os
-import shutil
-import tempfile
-from pathlib import Path
-
-import pytest
-import time
-import sqlite3
-from datetime import datetime, timezone, timedelta
-
-from tldw_Server_API.app.core.DB_Management.Media_DB_v2 import MediaDatabase, ConflictError, DatabaseError
-
-
-#
-# 3rd-Party Imports:
-#
-# Local imports
-# Import from src using adjusted sys.path in conftest
-#
-#######################################################################################################################
-#
-# Functions:
-
-# Helper to get sync log entries for assertions
-def get_log_count(db: MediaDatabase, entity_uuid: str) -> int:
- cursor = db.execute_query("SELECT COUNT(*) FROM sync_log WHERE entity_uuid = ?", (entity_uuid,))
- return cursor.fetchone()[0]
-
-def get_latest_log(db: MediaDatabase, entity_uuid: str) -> dict | None:
- cursor = db.execute_query(
- "SELECT * FROM sync_log WHERE entity_uuid = ? ORDER BY change_id DESC LIMIT 1",
- (entity_uuid,)
- )
- row = cursor.fetchone()
- return dict(row) if row else None
-
-def get_entity_version(db: MediaDatabase, entity_table: str, uuid: str) -> int | None:
- cursor = db.execute_query(f"SELECT version FROM {entity_table} WHERE uuid = ?", (uuid,))
- row = cursor.fetchone()
- return row['version'] if row else None
-
-class TestDatabaseInitialization:
- def test_memory_db_creation(self, memory_db_factory):
- """Test creating an in-memory database."""
- db = memory_db_factory("client_mem")
- assert db.is_memory_db
- assert db.client_id == "client_mem"
- # Check if a table exists (schema creation check)
- cursor = db.execute_query("SELECT name FROM sqlite_master WHERE type='table' AND name='Media'")
- assert cursor.fetchone() is not None
- db.close_connection()
-
- def test_file_db_creation(self, file_db, temp_db_path):
- """Test creating a file-based database."""
- assert not file_db.is_memory_db
- assert file_db.client_id == "file_client"
- assert os.path.exists(temp_db_path)
- cursor = file_db.execute_query("SELECT name FROM sqlite_master WHERE type='table' AND name='Media'")
- assert cursor.fetchone() is not None
- # file_db fixture handles closure
-
- def test_missing_client_id(self):
- """Test that ValueError is raised if client_id is missing."""
- with pytest.raises(ValueError, match="Client ID cannot be empty"):
- MediaDatabase(db_path=":memory:", client_id="")
- with pytest.raises(ValueError, match="Client ID cannot be empty"):
- MediaDatabase(db_path=":memory:", client_id=None)
-
-def test_schema_versioning_new_file_db(file_db): # Use the file_db fixture
- """Test that a new file DB gets the correct schema version."""
- # Initialization happened in the fixture
- cursor = file_db.execute_query("SELECT version FROM schema_version")
- version = cursor.fetchone()['version']
- assert version == MediaDatabase._CURRENT_SCHEMA_VERSION
-
-class TestDatabaseTransactions:
- def test_transaction_commit(self, memory_db_factory):
- db = memory_db_factory()
- keyword = "commit_test"
- with db.transaction():
- # Use internal method _add_keyword_internal or simplified version for test
- kw_id, kw_uuid = db.add_keyword(keyword) # add_keyword uses transaction internally too, nested is ok
- # Verify outside transaction
- cursor = db.execute_query("SELECT keyword FROM Keywords WHERE id = ?", (kw_id,))
- assert cursor.fetchone()['keyword'] == keyword
-
- def test_transaction_rollback(self, memory_db_factory):
- db = memory_db_factory()
- keyword = "rollback_test"
- initial_count_cursor = db.execute_query("SELECT COUNT(*) FROM Keywords")
- initial_count = initial_count_cursor.fetchone()[0]
- try:
- with db.transaction():
- # Simplified insert for test clarity
- new_uuid = db._generate_uuid()
- db.execute_query(
- "INSERT INTO Keywords (keyword, uuid, last_modified, version, client_id, deleted) VALUES (?, ?, ?, 1, ?, 0)",
- (keyword, new_uuid, db._get_current_utc_timestamp_str(), db.client_id),
- commit=False # Important: commit=False inside transaction block
- )
- # Check *inside* transaction
- cursor_inside = db.execute_query("SELECT COUNT(*) FROM Keywords")
- assert cursor_inside.fetchone()[0] == initial_count + 1
- raise ValueError("Simulating error to trigger rollback") # Force rollback
- except ValueError:
- pass # Expected error
- except Exception as e:
- pytest.fail(f"Unexpected exception during rollback test: {e}")
-
- # Verify outside transaction (count should be back to initial)
- final_count_cursor = db.execute_query("SELECT COUNT(*) FROM Keywords")
- assert final_count_cursor.fetchone()[0] == initial_count
-
-
-class TestDatabaseCRUDAndSync:
-
- @pytest.fixture
- def db_instance(self, memory_db_factory):
- """Provides a fresh in-memory DB for each test in this class."""
- return memory_db_factory("crud_client")
-
- def test_add_keyword(self, db_instance):
- keyword = " test keyword "
- expected_keyword = "test keyword"
- kw_id, kw_uuid = db_instance.add_keyword(keyword)
-
- assert kw_id is not None
- assert kw_uuid is not None
-
- # Verify DB state
- cursor = db_instance.execute_query("SELECT * FROM Keywords WHERE id = ?", (kw_id,))
- row = cursor.fetchone()
- assert row['keyword'] == expected_keyword
- assert row['uuid'] == kw_uuid
- assert row['version'] == 1
- assert row['client_id'] == db_instance.client_id
- assert not row['deleted']
-
- # Verify Sync Log
- log_entry = get_latest_log(db_instance, kw_uuid)
- assert log_entry['operation'] == 'create'
- assert log_entry['entity'] == 'Keywords'
- assert log_entry['version'] == 1
- assert log_entry['client_id'] == db_instance.client_id
- payload = json.loads(log_entry['payload'])
- assert payload['keyword'] == expected_keyword
- assert payload['uuid'] == kw_uuid
-
- def test_add_existing_keyword(self, db_instance):
- keyword = "existing"
- kw_id1, kw_uuid1 = db_instance.add_keyword(keyword)
- log_count1 = get_log_count(db_instance, kw_uuid1)
-
- kw_id2, kw_uuid2 = db_instance.add_keyword(keyword) # Add again
- log_count2 = get_log_count(db_instance, kw_uuid1)
-
- assert kw_id1 == kw_id2
- assert kw_uuid1 == kw_uuid2
- assert log_count1 == log_count2 # No new log entry
-
- def test_soft_delete_keyword(self, db_instance):
- keyword = "to_delete"
- kw_id, kw_uuid = db_instance.add_keyword(keyword)
- initial_version = get_entity_version(db_instance, "Keywords", kw_uuid)
-
- deleted = db_instance.soft_delete_keyword(keyword)
- assert deleted is True
-
- # Verify DB state
- cursor = db_instance.execute_query("SELECT deleted, version FROM Keywords WHERE id = ?", (kw_id,))
- row = cursor.fetchone()
- assert row['deleted'] == 1
- assert row['version'] == initial_version + 1
-
- # Verify Sync Log
- log_entry = get_latest_log(db_instance, kw_uuid)
- assert log_entry['operation'] == 'delete'
- assert log_entry['entity'] == 'Keywords'
- assert log_entry['version'] == initial_version + 1
- payload = json.loads(log_entry['payload'])
- assert payload['uuid'] == kw_uuid # Delete payload is minimal
-
- def test_undelete_keyword(self, db_instance):
- keyword = "to_undelete"
- kw_id, kw_uuid = db_instance.add_keyword(keyword)
- db_instance.soft_delete_keyword(keyword) # Delete it first
- deleted_version = get_entity_version(db_instance, "Keywords", kw_uuid)
-
- # Adding it again should undelete it
- undelete_id, undelete_uuid = db_instance.add_keyword(keyword)
-
- assert undelete_id == kw_id
- assert undelete_uuid == kw_uuid
-
- # Verify DB state
- cursor = db_instance.execute_query("SELECT deleted, version FROM Keywords WHERE id = ?", (kw_id,))
- row = cursor.fetchone()
- assert row['deleted'] == 0
- assert row['version'] == deleted_version + 1
-
- # Verify Sync Log
- log_entry = get_latest_log(db_instance, kw_uuid)
- # Undelete is logged as an 'update'
- assert log_entry['operation'] == 'update'
- assert log_entry['entity'] == 'Keywords'
- assert log_entry['version'] == deleted_version + 1
- payload = json.loads(log_entry['payload'])
- assert payload['uuid'] == kw_uuid
- assert payload['deleted'] == 0 # Payload shows undeleted state
-
- def test_add_media_with_keywords_create(self, db_instance):
- title = "Test Media Create"
- content = "Some unique content for create."
- keywords = ["create_kw1", "create_kw2"]
-
- media_id, media_uuid, msg = db_instance.add_media_with_keywords(
- title=title,
- media_type="article",
- content=content,
- keywords=keywords,
- author="Tester"
- )
-
- assert media_id is not None
- assert media_uuid is not None
- # FIX: Adjust assertion to match actual return message
- assert msg == f"Media '{title}' added."
-
- # Verify DB state (unchanged)
- cursor = db_instance.execute_query("SELECT * FROM Media WHERE id = ?", (media_id,))
- media_row = cursor.fetchone()
- assert media_row['title'] == title
- assert media_row['uuid'] == media_uuid
- assert media_row['version'] == 1 # Initial version
- assert not media_row['deleted']
-
- # Verify Keywords exist (unchanged)
- cursor = db_instance.execute_query("SELECT COUNT(*) FROM Keywords WHERE keyword IN (?, ?)", tuple(keywords))
- assert cursor.fetchone()[0] == 2
-
- # Verify MediaKeywords links (unchanged)
- cursor = db_instance.execute_query("SELECT COUNT(*) FROM MediaKeywords WHERE media_id = ?", (media_id,))
- assert cursor.fetchone()[0] == 2
-
- # Verify DocumentVersion creation (unchanged)
- cursor = db_instance.execute_query("SELECT version_number, content FROM DocumentVersions WHERE media_id = ? ORDER BY version_number DESC LIMIT 1", (media_id,))
- version_row = cursor.fetchone()
- assert version_row['version_number'] == 1
- assert version_row['content'] == content
-
- # Verify Sync Log for Media (Now Python generated)
- log_entry = get_latest_log(db_instance, media_uuid)
- # The *last* log might be DocumentVersion or MediaKeywords link depending on order.
- # Find the Media create log specifically.
- cursor_log = db_instance.execute_query("SELECT * FROM sync_log WHERE entity_uuid = ? AND operation = 'create' AND entity = 'Media'", (media_uuid,))
- log_entry = dict(cursor_log.fetchone())
-
- assert log_entry['operation'] == 'create'
- assert log_entry['entity'] == 'Media'
- assert log_entry['version'] == 1 # Check version
- payload = json.loads(log_entry['payload'])
- assert payload['uuid'] == media_uuid
- assert payload['title'] == title
-
-
- def test_add_media_with_keywords_update(self, db_instance):
- title = "Test Media Update"
- content1 = "Initial content."
- content2 = "Updated content."
- keywords1 = ["update_kw1"]
- keywords2 = ["update_kw2", "update_kw3"]
-
- media_id, media_uuid, _ = db_instance.add_media_with_keywords(
- title=title, media_type="text", content=content1, keywords=keywords1
- )
- initial_version = get_entity_version(db_instance, "Media", media_uuid)
- cursor_check_initial = db_instance.execute_query("SELECT content_hash FROM Media WHERE id = ?", (media_id,))
- initial_hash_row = cursor_check_initial.fetchone()
- assert initial_hash_row is not None
- initial_content_hash = initial_hash_row['content_hash']
-
- # Update 1: Using explicit URL (optional part of test)
- generated_url = f"local://text/{initial_content_hash}"
- media_id_up1, media_uuid_up1, msg1 = db_instance.add_media_with_keywords(
- title=title + " Updated Via URL", media_type="text", content=content2,
- keywords=["url_update_kw"], overwrite=True, url=generated_url
- )
- assert media_id_up1 == media_id
- assert media_uuid_up1 == media_uuid
- # FIX: Adjust assertion
- assert msg1 == f"Media '{title + ' Updated Via URL'}' updated."
- version_after_update1 = get_entity_version(db_instance, "Media", media_uuid)
- assert version_after_update1 == initial_version + 1
-
- # Update 2: Simulate finding by hash (URL=None)
- media_id_up2, media_uuid_up2, msg2 = db_instance.add_media_with_keywords(
- title=title + " Updated Via Hash", media_type="text", content=content2,
- keywords=keywords2, overwrite=True, url=None
- )
- assert media_id_up2 == media_id
- assert media_uuid_up2 == media_uuid
- # FIX: Adjust assertion
- assert msg2 == f"Media '{title + ' Updated Via Hash'}' updated."
-
- # Verify Final State (unchanged checks for DB content)
- cursor = db_instance.execute_query("SELECT title, content, version FROM Media WHERE id = ?", (media_id,))
- media_row = cursor.fetchone()
- assert media_row['title'] == title + " Updated Via Hash"
- assert media_row['content'] == content2
- assert media_row['version'] == version_after_update1 + 1
-
- # Verify Keywords links updated (unchanged)
- cursor = db_instance.execute_query("SELECT k.keyword FROM MediaKeywords mk JOIN Keywords k ON mk.keyword_id = k.id WHERE mk.media_id = ? ORDER BY k.keyword", (media_id,))
- current_keywords = [r['keyword'] for r in cursor.fetchall()]
- assert current_keywords == sorted(keywords2)
-
- # Verify latest DocumentVersion (unchanged)
- cursor = db_instance.execute_query("SELECT version_number, content FROM DocumentVersions WHERE media_id = ? ORDER BY version_number DESC LIMIT 1", (media_id,))
- version_row = cursor.fetchone(); assert version_row['version_number'] == 3; assert version_row['content'] == content2
-
- # Verify Sync Log for the *last* Media update
- log_entry = get_latest_log(db_instance, media_uuid) # Should be the Media update
- assert log_entry['operation'] == 'update'
- assert log_entry['entity'] == 'Media'
- assert log_entry['version'] == version_after_update1 + 1
- payload = json.loads(log_entry['payload'])
- assert payload['title'] == title + " Updated Via Hash"
-
- def test_soft_delete_media_cascade(self, db_instance):
- # 1. Setup complex item
- media_id, media_uuid, _ = db_instance.add_media_with_keywords(
- title="Cascade Test", content="Cascade content", media_type="article",
- keywords=["cascade1", "cascade2"], author="Cascade Author"
- )
- # Add a transcript manually (assuming no direct add_transcript method)
- t_uuid = db_instance._generate_uuid()
- db_instance.execute_query(
- """INSERT INTO Transcripts (media_id, whisper_model, transcription, uuid, last_modified, version, client_id, deleted)
- VALUES (?, ?, ?, ?, ?, 1, ?, 0)""",
- (media_id, "model_xyz", "Transcript text", t_uuid, db_instance._get_current_utc_timestamp_str(), db_instance.client_id),
- commit=True
- )
- # Add a chunk manually
- c_uuid = db_instance._generate_uuid()
- db_instance.execute_query(
- """INSERT INTO MediaChunks (media_id, chunk_text, uuid, last_modified, version, client_id, deleted)
- VALUES (?, ?, ?, ?, 1, ?, 0)""",
- (media_id, "Chunk text", c_uuid, db_instance._get_current_utc_timestamp_str(), db_instance.client_id),
- commit=True
- )
- media_version = get_entity_version(db_instance, "Media", media_uuid)
- transcript_version = get_entity_version(db_instance, "Transcripts", t_uuid)
- chunk_version = get_entity_version(db_instance, "MediaChunks", c_uuid)
-
-
- # 2. Perform soft delete with cascade
- deleted = db_instance.soft_delete_media(media_id, cascade=True)
- assert deleted is True
-
- # 3. Verify parent and children are marked deleted and versioned
- cursor = db_instance.execute_query("SELECT deleted, version FROM Media WHERE id = ?", (media_id,))
- assert dict(cursor.fetchone()) == {'deleted': 1, 'version': media_version + 1}
-
- cursor = db_instance.execute_query("SELECT deleted, version FROM Transcripts WHERE uuid = ?", (t_uuid,))
- assert dict(cursor.fetchone()) == {'deleted': 1, 'version': transcript_version + 1}
-
- cursor = db_instance.execute_query("SELECT deleted, version FROM MediaChunks WHERE uuid = ?", (c_uuid,))
- assert dict(cursor.fetchone()) == {'deleted': 1, 'version': chunk_version + 1}
-
- # 4. Verify keywords are unlinked
- cursor = db_instance.execute_query("SELECT COUNT(*) FROM MediaKeywords WHERE media_id = ?", (media_id,))
- assert cursor.fetchone()[0] == 0
-
- # 5. Verify Sync Logs
- media_log = get_latest_log(db_instance, media_uuid)
- assert media_log['operation'] == 'delete'
- assert media_log['version'] == media_version + 1
-
- transcript_log = get_latest_log(db_instance, t_uuid)
- assert transcript_log['operation'] == 'delete'
- assert transcript_log['version'] == transcript_version + 1
-
- chunk_log = get_latest_log(db_instance, c_uuid)
- assert chunk_log['operation'] == 'delete'
- assert chunk_log['version'] == chunk_version + 1
-
- # Check MediaKeywords unlink logs (tricky to get exact UUIDs, check count)
- cursor = db_instance.execute_query("SELECT COUNT(*) FROM sync_log WHERE entity = 'MediaKeywords' AND operation = 'unlink' AND payload LIKE ?", (f'%{media_uuid}%',))
- assert cursor.fetchone()[0] == 2 # Should be 2 unlink events
-
- def test_optimistic_locking_prevents_update_with_stale_version(self, db_instance):
- """Test that an UPDATE with a stale version number fails (rowcount 0)."""
- keyword = "conflict_test"
- kw_id, kw_uuid = db_instance.add_keyword(keyword)
- original_version = get_entity_version(db_instance, "Keywords", kw_uuid) # Should be 1
- assert original_version == 1, "Initial version should be 1"
-
- # Simulate external update incrementing version
- db_instance.execute_query(
- "UPDATE Keywords SET version = ?, client_id = ? WHERE id = ?",
- (original_version + 1, "external_client", kw_id),
- commit=True
- )
- version_after_external_update = get_entity_version(db_instance, "Keywords", kw_uuid) # Should be 2
- assert version_after_external_update == original_version + 1, "Version after external update should be 2"
-
- # Now, manually attempt an update using the *original stale version* (version=1)
- # This mimics what would happen if a process read version 1, then tried
- # to update after the external process bumped it to version 2.
- current_time = db_instance._get_current_utc_timestamp_str()
- client_id = db_instance.client_id
- cursor = db_instance.execute_query(
- "UPDATE Keywords SET keyword='stale_update', last_modified=?, version=?, client_id=? WHERE id=? AND version=?",
- (current_time, original_version + 1, client_id, kw_id, original_version), # <<< WHERE version = 1 (stale)
- commit=True # Commit needed to actually perform the check
- )
-
- # Assert that the update failed because the WHERE clause (version=1) didn't match any rows
- assert cursor.rowcount == 0, "Update with stale version should affect 0 rows"
-
- # Verify DB state is unchanged by the failed update (still shows external update's state)
- cursor_check = db_instance.execute_query("SELECT keyword, version, client_id FROM Keywords WHERE id = ?",
- (kw_id,))
- row = cursor_check.fetchone()
- assert row is not None, "Keyword should still exist"
- assert row['keyword'] == keyword, "Keyword text should not have changed to 'stale_update'"
- assert row['version'] == original_version + 1, "Version should remain 2 from the external update"
- assert row['client_id'] == "external_client", "Client ID should remain from the external update"
-
- def test_version_validation_trigger(self, db_instance):
- """Test trigger preventing non-sequential version updates."""
- kw_id, kw_uuid = db_instance.add_keyword("validation_test")
- current_version = get_entity_version(db_instance, "Keywords", kw_uuid)
-
- # Try to update version incorrectly (skipping a version)
- with pytest.raises(sqlite3.IntegrityError,
- match=r"Sync Error \(Keywords\): Version must increment by exactly 1"):
- # Provide client_id to prevent the *other* validation trigger firing
- client_id = db_instance.client_id
- db_instance.execute_query(
- "UPDATE Keywords SET version = ?, keyword = ?, client_id = ? WHERE id = ?",
- (current_version + 2, "bad version", client_id, kw_id),
- commit=True
- )
-
- # Try to update version incorrectly (same version)
- with pytest.raises(sqlite3.IntegrityError, match=r"Sync Error \(Keywords\): Version must increment by exactly 1"):
- client_id = db_instance.client_id
- db_instance.execute_query(
- "UPDATE Keywords SET version = ?, keyword = ?, client_id = ? WHERE id = ?",
- (current_version + 2, "bad version", client_id, kw_id),
- commit=True
- )
-
- def test_client_id_validation_trigger(self, db_instance):
- """Test trigger preventing null/empty client_id on update."""
- kw_id, kw_uuid = db_instance.add_keyword("clientid_test")
- current_version = get_entity_version(db_instance, "Keywords", kw_uuid)
-
- # Test the EMPTY STRING case handled by the trigger
- # Use raw string for regex match safety
- with pytest.raises(sqlite3.IntegrityError, match=r"Sync Error \(Keywords\): Client ID cannot be NULL or empty"):
- db_instance.execute_query(
- "UPDATE Keywords SET version = ?, client_id = '' WHERE id = ?",
- (current_version + 1, kw_id),
- commit=True
- )
-
- # Optional: Test the NULL case separately, expecting the NOT NULL constraint error
- # This confirms the underlying table constraint works, though not the trigger message.
- with pytest.raises(sqlite3.IntegrityError, match=r"Sync Error \(Keywords\): Client ID cannot be NULL or empty"):
- db_instance.execute_query(
- "UPDATE Keywords SET version = ?, client_id = NULL WHERE id = ?",
- (current_version + 1, kw_id), # Increment version correctly
- commit=True
- )
-
-
-class TestSyncLogManagement:
-
- @pytest.fixture
- def db_instance(self, memory_db_factory):
- db = memory_db_factory("log_client")
- # Add some initial data to generate logs
- db.add_keyword("log_kw_1")
- time.sleep(0.01) # Ensure timestamp difference
- db.add_keyword("log_kw_2")
- time.sleep(0.01)
- db.add_keyword("log_kw_3")
- db.soft_delete_keyword("log_kw_2")
- return db
-
- def test_get_sync_log_entries_all(self, db_instance):
- logs = db_instance.get_sync_log_entries()
- # Expect 3 creates + 1 delete = 4 entries
- assert len(logs) == 4
- assert logs[0]['change_id'] == 1
- assert logs[-1]['change_id'] == 4
-
- def test_get_sync_log_entries_since(self, db_instance):
- logs = db_instance.get_sync_log_entries(since_change_id=2) # Get 3 and 4
- assert len(logs) == 2
- assert logs[0]['change_id'] == 3
- assert logs[1]['change_id'] == 4
-
- def test_get_sync_log_entries_limit(self, db_instance):
- logs = db_instance.get_sync_log_entries(limit=2) # Get 1 and 2
- assert len(logs) == 2
- assert logs[0]['change_id'] == 1
- assert logs[1]['change_id'] == 2
-
- def test_get_sync_log_entries_since_and_limit(self, db_instance):
- logs = db_instance.get_sync_log_entries(since_change_id=1, limit=2) # Get 2 and 3
- assert len(logs) == 2
- assert logs[0]['change_id'] == 2
- assert logs[1]['change_id'] == 3
-
- def test_delete_sync_log_entries_specific(self, db_instance):
- initial_logs = db_instance.get_sync_log_entries()
- initial_count = len(initial_logs) # Should be 4
- ids_to_delete = [initial_logs[1]['change_id'], initial_logs[2]['change_id']] # Delete 2 and 3
-
- deleted_count = db_instance.delete_sync_log_entries(ids_to_delete)
- assert deleted_count == 2
-
- remaining_logs = db_instance.get_sync_log_entries()
- assert len(remaining_logs) == initial_count - 2
- remaining_ids = {log['change_id'] for log in remaining_logs}
- assert remaining_ids == {initial_logs[0]['change_id'], initial_logs[3]['change_id']} # 1 and 4 should remain
-
- def test_delete_sync_log_entries_before(self, db_instance):
- initial_logs = db_instance.get_sync_log_entries()
- initial_count = len(initial_logs) # Should be 4
- threshold_id = initial_logs[2]['change_id'] # Delete up to and including ID 3
-
- deleted_count = db_instance.delete_sync_log_entries_before(threshold_id)
- assert deleted_count == 3 # Deleted 1, 2, 3
-
- remaining_logs = db_instance.get_sync_log_entries()
- assert len(remaining_logs) == 1
- assert remaining_logs[0]['change_id'] == initial_logs[3]['change_id'] # Only 4 remains
-
- def test_delete_sync_log_entries_empty(self, db_instance):
- deleted_count = db_instance.delete_sync_log_entries([])
- assert deleted_count == 0
-
- def test_delete_sync_log_entries_invalid_id(self, db_instance):
- with pytest.raises(ValueError):
- db_instance.delete_sync_log_entries([1, "two", 3])
-
-
-# Add FTS specific tests
-class TestDatabaseFTS:
- @pytest.fixture
- def db_instance(self, memory_db_factory):
- # Use file DB for FTS tests if memory DB proves unstable
- # return memory_db_factory("fts_client")
- temp_dir = tempfile.mkdtemp()
- db_file = Path(temp_dir) / "fts_test_db.sqlite"
- db = MediaDatabase(db_path=str(db_file), client_id="fts_client")
- yield db
- db.close_connection()
- shutil.rmtree(temp_dir, ignore_errors=True)
-
- def test_fts_media_create_search(self, db_instance):
- """Test searching media via FTS after creation."""
- title = "FTS Test Alpha"
- content = "Unique content string omega gamma beta."
- media_id, media_uuid, _ = db_instance.add_media_with_keywords(title=title, content=content, media_type="fts_test")
-
- # Search by title fragment
- results, total = MediaDatabase.search_media_db(db_instance, search_query="Alpha", search_fields=["title"])
- assert total == 1
- assert len(results) == 1
- assert results[0]['id'] == media_id
-
- # Search by content fragment
- results, total = MediaDatabase.search_media_db(db_instance, search_query="omega", search_fields=["content"])
- assert total == 1
- assert len(results) == 1
- assert results[0]['id'] == media_id
-
- # Search by content phrase
- results, total = MediaDatabase.search_media_db(db_instance, search_query='"omega gamma"', search_fields=["content"])
- assert total == 1
-
- # Search non-existent term
- results, total = MediaDatabase.search_media_db(db_instance, search_query="nonexistent", search_fields=["content", "title"])
- assert total == 0
-
- def test_fts_media_update_search(self, db_instance):
- """Test searching media via FTS after update."""
- title1 = "FTS Update Initial"
- content1 = "Original text epsilon."
- title2 = "FTS Update Final Zeta"
- content2 = "Replacement stuff delta."
-
- media_id, media_uuid, _ = db_instance.add_media_with_keywords(title=title1, content=content1,
- media_type="fts_update")
-
- # Verify initial search works
- results, total = MediaDatabase.search_media_db(db_instance, search_query="epsilon", search_fields=["content"])
- assert total == 1
- initial_url = results[0]['url'] # Get URL for update lookup
-
- # Update the media
- db_instance.add_media_with_keywords(title=title2, content=content2, media_type="fts_update", overwrite=True,
- url=initial_url)
-
- # Search for OLD content - REMOVE this assertion as immediate consistency isn't guaranteed
- # results, total = search_media_db(db_instance, search_query="epsilon", search_fields=["content"])
- # assert total == 0
-
- # Search for NEW content should work
- results, total = MediaDatabase.search_media_db(db_instance, search_query="delta", search_fields=["content"])
- assert total == 1
- assert results[0]['id'] == media_id
-
- # Search for NEW title should work
- results, total = MediaDatabase.search_media_db(db_instance, search_query="Zeta", search_fields=["title"])
- assert total == 1
- assert results[0]['id'] == media_id
-
- def test_fts_media_delete_search(self, db_instance):
- """Test searching media via FTS after soft deletion."""
- title = "FTS To Delete"
- content = "Content will vanish theta."
- media_id, media_uuid, _ = db_instance.add_media_with_keywords(title=title, content=content, media_type="fts_delete")
-
- # Verify initial search works
- results, total = MediaDatabase.search_media_db(db_instance, search_query="theta", search_fields=["content"])
- assert total == 1
-
- # Soft delete the media
- deleted = db_instance.soft_delete_media(media_id)
- assert deleted is True
-
- # Search should now fail
- results, total = MediaDatabase.search_media_db(db_instance, search_query="theta", search_fields=["content"])
- assert total == 0
-
- def test_fts_keyword_search(self, db_instance):
- """Test searching keywords via FTS."""
- kw1_id, kw1_uuid = db_instance.add_keyword("fts_keyword_apple")
- kw2_id, kw2_uuid = db_instance.add_keyword("fts_keyword_banana")
-
- # Search keyword FTS directly (not typically done, but tests population)
- # NOTE: search_media_db doesn't search keyword_fts directly, this is just to test population
- cursor = db_instance.execute_query("SELECT rowid, keyword FROM keyword_fts WHERE keyword_fts MATCH ?", ("apple",))
- fts_results = cursor.fetchall()
- assert len(fts_results) == 1
- assert fts_results[0]['rowid'] == kw1_id
-
- # Soft delete keyword 1
- db_instance.soft_delete_keyword("fts_keyword_apple")
-
- # Search should now fail for apple
- cursor = db_instance.execute_query("SELECT rowid FROM keyword_fts WHERE keyword_fts MATCH ?", ("apple",))
- assert cursor.fetchone() is None
-
- # Search for banana should still work
- cursor = db_instance.execute_query("SELECT rowid FROM keyword_fts WHERE keyword_fts MATCH ?", ("banana",))
- assert cursor.fetchone()['rowid'] == kw2_id
\ No newline at end of file
diff --git a/Tests/Media_DB/__init__.py b/Tests/Media_DB/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/Tests/DB/conftest.py b/Tests/Media_DB/conftest.py
similarity index 100%
rename from Tests/DB/conftest.py
rename to Tests/Media_DB/conftest.py
diff --git a/Tests/Media_DB/test_media_db_properties.py b/Tests/Media_DB/test_media_db_properties.py
new file mode 100644
index 00000000..3fcc3c37
--- /dev/null
+++ b/Tests/Media_DB/test_media_db_properties.py
@@ -0,0 +1,348 @@
+# test_media_db_v2_properties.py
+#
+# Property-based tests for the Media_DB_v2 library using Hypothesis.
+# These tests verify the logical correctness and invariants of the database
+# operations across a wide range of generated data.
+#
+# Imports
+from datetime import datetime, timezone, timedelta
+from typing import Iterator, Callable, Any, Generator
+import pytest
+import uuid
+from pathlib import Path
+#
+# Third-Party Imports
+from hypothesis import given, strategies as st, settings, HealthCheck, assume
+#
+# Local Imports
+# Adjust the import path based on your project structure
+from tldw_chatbook.DB.Client_Media_DB_v2 import (
+ MediaDatabase,
+ InputError,
+ DatabaseError,
+ ConflictError, fetch_keywords_for_media, empty_trash
+)
+#
+#######################################################################################################################
+#
+# --- Hypothesis Settings ---
+
+# A custom profile for database-intensive tests.
+# It increases the deadline and suppresses health checks that are common
+# but expected in I/O-heavy testing scenarios.
+settings.register_profile(
+ "db_test_suite",
+ deadline=2000,
+ suppress_health_check=[
+ HealthCheck.too_slow,
+ HealthCheck.function_scoped_fixture,
+ HealthCheck.data_too_large,
+ ]
+)
+settings.load_profile("db_test_suite")
+
+
+# --- Pytest Fixtures ---
+
+
+@pytest.fixture
+def db_factory(tmp_path: Path) -> Generator[Callable[[], MediaDatabase], Any, None]:
+ """
+ A factory that creates fresh, isolated MediaDatabase instances on demand.
+ Manages cleanup of all created instances.
+ """
+ created_dbs = []
+
+ def _create_db_instance() -> MediaDatabase:
+ db_file = tmp_path / f"prop_test_{uuid.uuid4().hex}.db"
+ client_id = f"client_{uuid.uuid4().hex[:8]}"
+ db = MediaDatabase(db_path=db_file, client_id=client_id)
+ created_dbs.append(db)
+ return db
+
+ yield _create_db_instance
+
+ # Teardown: close all connections that were created by the factory
+ for db in created_dbs:
+ db.close_connection()
+
+@pytest.fixture
+def db_instance(db_factory: Callable[[], MediaDatabase]) -> MediaDatabase:
+ """
+ Provides a single, fresh MediaDatabase instance for a test function.
+ This fixture uses the `db_factory` to create and manage the instance.
+ """
+ return db_factory()
+
+# --- Hypothesis Strategies ---
+
+# Strategy for generating text that is guaranteed to have non-whitespace content.
+st_required_text = st.text(min_size=1, max_size=50).map(lambda s: s.strip()).filter(lambda s: len(s) > 0)
+
+# Strategy for a single, clean keyword.
+st_keyword_text = st.text(
+ alphabet=st.characters(whitelist_categories=["L", "N", "S", "P"]),
+ min_size=2,
+ max_size=20
+).map(lambda s: s.strip()).filter(lambda s: len(s) > 0)
+
+# Strategy for generating a list of unique, case-insensitive keywords.
+st_keywords_list = st.lists(
+ st_keyword_text,
+ min_size=1,
+ max_size=5,
+ unique_by=lambda s: s.lower()
+).filter(lambda l: len(l) > 0) # Ensure list is not empty after filtering
+
+
+# A composite strategy to generate a valid dictionary of media data for creation.
+@st.composite
+def st_media_data(draw):
+ """Generates a dictionary of plausible data for a new media item."""
+ return {
+ "title": draw(st_required_text),
+ "content": draw(st.text(min_size=10, max_size=500)),
+ "media_type": draw(st.sampled_from(['article', 'video', 'obsidian_note', 'pdf'])),
+ "author": draw(st.one_of(st.none(), st.text(min_size=3, max_size=30))),
+ "keywords": draw(st_keywords_list)
+ }
+
+
+# --- Property Test Classes ---
+
+class TestMediaItemProperties:
+ """Property-based tests for the core Media item lifecycle."""
+
+ @given(media_data=st_media_data())
+ def test_media_item_roundtrip(self, db_instance: MediaDatabase, media_data: dict):
+ """
+ Property: A media item, once added, should be retrievable with the same data.
+ """
+ media_data["content"] += f" {uuid.uuid4().hex}"
+
+ media_id, media_uuid, msg = db_instance.add_media_with_keywords(**media_data)
+
+ assert "added" in msg
+ assert media_id is not None
+ assert media_uuid is not None
+
+ retrieved = db_instance.get_media_by_id(media_id)
+ assert retrieved is not None
+
+ assert retrieved['title'] == media_data['title']
+ assert retrieved['content'] == media_data['content']
+ assert retrieved['type'] == media_data['media_type']
+ assert retrieved['author'] == media_data['author']
+ assert retrieved['version'] == 1
+ assert not retrieved['deleted']
+
+ linked_keywords = {kw.lower().strip() for kw in fetch_keywords_for_media(media_id, db_instance)}
+ expected_keywords = {kw.lower().strip() for kw in media_data['keywords']}
+ assert linked_keywords == expected_keywords
+
+ # FIX: The get_all_document_versions function defaults to NOT including content.
+ # We must explicitly request it for the assertion to work.
+ doc_versions = db_instance.get_all_document_versions(media_id, include_content=True)
+ assert len(doc_versions) == 1
+ assert doc_versions[0]['version_number'] == 1
+ assert doc_versions[0]['content'] == media_data['content']
+
+ # ... other tests in this class are correct ...
+ @given(initial_media=st_media_data(), update_media=st_media_data())
+ def test_update_increments_version_and_changes_data(self, db_instance: MediaDatabase, initial_media: dict,
+ update_media: dict):
+ initial_media["content"] += f" initial_{uuid.uuid4().hex}"
+ update_media["content"] += f" update_{uuid.uuid4().hex}"
+ media_id, media_uuid, _ = db_instance.add_media_with_keywords(**initial_media)
+ original = db_instance.get_media_by_id(media_id)
+ media_id_up, media_uuid_up, msg = db_instance.add_media_with_keywords(
+ url=original['url'],
+ overwrite=True,
+ **update_media
+ )
+ assert media_id_up == media_id
+ assert media_uuid_up == media_uuid
+ assert "updated" in msg
+ updated = db_instance.get_media_by_id(media_id)
+ assert updated is not None
+ assert updated['version'] == original['version'] + 1
+ assert updated['title'] == update_media['title']
+ assert updated['content'] == update_media['content']
+ doc_versions = db_instance.get_all_document_versions(media_id)
+ assert len(doc_versions) == 2
+
+ @given(media_data=st_media_data())
+ def test_soft_delete_makes_item_unfindable_by_default(self, db_instance: MediaDatabase, media_data: dict):
+ unique_word = f"hypothesis_{uuid.uuid4().hex}"
+ media_data["content"] = f"{media_data['content']} {unique_word}"
+ media_id, _, _ = db_instance.add_media_with_keywords(**media_data)
+ original = db_instance.get_media_by_id(media_id)
+ assert original is not None
+ db_instance.soft_delete_media(media_id)
+ assert db_instance.get_media_by_id(media_id) is None
+ results, total = db_instance.search_media_db(search_query=unique_word)
+ assert total == 0
+ raw_record = db_instance.get_media_by_id(media_id, include_deleted=True)
+ assert raw_record is not None
+ assert raw_record['deleted'] == 1
+ assert raw_record['version'] == original['version'] + 1
+
+
+class TestSearchProperties:
+ @given(media_data=st_media_data())
+ def test_search_finds_item_by_its_properties(self, db_instance: MediaDatabase, media_data: dict):
+ unique_word = f"hypothesis_{uuid.uuid4().hex}"
+ media_data["content"] = f"{media_data['content']} {unique_word}"
+ media_id, _, _ = db_instance.add_media_with_keywords(**media_data)
+ results, total = db_instance.search_media_db(search_query=unique_word, search_fields=['content'])
+ assert total == 1
+ assert results[0]['id'] == media_id
+ keyword_to_find = media_data["keywords"][0]
+ results, total = db_instance.search_media_db(search_query=None, must_have_keywords=[keyword_to_find],
+ media_ids_filter=[media_id])
+ assert total == 1
+ assert results[0]['id'] == media_id
+ results, total = db_instance.search_media_db(search_query=None, media_types=[media_data['media_type']],
+ media_ids_filter=[media_id])
+ assert total == 1
+ assert results[0]['id'] == media_id
+
+ @given(item1=st_media_data(), item2=st_media_data())
+ def test_search_isolates_results_correctly(self, db_instance: MediaDatabase, item1: dict, item2: dict):
+ item1_kws = set(kw.lower() for kw in item1['keywords'])
+ item2_kws = set(kw.lower() for kw in item2['keywords'])
+ assume(item1_kws.isdisjoint(item2_kws))
+ item1["content"] += f" item1_{uuid.uuid4().hex}"
+ item2["content"] += f" item2_{uuid.uuid4().hex}"
+ id1, _, _ = db_instance.add_media_with_keywords(**item1)
+ id2, _, _ = db_instance.add_media_with_keywords(**item2)
+ keyword_to_find = item1['keywords'][0]
+ results, total = db_instance.search_media_db(search_query=None, must_have_keywords=[keyword_to_find],
+ media_ids_filter=[id1, id2])
+ assert total == 1
+ assert results[0]['id'] == id1
+
+ @given(media_data=st_media_data())
+ def test_soft_deleted_item_is_not_in_fts_search(self, db_instance: MediaDatabase, media_data: dict):
+ unique_term = f"fts_{uuid.uuid4().hex}"
+ media_data['title'] = f"{media_data['title']} {unique_term}"
+ media_data['content'] += f" {uuid.uuid4().hex}"
+ media_id, _, _ = db_instance.add_media_with_keywords(**media_data)
+ results, total = db_instance.search_media_db(search_query=unique_term)
+ assert total == 1
+ was_deleted = db_instance.soft_delete_media(media_id)
+ assert was_deleted is True
+ results, total = db_instance.search_media_db(search_query=unique_term)
+ assert total == 0
+
+
+class TestIdempotencyAndConstraints:
+ """Tests for idempotency of operations and enforcement of DB constraints."""
+
+ @settings(deadline=None)
+ @given(media_data=st_media_data())
+ def test_mark_as_trash_is_idempotent(self, db_instance: MediaDatabase, media_data: dict):
+ """
+ Property: Marking an item as trash multiple times has the same effect as
+ marking it once. The version should only increment on the first call.
+ """
+ media_data["content"] += f" {uuid.uuid4().hex}"
+ media_id, _, _ = db_instance.add_media_with_keywords(**media_data)
+
+ assert db_instance.mark_as_trash(media_id) is True
+ item_v2 = db_instance.get_media_by_id(media_id, include_trash=True)
+ assert item_v2['version'] == 2
+
+ assert db_instance.mark_as_trash(media_id) is False
+ item_still_v2 = db_instance.get_media_by_id(media_id, include_trash=True)
+ assert item_still_v2['version'] == 2
+
+ @given(
+ media1=st_media_data(),
+ media2=st_media_data(),
+ url_part1=st.uuids().map(str),
+ url_part2=st.uuids().map(str),
+ )
+ def test_add_media_with_conflicting_hash_is_handled(self,
+ db_instance: MediaDatabase,
+ media1: dict,
+ media2: dict,
+ url_part1: str,
+ url_part2: str):
+ # Ensure URLs will be different, a highly unlikely edge case otherwise
+ assume(url_part1 != url_part2)
+ # Ensure titles are different to test a metadata-only update.
+ assume(media1['title'] != media2['title'])
+
+ # Make content identical to trigger a content hash conflict.
+ media2['content'] = media1['content']
+
+ # Use the deterministic UUIDs from Hypothesis to build the URLs.
+ media1['url'] = f"http://example.com/{url_part1}"
+ media2['url'] = f"http://example.com/{url_part2}"
+
+ id1, _, _ = db_instance.add_media_with_keywords(**media1)
+
+ # 1. Test with overwrite=False. Should fail due to conflict.
+ id2, _, msg2 = db_instance.add_media_with_keywords(**media2, overwrite=False)
+ assert id2 is None
+ assert "already exists. Overwrite not enabled." in msg2
+
+ # 2. Test with overwrite=True. Should update the existing item's metadata.
+ id3, _, msg3 = db_instance.add_media_with_keywords(**media2, overwrite=True)
+ assert id3 == id1
+ assert "updated" in msg3
+
+ # 3. Verify the metadata was actually updated in the database.
+ final_item = db_instance.get_media_by_id(id1)
+ assert final_item is not None
+ assert final_item['title'] == media2['title']
+
+
+class TestTimeBasedAndSearchQueries:
+ # ... other tests in this class are correct ...
+
+ @given(days=st.integers(min_value=1, max_value=365))
+ def test_empty_trash_respects_time_threshold(self, db_instance: MediaDatabase, days: int):
+ """
+ Property: `empty_trash` should only soft-delete items whose `trash_date`
+ is older than the specified threshold.
+ """
+ media_id, _, _ = db_instance.add_media_with_keywords(
+ title="Trash Test", content=f"...{uuid.uuid4().hex}", media_type="article", keywords=["test"])
+
+ # This call handles versioning correctly, bumping version to 2
+ db_instance.mark_as_trash(media_id)
+ item_v2 = db_instance.get_media_by_id(media_id, include_trash=True)
+
+ past_date = datetime.now(timezone.utc) - timedelta(days=days + 1)
+
+ # FIX: The manual update MUST comply with the database triggers.
+ # This means we have to increment the version and supply a client_id.
+ # This makes the test setup robust.
+ with db_instance.transaction():
+ db_instance.execute_query(
+ "UPDATE Media SET trash_date = ?, version = ?, client_id = ?, last_modified = ? WHERE id = ?",
+ (
+ past_date.strftime('%Y-%m-%dT%H:%M:%S.%f')[:-3] + 'Z',
+ item_v2['version'] + 1, # Manually increment version for this setup step
+ 'test_setup_client',
+ db_instance._get_current_utc_timestamp_str(),
+ media_id
+ )
+ )
+
+ # Now the item is at version 3.
+ # `empty_trash` will find this item and call `soft_delete_media`,
+ # which will correctly read version 3 and update to version 4.
+ processed_count, _ = empty_trash(db_instance=db_instance, days_threshold=days)
+ assert processed_count == 1
+
+ final_item = db_instance.get_media_by_id(media_id, include_trash=True, include_deleted=True)
+ assert final_item['deleted'] == 1
+ assert final_item['version'] == 4 # Initial: 1, Trash: 2, Manual Date Change: 3, Delete: 4
+
+
+#
+# End of test_media_db_properties.py
+#######################################################################################################################
diff --git a/Tests/Media_DB/test_media_db_v2.py b/Tests/Media_DB/test_media_db_v2.py
new file mode 100644
index 00000000..a5fc26ae
--- /dev/null
+++ b/Tests/Media_DB/test_media_db_v2.py
@@ -0,0 +1,460 @@
+# tests/test_media_db_v2.py
+# Description: Unit tests for SQLite database operations, including CRUD, transactions, and sync log management.
+# This version is self-contained and does not require a conftest.py file.
+#
+# Standard Library Imports:
+import json
+import os
+import pytest
+import shutil
+import sys
+import time
+import sqlite3
+from datetime import datetime, timezone, timedelta
+from pathlib import Path
+#
+# --- Path Setup (Replaces conftest.py logic) ---
+# Add the project root to the Python path to allow importing the library.
+# This assumes the tests are in a 'tests' directory at the project root.
+try:
+ project_root = Path(__file__).resolve().parent.parent
+ sys.path.insert(0, str(project_root))
+ # If your source code is in a 'src' directory, you might need:
+ # sys.path.insert(0, str(project_root / "src"))
+except (NameError, IndexError):
+ # Fallback for environments where __file__ is not defined
+ pass
+#
+# Local imports (from the main project)
+from tldw_chatbook.DB.Client_Media_DB_v2 import MediaDatabase as Database, ConflictError\
+#
+#######################################################################################################################
+#
+# Helper Functions (for use in tests)
+#
+
+def get_log_count(db: Database, entity_uuid: str) -> int:
+ """Helper to get sync log entries for assertions."""
+ cursor = db.execute_query("SELECT COUNT(*) FROM sync_log WHERE entity_uuid = ?", (entity_uuid,))
+ return cursor.fetchone()[0]
+
+
+def get_latest_log(db: Database, entity_uuid: str) -> dict | None:
+ """Helper to get the most recent sync log for an entity."""
+ cursor = db.execute_query(
+ "SELECT * FROM sync_log WHERE entity_uuid = ? ORDER BY change_id DESC LIMIT 1",
+ (entity_uuid,)
+ )
+ row = cursor.fetchone()
+ return dict(row) if row else None
+
+
+def get_entity_version(db: Database, entity_table: str, uuid: str) -> int | None:
+ """Helper to get the current version of an entity."""
+ cursor = db.execute_query(f"SELECT version FROM {entity_table} WHERE uuid = ?", (uuid,))
+ row = cursor.fetchone()
+ return row['version'] if row else None
+
+
+#######################################################################################################################
+#
+# Pytest Fixtures (Moved from conftest.py)
+#
+
+@pytest.fixture(scope="function")
+def memory_db_factory():
+ """Factory fixture to create in-memory Database instances with automatic connection closing."""
+ created_dbs = []
+
+ def _create_db(client_id="test_client"):
+ db = Database(db_path=":memory:", client_id=client_id)
+ created_dbs.append(db)
+ return db
+
+ yield _create_db
+ # Teardown: close connections for all created in-memory DBs
+ for db in created_dbs:
+ try:
+ db.close_connection()
+ except Exception: # Ignore errors during cleanup
+ pass
+
+
+@pytest.fixture(scope="function")
+def temp_db_path(tmp_path: Path) -> str:
+ """Creates a temporary directory and returns a unique DB path string within it."""
+ # The built-in tmp_path fixture handles directory creation and cleanup.
+ return str(tmp_path / "test_db.sqlite")
+
+
+@pytest.fixture(scope="function")
+def file_db(temp_db_path: str):
+ """Creates a file-based Database instance using a temporary path with automatic connection closing."""
+ db = Database(db_path=temp_db_path, client_id="file_client")
+ yield db
+ db.close_connection()
+
+
+@pytest.fixture(scope="function")
+def db_instance(memory_db_factory):
+ """Provides a fresh, isolated in-memory DB for a single test."""
+ return memory_db_factory("crud_client")
+
+
+@pytest.fixture(scope="class")
+def search_db(tmp_path_factory):
+ """Sets up a single DB with predictable data for all search tests in a class."""
+ db_path = tmp_path_factory.mktemp("search_tests") / "search.db"
+ db = Database(db_path, "search_client")
+
+ # Add a predictable set of media items
+ db.add_media_with_keywords(
+ title="Alpha One", content="Content about Python and programming.", media_type="article",
+ keywords=["python", "programming"], ingestion_date="2023-01-15T12:00:00Z"
+ ) # ID 1
+ db.add_media_with_keywords(
+ title="Beta Two", content="A video about data science.", media_type="video",
+ keywords=["python", "data science"], ingestion_date="2023-02-20T12:00:00Z"
+ ) # ID 2
+ db.add_media_with_keywords(
+ title="Gamma Three (TRASH)", content="Old news.", media_type="article",
+ keywords=["news"], ingestion_date="2023-03-10T12:00:00Z"
+ ) # ID 3
+ db.mark_as_trash(3)
+
+ yield db
+ db.close_connection()
+
+
+#######################################################################################################################
+#
+# Test Classes
+#
+
+class TestDatabaseInitialization:
+ def test_memory_db_creation(self, memory_db_factory):
+ """Test creating an in-memory database."""
+ db = memory_db_factory("client_mem")
+ assert db.is_memory_db
+ assert db.client_id == "client_mem"
+ cursor = db.execute_query("SELECT name FROM sqlite_master WHERE type='table' AND name='Media'")
+ assert cursor.fetchone() is not None
+ db.close_connection()
+
+ def test_file_db_creation(self, file_db, temp_db_path):
+ """Test creating a file-based database."""
+ assert not file_db.is_memory_db
+ assert file_db.client_id == "file_client"
+ assert os.path.exists(temp_db_path)
+ cursor = file_db.execute_query("SELECT name FROM sqlite_master WHERE type='table' AND name='Media'")
+ assert cursor.fetchone() is not None
+ # file_db fixture handles closure
+
+ def test_missing_client_id(self):
+ """Test that ValueError is raised if client_id is missing."""
+ with pytest.raises(ValueError, match="Client ID cannot be empty"):
+ Database(db_path=":memory:", client_id="")
+ with pytest.raises(ValueError, match="Client ID cannot be empty"):
+ Database(db_path=":memory:", client_id=None)
+
+
+class TestDatabaseTransactions:
+ def test_transaction_commit(self, memory_db_factory):
+ db = memory_db_factory()
+ keyword = "commit_test"
+ with db.transaction():
+ db.add_keyword(keyword)
+ cursor = db.execute_query("SELECT keyword FROM Keywords WHERE keyword = ?", (keyword,))
+ assert cursor.fetchone()['keyword'] == keyword
+
+ def test_transaction_rollback(self, memory_db_factory):
+ db = memory_db_factory()
+ keyword = "rollback_test"
+ initial_count_cursor = db.execute_query("SELECT COUNT(*) FROM Keywords")
+ initial_count = initial_count_cursor.fetchone()[0]
+ try:
+ with db.transaction():
+ new_uuid = db._generate_uuid()
+ db.execute_query(
+ "INSERT INTO Keywords (keyword, uuid, last_modified, version, client_id, deleted) VALUES (?, ?, ?, 1, ?, 0)",
+ (keyword, new_uuid, db._get_current_utc_timestamp_str(), db.client_id),
+ commit=False
+ )
+ cursor_inside = db.execute_query("SELECT COUNT(*) FROM Keywords")
+ assert cursor_inside.fetchone()[0] == initial_count + 1
+ raise ValueError("Simulating error to trigger rollback")
+ except ValueError:
+ pass # Expected error
+ except Exception as e:
+ pytest.fail(f"Unexpected exception during rollback test: {e}")
+
+ final_count_cursor = db.execute_query("SELECT COUNT(*) FROM Keywords")
+ assert final_count_cursor.fetchone()[0] == initial_count
+
+
+class TestSearchFunctionality:
+ # The 'search_db' fixture is now defined at the module level
+ # and provides a shared database for all tests in this class.
+
+ def test_some_search_function(self, search_db):
+ """A placeholder test demonstrating usage of the search_db fixture."""
+ # Example: Search for items with the keyword "python"
+ # results = search_db.search(keywords=["python"])
+ # assert len(results) == 2
+ pass # Add actual search tests here
+
+
+class TestDatabaseCRUDAndSync:
+ # The 'db_instance' fixture is now defined at the module level
+ # and provides a fresh in-memory DB for each test in this class.
+
+ def test_add_keyword(self, db_instance):
+ keyword = " test keyword "
+ expected_keyword = "test keyword"
+ kw_id, kw_uuid = db_instance.add_keyword(keyword)
+
+ assert kw_id is not None
+ assert kw_uuid is not None
+
+ cursor = db_instance.execute_query("SELECT * FROM Keywords WHERE id = ?", (kw_id,))
+ row = cursor.fetchone()
+ assert row['keyword'] == expected_keyword
+ assert row['uuid'] == kw_uuid
+
+ log_entry = get_latest_log(db_instance, kw_uuid)
+ assert log_entry['operation'] == 'create'
+ assert log_entry['entity'] == 'Keywords'
+
+ def test_add_existing_keyword(self, db_instance):
+ keyword = "existing"
+ kw_id1, kw_uuid1 = db_instance.add_keyword(keyword)
+ log_count1 = get_log_count(db_instance, kw_uuid1)
+ kw_id2, kw_uuid2 = db_instance.add_keyword(keyword)
+ log_count2 = get_log_count(db_instance, kw_uuid1)
+
+ assert kw_id1 == kw_id2
+ assert kw_uuid1 == kw_uuid2
+ assert log_count1 == log_count2
+
+ def test_soft_delete_keyword(self, db_instance):
+ keyword = "to_delete"
+ kw_id, kw_uuid = db_instance.add_keyword(keyword)
+ initial_version = get_entity_version(db_instance, "Keywords", kw_uuid)
+
+ assert db_instance.soft_delete_keyword(keyword) is True
+
+ cursor = db_instance.execute_query("SELECT deleted, version FROM Keywords WHERE id = ?", (kw_id,))
+ row = cursor.fetchone()
+ assert row['deleted'] == 1
+ assert row['version'] == initial_version + 1
+
+ log_entry = get_latest_log(db_instance, kw_uuid)
+ assert log_entry['operation'] == 'delete'
+ assert log_entry['version'] == initial_version + 1
+
+ def test_undelete_keyword(self, db_instance):
+ keyword = "to_undelete"
+ kw_id, kw_uuid = db_instance.add_keyword(keyword)
+ db_instance.soft_delete_keyword(keyword)
+ deleted_version = get_entity_version(db_instance, "Keywords", kw_uuid)
+
+ undelete_id, undelete_uuid = db_instance.add_keyword(keyword)
+
+ assert undelete_id == kw_id
+ cursor = db_instance.execute_query("SELECT deleted, version FROM Keywords WHERE id = ?", (kw_id,))
+ row = cursor.fetchone()
+ assert row['deleted'] == 0
+ assert row['version'] == deleted_version + 1
+
+ log_entry = get_latest_log(db_instance, kw_uuid)
+ assert log_entry['operation'] == 'update'
+ assert log_entry['version'] == deleted_version + 1
+
+ def test_add_media_with_keywords_create(self, db_instance):
+ title = "Test Media Create"
+ content = "Some unique content for create."
+ keywords = ["create_kw1", "create_kw2"]
+
+ media_id, media_uuid, msg = db_instance.add_media_with_keywords(
+ title=title, media_type="article", content=content, keywords=keywords
+ )
+ assert media_id is not None
+ assert f"Media '{title}' added." in msg
+
+ cursor = db_instance.execute_query("SELECT uuid, version FROM Media WHERE id = ?", (media_id,))
+ media_row = cursor.fetchone()
+ assert media_row['uuid'] == media_uuid
+ assert media_row['version'] == 1
+
+ log_entry = get_latest_log(db_instance, media_uuid)
+ assert log_entry['operation'] == 'create'
+
+ def test_add_media_with_keywords_update(self, db_instance):
+ """Test updating a media item with new content, title, and keywords."""
+ # Initial media setup
+ title1 = "Test Media Original"
+ title2 = "Test Media Updated"
+ content1 = "Initial content."
+ content2 = "Updated content."
+ keywords1 = ["update_kw1"]
+ keywords2 = ["update_kw2", "update_kw3"]
+ media_type = "article" # Use consistent media_type across tests
+
+ # Create initial media item
+ media_id, media_uuid, msg1 = db_instance.add_media_with_keywords(
+ title=title1, media_type=media_type, content=content1, keywords=keywords1
+ )
+ assert "added" in msg1.lower(), f"Expected 'added' in message, got: {msg1}"
+ initial_version = get_entity_version(db_instance, "Media", media_uuid)
+ assert initial_version == 1, f"Expected initial version 1, got {initial_version}"
+
+ # Fetch the created media to get its URL (stable identifier)
+ created_media = db_instance.get_media_by_id(media_id)
+ assert created_media is not None, "Failed to retrieve created media item"
+ url_to_update = created_media['url']
+
+ # Update the media item
+ updated_id, updated_uuid, msg2 = db_instance.add_media_with_keywords(
+ title=title2,
+ media_type=media_type,
+ content=content2,
+ keywords=keywords2,
+ overwrite=True,
+ url=url_to_update
+ )
+
+ # Verify update operation returned correct values
+ assert updated_id == media_id, "Update returned different media ID"
+ assert updated_uuid == media_uuid, "Update returned different UUID"
+ assert "updated" in msg2.lower(), f"Expected 'updated' in message, got: {msg2}"
+
+ # Verify content was updated
+ cursor = db_instance.execute_query("SELECT content, title, version FROM Media WHERE id = ?", (media_id,))
+ media_row = cursor.fetchone()
+ assert media_row['content'] == content2, "Content was not updated"
+ assert media_row['title'] == title2, "Title was not updated"
+ assert media_row['version'] == initial_version + 1, f"Version not incremented, expected {initial_version + 1}, got {media_row['version']}"
+
+ # Verify keywords were updated
+ cursor = db_instance.execute_query(
+ """
+ SELECT k.keyword FROM MediaKeywords mk
+ JOIN Keywords k ON mk.keyword_id = k.id
+ WHERE mk.media_id = ?
+ """,
+ (media_id,)
+ )
+ linked_keywords = [row['keyword'] for row in cursor.fetchall()]
+ assert set(kw.lower() for kw in linked_keywords) == set(kw.lower() for kw in keywords2), "Keywords were not updated correctly"
+
+ # Verify sync log was created for the update
+ log_entry = get_latest_log(db_instance, media_uuid)
+ assert log_entry['operation'] == 'update', f"Expected 'update' operation, got {log_entry['operation']}"
+ assert log_entry['version'] == initial_version + 1, f"Log version mismatch: {log_entry['version']} vs {initial_version + 1}"
+ assert log_entry['entity'] == 'Media', f"Expected 'Media' entity, got {log_entry['entity']}"
+
+ def test_soft_delete_media_cascade(self, db_instance):
+ media_id, media_uuid, _ = db_instance.add_media_with_keywords(
+ title="Cascade Test", content="Cascade content", keywords=["cascade1"], media_type="article"
+ )
+ media_version = get_entity_version(db_instance, "Media", media_uuid)
+
+ assert db_instance.soft_delete_media(media_id, cascade=True) is True
+
+ cursor = db_instance.execute_query("SELECT deleted, version FROM Media WHERE id = ?", (media_id,))
+ assert dict(cursor.fetchone()) == {'deleted': 1, 'version': media_version + 1}
+
+ cursor = db_instance.execute_query("SELECT COUNT(*) FROM MediaKeywords WHERE media_id = ?", (media_id,))
+ assert cursor.fetchone()[0] == 0
+
+ media_log = get_latest_log(db_instance, media_uuid)
+ assert media_log['operation'] == 'delete'
+
+ def test_optimistic_locking_prevents_update_with_stale_version(self, db_instance):
+ kw_id, kw_uuid = db_instance.add_keyword("conflict_test")
+ original_version = 1
+
+ db_instance.execute_query(
+ "UPDATE Keywords SET version = ?, client_id = ? WHERE id = ?",
+ (original_version + 1, "external_client", kw_id), commit=True
+ )
+
+ cursor = db_instance.execute_query(
+ "UPDATE Keywords SET keyword='stale_update', version=?, client_id=? WHERE id=? AND version=?",
+ (original_version + 1, db_instance.client_id, kw_id, original_version), commit=True
+ )
+ assert cursor.rowcount == 0
+
+ def test_version_validation_trigger(self, db_instance):
+ kw_id, kw_uuid = db_instance.add_keyword("validation_test")
+ current_version = get_entity_version(db_instance, "Keywords", kw_uuid)
+
+ with pytest.raises(sqlite3.IntegrityError, match="Version must increment by exactly 1"):
+ db_instance.execute_query(
+ "UPDATE Keywords SET version = ? WHERE id = ?",
+ (current_version + 2, kw_id), commit=True
+ )
+
+ def test_client_id_validation_trigger(self, db_instance):
+ kw_id, kw_uuid = db_instance.add_keyword("clientid_test")
+ current_version = get_entity_version(db_instance, "Keywords", kw_uuid)
+
+ with pytest.raises(sqlite3.IntegrityError, match="Client ID cannot be NULL or empty"):
+ db_instance.execute_query(
+ "UPDATE Keywords SET version = ?, client_id = NULL WHERE id = ?",
+ (current_version + 1, kw_id), commit=True
+ )
+
+
+class TestSyncLogManagement:
+ @pytest.fixture(autouse=True)
+ def setup_db(self, db_instance):
+ """Use autouse to provide the db_instance to every test in this class."""
+ # Add some initial data to generate logs
+ db_instance.add_keyword("log_kw_1")
+ time.sleep(0.01)
+ db_instance.add_keyword("log_kw_2")
+ time.sleep(0.01)
+ db_instance.add_keyword("log_kw_3")
+ db_instance.soft_delete_keyword("log_kw_2")
+ self.db = db_instance
+
+ def test_get_sync_log_entries_all(self):
+ logs = self.db.get_sync_log_entries()
+ assert len(logs) == 4
+ assert logs[0]['change_id'] == 1
+
+ def test_get_sync_log_entries_since(self):
+ logs = self.db.get_sync_log_entries(since_change_id=2)
+ assert len(logs) == 2
+ assert logs[0]['change_id'] == 3
+
+ def test_get_sync_log_entries_limit(self):
+ logs = self.db.get_sync_log_entries(limit=2)
+ assert len(logs) == 2
+ assert logs[0]['change_id'] == 1
+ assert logs[1]['change_id'] == 2
+
+ def test_delete_sync_log_entries_specific(self):
+ initial_logs = self.db.get_sync_log_entries()
+ ids_to_delete = [initial_logs[1]['change_id'], initial_logs[2]['change_id']]
+ deleted_count = self.db.delete_sync_log_entries(ids_to_delete)
+ assert deleted_count == 2
+ remaining_ids = {log['change_id'] for log in self.db.get_sync_log_entries()}
+ assert remaining_ids == {1, 4}
+
+ def test_delete_sync_log_entries_before(self):
+ deleted_count = self.db.delete_sync_log_entries_before(3)
+ assert deleted_count == 3
+ remaining_logs = self.db.get_sync_log_entries()
+ assert len(remaining_logs) == 1
+ assert remaining_logs[0]['change_id'] == 4
+
+ def test_delete_sync_log_entries_invalid_id(self):
+ with pytest.raises(ValueError):
+ self.db.delete_sync_log_entries([1, "two", 3])
+
+#
+# End of test_media_db_v2.py
+########################################################################################################################
+
diff --git a/Tests/DB/test_sync_client.py b/Tests/Media_DB/test_sync_client.py
similarity index 99%
rename from Tests/DB/test_sync_client.py
rename to Tests/Media_DB/test_sync_client.py
index 4a18a7e5..b0bf5c84 100644
--- a/Tests/DB/test_sync_client.py
+++ b/Tests/Media_DB/test_sync_client.py
@@ -12,8 +12,8 @@
import requests
#
# Local Imports
-from .test_sqlite_db import get_entity_version
-from tldw_cli.tldw_app.DB.Sync_Client import ClientSyncEngine
+from .test_media_db_v2 import get_entity_version
+from tldw_chatbook.DB.Sync_Client import ClientSyncEngine
#
#######################################################################################################################
#
diff --git a/Tests/Prompts_DB/__init__.py b/Tests/Prompts_DB/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/Tests/Prompts_DB/tests_prompts_db.py b/Tests/Prompts_DB/tests_prompts_db.py
new file mode 100644
index 00000000..d69f981d
--- /dev/null
+++ b/Tests/Prompts_DB/tests_prompts_db.py
@@ -0,0 +1,1102 @@
+# test_Prompts_DB.py
+import sqlite3
+import unittest
+import uuid
+import os
+import shutil
+import tempfile
+import threading
+from pathlib import Path
+
+# The module to be tested
+from tldw_chatbook.DB.Prompts_DB import (
+ PromptsDatabase,
+ DatabaseError,
+ SchemaError,
+ InputError,
+ ConflictError,
+ add_or_update_prompt,
+ load_prompt_details_for_ui,
+ export_prompt_keywords_to_csv,
+ view_prompt_keywords_markdown,
+ export_prompts_formatted,
+)
+
+
+# --- Test Case Base ---
+class BaseTestCase(unittest.TestCase):
+ """Base class for tests, handles temporary DB setup and teardown."""
+
+ def setUp(self):
+ """Set up a new in-memory database for each test."""
+ self.client_id = "test_client_1"
+ self.db = PromptsDatabase(':memory:', client_id=self.client_id)
+ # For tests requiring a file-based DB
+ self.temp_dir = tempfile.mkdtemp()
+ self.db_path = Path(self.temp_dir) / "test_prompts.db"
+ # ---- Add a list to track all created file DBs ----
+ self.file_db_instances = []
+
+ def tearDown(self):
+ """Close connection and clean up resources."""
+ if hasattr(self, 'db') and self.db:
+ self.db.close_connection()
+ # ---- Close all tracked file DB instances ----
+ for instance in self.file_db_instances:
+ instance.close_connection()
+ # Now it's safe to remove the directory
+ shutil.rmtree(self.temp_dir)
+
+ def _get_file_db(self):
+ """Helper to get a file-based database instance, ensuring it's tracked for cleanup."""
+ # ---- Track the new instance ----
+ instance = PromptsDatabase(self.db_path, client_id=f"test_client_file_{len(self.file_db_instances)}")
+ self.file_db_instances.append(instance)
+ return instance
+
+ def _add_sample_data(self, db_instance):
+ """Helper to populate a given database instance with some initial data."""
+ # Note: All arguments are provided to match the fixed library code
+ db_instance.add_prompt(
+ name="Recipe Generator",
+ author="ChefAI",
+ details="Creates recipes for cooking",
+ system_prompt="You are a chef.",
+ user_prompt="Give me a recipe for pasta.",
+ keywords=["food", "cooking"],
+ overwrite=True
+ )
+ db_instance.add_prompt(
+ name="Code Explainer",
+ author="DevHelper",
+ details="Explains python code snippets",
+ system_prompt="You are a senior dev.",
+ user_prompt="Explain this python code.",
+ keywords=["code", "python"],
+ overwrite=True
+ )
+ db_instance.add_prompt(
+ name="Poem Writer",
+ author="BardBot",
+ details="Writes poems about nature",
+ system_prompt="You are a poet.",
+ user_prompt="Write a poem about the sea.",
+ keywords=["writing", "poetry"],
+ overwrite=True
+ )
+ # Add a deleted prompt for testing filters
+ pid, _, _ = db_instance.add_prompt(
+ name="Old Prompt",
+ author="Old",
+ details="Old details",
+ overwrite=True
+ )
+ db_instance.soft_delete_prompt(pid)
+
+
+# --- Test Suites ---
+class TestDatabaseInitialization(BaseTestCase):
+
+ def test_init_success_in_memory(self):
+ self.assertIsNotNone(self.db)
+ self.assertIsInstance(self.db, PromptsDatabase)
+ self.assertTrue(self.db.is_memory_db)
+ self.assertEqual(self.db.client_id, self.client_id)
+
+ def test_init_success_file_based(self):
+ file_db = self._get_file_db()
+ self.assertTrue(self.db_path.exists())
+ self.assertFalse(file_db.is_memory_db)
+ conn = file_db.get_connection()
+ # WAL mode is set for file-based dbs
+ cursor = conn.execute("PRAGMA journal_mode;")
+ self.assertEqual(cursor.fetchone()[0].lower(), 'wal')
+
+ def test_init_failure_no_client_id(self):
+ with self.assertRaises(ValueError):
+ PromptsDatabase(':memory:', client_id=None)
+ with self.assertRaises(ValueError):
+ PromptsDatabase(':memory:', client_id="")
+
+ def test_schema_version_check(self):
+ conn = self.db.get_connection()
+ version = conn.execute("SELECT version FROM schema_version").fetchone()['version']
+ self.assertEqual(version, self.db._CURRENT_SCHEMA_VERSION)
+
+ def test_fts_tables_created(self):
+ conn = self.db.get_connection()
+ try:
+ conn.execute("SELECT * FROM prompts_fts LIMIT 1")
+ conn.execute("SELECT * FROM prompt_keywords_fts LIMIT 1")
+ except Exception as e:
+ self.fail(f"FTS tables not created or queryable: {e}")
+
+ def test_thread_safety_connections(self):
+ """Verify that different threads get different connection objects."""
+ connections = {}
+ db_instance = self._get_file_db()
+
+ def get_conn(thread_id):
+ conn = db_instance.get_connection()
+ connections[thread_id] = id(conn)
+ db_instance.close_connection()
+
+ thread1 = threading.Thread(target=get_conn, args=(1,))
+ thread2 = threading.Thread(target=get_conn, args=(2,))
+
+ thread1.start()
+ thread2.start()
+ thread1.join()
+ thread2.join()
+
+ self.assertIn(1, connections)
+ self.assertIn(2, connections)
+ self.assertNotEqual(connections[1], connections[2],
+ "Connections for different threads should be different objects")
+
+
+class TestCrudOperations(BaseTestCase):
+
+ def test_add_keyword(self):
+ kw_id, kw_uuid = self.db.add_keyword(" Test Keyword 1 ")
+ self.assertIsNotNone(kw_id)
+ self.assertIsNotNone(kw_uuid)
+
+ # Verify it was added correctly and normalized
+ kw_data = self.db.get_active_keyword_by_text("test keyword 1")
+ self.assertIsNotNone(kw_data)
+ self.assertEqual(kw_data['keyword'], "test keyword 1")
+ self.assertEqual(kw_data['id'], kw_id)
+
+ # Verify sync log
+ sync_logs = self.db.get_sync_log_entries()
+ self.assertEqual(len(sync_logs), 1)
+ log_entry = sync_logs[0]
+ self.assertEqual(log_entry['entity'], 'PromptKeywordsTable')
+ self.assertEqual(log_entry['entity_uuid'], kw_uuid)
+ self.assertEqual(log_entry['operation'], 'create')
+ self.assertEqual(log_entry['version'], 1)
+
+ # Verify FTS
+ res = self.db.execute_query(
+ "SELECT rowid FROM prompt_keywords_fts WHERE prompt_keywords_fts MATCH ?",
+ ("test",)
+ ).fetchone()
+ self.assertIsNotNone(res)
+ self.assertEqual(res['rowid'], kw_id)
+
+ def test_add_existing_keyword(self):
+ kw_id1, kw_uuid1 = self.db.add_keyword("duplicate")
+ kw_id2, kw_uuid2 = self.db.add_keyword("duplicate")
+ self.assertEqual(kw_id1, kw_id2)
+ self.assertEqual(kw_uuid1, kw_uuid2)
+ sync_logs = self.db.get_sync_log_entries()
+ self.assertEqual(len(sync_logs), 1) # Should only log the creation once
+
+ def test_add_prompt(self):
+ pid, puuid, msg = self.db.add_prompt("My Prompt", "Me", "Details here", keywords=["tag1", "tag2"])
+ self.assertIsNotNone(pid)
+ self.assertIsNotNone(puuid)
+ self.assertIn("added", msg)
+
+ prompt = self.db.get_prompt_by_id(pid)
+ self.assertEqual(prompt['name'], "My Prompt")
+ self.assertEqual(prompt['version'], 1)
+
+ keywords = self.db.fetch_keywords_for_prompt(pid)
+ self.assertIn("tag1", keywords)
+ self.assertIn("tag2", keywords)
+
+ # Check sync logs - 1 for prompt, 2 for keywords, 2 for links
+ sync_logs = self.db.get_sync_log_entries()
+ prompt_create_logs = [l for l in sync_logs if l['entity'] == 'Prompts' and l['operation'] == 'create']
+ self.assertEqual(len(prompt_create_logs), 1)
+ prompt_create_log = prompt_create_logs[0]
+ self.assertEqual(prompt_create_log['entity_uuid'], puuid)
+
+ # Verify FTS
+ res = self.db.execute_query(
+ "SELECT rowid FROM prompts_fts WHERE prompts_fts MATCH ?",
+ ("My Prompt",)
+ ).fetchone()
+ self.assertIsNotNone(res)
+ self.assertEqual(res['rowid'], pid)
+
+ def test_add_prompt_conflict(self):
+ self.db.add_prompt("Conflict Prompt", "Author", "Details")
+ with self.assertRaises(ConflictError):
+ self.db.add_prompt("Conflict Prompt", "Author", "Details", overwrite=False)
+
+ def test_add_prompt_overwrite(self):
+ pid1, _, _ = self.db.add_prompt("Overwrite Me", "Author1", "Details1")
+ pid2, _, msg = self.db.add_prompt("Overwrite Me", "Author2", "Details2", overwrite=True)
+ self.assertEqual(pid1, pid2)
+ self.assertIn("updated", msg)
+
+ prompt = self.db.get_prompt_by_id(pid1)
+ self.assertEqual(prompt['author'], "Author2")
+ self.assertEqual(prompt['version'], 2)
+
+ def test_soft_delete_prompt(self):
+ pid, puuid, _ = self.db.add_prompt("To Be Deleted", "Author", "Details", keywords=["temp"])
+
+ self.assertIsNotNone(self.db.get_prompt_by_id(pid))
+
+ success = self.db.soft_delete_prompt(pid)
+ self.assertTrue(success)
+
+ self.assertIsNone(self.db.get_prompt_by_id(pid))
+
+ deleted_prompt = self.db.get_prompt_by_id(pid, include_deleted=True)
+ self.assertIsNotNone(deleted_prompt)
+ self.assertEqual(deleted_prompt['deleted'], 1)
+ self.assertEqual(deleted_prompt['version'], 2)
+
+ res = self.db.execute_query(
+ "SELECT rowid FROM prompts_fts WHERE prompts_fts MATCH ?",
+ ("To Be Deleted",)
+ ).fetchone()
+ self.assertIsNone(res)
+
+ link_exists = self.db.execute_query("SELECT 1 FROM PromptKeywordLinks WHERE prompt_id=?", (pid,)).fetchone()
+ self.assertIsNone(link_exists)
+
+ sync_logs = self.db.get_sync_log_entries() # Check all logs
+ delete_log = next(l for l in sync_logs if l['entity'] == 'Prompts' and l['operation'] == 'delete')
+ unlink_log = next(l for l in sync_logs if l['entity'] == 'PromptKeywordLinks' and l['operation'] == 'unlink')
+ self.assertEqual(delete_log['entity_uuid'], puuid)
+ self.assertIn(puuid, unlink_log['entity_uuid'])
+
+ def test_soft_delete_keyword(self):
+ kw_id, kw_uuid = self.db.add_keyword("ephemeral")
+ self.db.add_prompt("Test Prompt", "Author", "Some details", keywords=["ephemeral"])
+
+ success = self.db.soft_delete_keyword("ephemeral")
+ self.assertTrue(success)
+
+ self.assertIsNone(self.db.get_active_keyword_by_text("ephemeral"))
+
+ res = self.db.execute_query(
+ "SELECT rowid FROM prompt_keywords_fts WHERE prompt_keywords_fts MATCH ?",
+ ("ephemeral",)
+ ).fetchone()
+ self.assertIsNone(res)
+
+ prompt = self.db.get_prompt_by_name("Test Prompt")
+ keywords = self.db.fetch_keywords_for_prompt(prompt['id'])
+ self.assertNotIn("ephemeral", keywords)
+
+ def test_update_prompt_by_id(self):
+ pid, puuid, _ = self.db.add_prompt("Initial Name", "Author", "Details", keywords=["old_kw"])
+
+ update_data = {"name": "Updated Name", "details": "New details", "keywords": ["new_kw", "another_kw"]}
+
+ updated_uuid, msg = self.db.update_prompt_by_id(pid, update_data)
+ self.assertEqual(updated_uuid, puuid)
+ self.assertIn("updated successfully", msg)
+
+ prompt = self.db.get_prompt_by_id(pid)
+ self.assertEqual(prompt['name'], "Updated Name")
+ self.assertEqual(prompt['details'], "New details")
+ self.assertEqual(prompt['version'], 2)
+
+ keywords = self.db.fetch_keywords_for_prompt(pid)
+ self.assertIn("new_kw", keywords)
+ self.assertIn("another_kw", keywords)
+ self.assertNotIn("old_kw", keywords)
+
+ res = self.db.execute_query(
+ "SELECT rowid FROM prompts_fts WHERE prompts_fts MATCH ?",
+ ("Updated Name",)
+ ).fetchone()
+ self.assertIsNotNone(res)
+ self.assertEqual(res['rowid'], pid)
+
+
+class TestQueryOperations(BaseTestCase):
+
+ def setUp(self):
+ super().setUp()
+ self._add_sample_data(self.db)
+
+ def test_list_prompts(self):
+ prompts, total_pages, page, total_items = self.db.list_prompts(page=1, per_page=2)
+ self.assertEqual(len(prompts), 2)
+ self.assertEqual(total_items, 3) # 3 active prompts
+ self.assertEqual(total_pages, 2)
+ self.assertEqual(page, 1)
+
+ def test_list_prompts_include_deleted(self):
+ _, _, _, total_items = self.db.list_prompts(include_deleted=True)
+ self.assertEqual(total_items, 4)
+
+ def test_fetch_prompt_details(self):
+ details = self.db.fetch_prompt_details("Recipe Generator")
+ self.assertIsNotNone(details)
+ self.assertEqual(details['name'], "Recipe Generator")
+ self.assertIn("food", details['keywords'])
+ self.assertIn("cooking", details['keywords'])
+
+ def test_fetch_all_keywords(self):
+ keywords = self.db.fetch_all_keywords()
+ expected = sorted(["food", "cooking", "code", "python", "writing", "poetry"])
+ self.assertEqual(sorted(keywords), expected)
+
+ def test_search_prompts_by_name(self):
+ results, total = self.db.search_prompts("Recipe")
+ self.assertEqual(total, 1)
+ self.assertEqual(results[0]['name'], "Recipe Generator")
+
+ def test_search_prompts_by_details(self):
+ results, total = self.db.search_prompts("python code", search_fields=['details', 'user_prompt'])
+ self.assertEqual(total, 1)
+ self.assertEqual(results[0]['name'], "Code Explainer")
+
+ def test_search_prompts_by_keyword(self):
+ results, total = self.db.search_prompts("poetry", search_fields=['keywords'])
+ self.assertEqual(total, 1)
+ self.assertEqual(results[0]['name'], "Poem Writer")
+
+
+class TestUtilitiesAndAdvancedFeatures(BaseTestCase):
+
+ def test_backup_database(self):
+ file_db = self._get_file_db()
+ file_db.add_prompt("Backup Test", "Tester", "Details")
+
+ backup_path = self.db_path.with_suffix('.backup.db')
+
+ success = file_db.backup_database(str(backup_path))
+ self.assertTrue(success)
+ self.assertTrue(backup_path.exists())
+
+ backup_db = PromptsDatabase(backup_path, client_id="backup_verifier")
+ prompt = backup_db.get_prompt_by_name("Backup Test")
+ self.assertIsNotNone(prompt)
+ self.assertEqual(prompt['author'], "Tester")
+
+ backup_db.close_connection()
+
+ def test_backup_database_same_file_fails(self):
+ file_db = self._get_file_db()
+ # The function catches this ValueError and returns False
+ success = file_db.backup_database(str(self.db_path))
+ self.assertFalse(success)
+
+ def test_transaction_rollback(self):
+ pid, _, _ = self.db.add_prompt("Initial", "Auth", "Det")
+
+ try:
+ with self.db.transaction():
+ self.db.execute_query("UPDATE Prompts SET name = ?, version = version + 1, client_id='t' WHERE id = ?",
+ ("Updated", pid), commit=False)
+ prompt_inside = self.db.get_prompt_by_id(pid)
+ self.assertEqual(prompt_inside['name'], "Updated")
+ raise ValueError("Intentional failure to trigger rollback")
+ except ValueError:
+ pass # Expected exception
+
+ prompt_outside = self.db.get_prompt_by_id(pid)
+ self.assertEqual(prompt_outside['name'], "Initial")
+
+ def test_delete_sync_log_entries(self):
+ self.db.add_keyword("kw1")
+ self.db.add_keyword("kw2")
+
+ logs = self.db.get_sync_log_entries()
+ self.assertEqual(len(logs), 2)
+ log_ids_to_delete = [log['change_id'] for log in logs]
+
+ deleted_count = self.db.delete_sync_log_entries(log_ids_to_delete)
+ self.assertEqual(deleted_count, 2)
+
+ remaining_logs = self.db.get_sync_log_entries()
+ self.assertEqual(len(remaining_logs), 0)
+
+ def test_update_prompt_version_conflict(self):
+ """
+ Tests that the database trigger prevents updates that violate the versioning rule.
+ This is a direct test of the database integrity layer.
+ """
+ db = self.db
+ pid, _, _ = db.add_prompt("Trigger Test", "Author", "Details") # Prompt is now at version 1
+
+ db.update_prompt_by_id(pid, {'details': 'New Details'})
+ prompt = db.get_prompt_by_id(pid)
+ self.assertEqual(prompt['version'], 2, "Version should be 2 after the first update.")
+
+ conn = db.get_connection()
+
+ with self.assertRaises(sqlite3.IntegrityError) as cm:
+ with conn:
+ conn.execute("UPDATE Prompts SET version = 2, client_id='raw' WHERE id = ?", (pid,))
+
+ self.assertIn("Version must increment by exactly 1", str(cm.exception))
+
+ final_prompt = db.get_prompt_by_id(pid)
+ self.assertEqual(final_prompt['version'], 2, "Version should remain 2 after the failed update.")
+
+
+class TestStandaloneFunctions(BaseTestCase):
+ def setUp(self):
+ super().setUp()
+ self._add_sample_data(self.db)
+
+ def test_add_or_update_prompt(self):
+ # Test update
+ pid, _, msg = add_or_update_prompt(self.db, "Recipe Generator", "New Chef", "New details", keywords=["italian"])
+ self.assertIn("updated", msg)
+ prompt = self.db.get_prompt_by_id(pid)
+ self.assertEqual(prompt['author'], "New Chef")
+ keywords = self.db.fetch_keywords_for_prompt(pid)
+ self.assertIn("italian", keywords)
+
+ # Test add
+ pid_new, _, msg_new = add_or_update_prompt(self.db, "New Standalone Prompt", "Tester", "Details")
+ self.assertIn("added", msg_new)
+ self.assertIsNotNone(self.db.get_prompt_by_id(pid_new))
+
+ def test_load_prompt_details_for_ui(self):
+ name, author, details, sys_p, user_p, kws = load_prompt_details_for_ui(self.db, "Code Explainer")
+ self.assertEqual(name, "Code Explainer")
+ self.assertEqual(author, "DevHelper")
+ self.assertEqual(sys_p, "You are a senior dev.")
+ self.assertEqual(kws, "code, python")
+
+ def test_load_prompt_details_not_found(self):
+ result = load_prompt_details_for_ui(self.db, "Non Existent")
+ self.assertEqual(result, ("", "", "", "", "", ""))
+
+ def test_view_prompt_keywords_markdown(self):
+ md_output = view_prompt_keywords_markdown(self.db)
+ self.assertIn("### Current Active Prompt Keywords:", md_output)
+ self.assertIn("- code (1 active prompts)", md_output)
+ self.assertIn("- cooking (1 active prompts)", md_output)
+
+ def test_export_keywords_to_csv(self):
+ file_db = self._get_file_db()
+ add_or_update_prompt(file_db, "Prompt 1", "Auth", "Det", keywords=["a", "b"])
+ add_or_update_prompt(file_db, "Prompt 2", "Auth", "Det", keywords=["b", "c"])
+
+ status, file_path = export_prompt_keywords_to_csv(file_db)
+ self.assertIn("Successfully exported", status)
+ self.assertTrue(os.path.exists(file_path))
+
+ with open(file_path, 'r', encoding='utf-8') as f:
+ content = f.read()
+ self.assertIn("Keyword,Associated Prompts", content)
+ self.assertIn("a,Prompt 1,1", content)
+ self.assertIn("c,Prompt 2,1", content)
+ # Handle potential ordering difference in GROUP_CONCAT
+ self.assertTrue("b,\"Prompt 1,Prompt 2\",2" in content or "b,\"Prompt 2,Prompt 1\",2" in content)
+
+ os.remove(file_path)
+
+ def test_export_prompts_formatted_csv(self):
+ file_db = self._get_file_db()
+ self._add_sample_data(file_db)
+
+ status, file_path = export_prompts_formatted(file_db, export_format='csv')
+ self.assertIn("Successfully exported 3 prompts", status)
+ self.assertTrue(os.path.exists(file_path))
+
+ with open(file_path, 'r', encoding='utf-8') as f:
+ content = f.read()
+ self.assertIn("Name,UUID,Author,Details,System Prompt,User Prompt,Keywords", content)
+ self.assertIn("Recipe Generator", content)
+ # FIX: Keywords are sorted by the fetch method, so 'cooking' comes before 'food'.
+ self.assertIn('"cooking, food"', content)
+
+ os.remove(file_path)
+
+ def test_export_prompts_formatted_markdown(self):
+ file_db = self._get_file_db()
+ self._add_sample_data(file_db)
+
+ status, file_path = export_prompts_formatted(file_db, export_format='markdown')
+ self.assertIn("Successfully exported 3 prompts to Markdown", status)
+ self.assertTrue(os.path.exists(file_path))
+ # A more thorough test could unzip and verify content, but file creation is a good start.
+ os.remove(file_path)
+
+
+class TestDatabaseIntegrityAndSchema(BaseTestCase):
+ """
+ Tests focused on low-level database integrity, schema rules, and triggers.
+ These tests may involve direct SQL execution to bypass library methods.
+ """
+
+ def test_database_version_too_high_raises_error(self):
+ """Ensure initializing a DB with a future schema version fails gracefully."""
+ file_db = self._get_file_db()
+ conn = file_db.get_connection()
+ # Manually set the schema version to a higher number
+ conn.execute("UPDATE schema_version SET version = 99")
+ conn.commit()
+ # FIX: We must close the connection so the file handle is released before the next open attempt
+ file_db.close_connection()
+ # The instance is now closed, remove it from tracking to avoid double-closing
+ self.file_db_instances.remove(file_db)
+
+ # Now, trying to create a new instance pointing to this DB should fail
+ # FIX: The __init__ method wraps SchemaError in a DatabaseError.
+ with self.assertRaisesRegex(DatabaseError, "newer than supported"):
+ PromptsDatabase(self.db_path, client_id="test_client_fail")
+
+ def test_trigger_prevents_bad_version_update(self):
+ """Verify the SQL trigger prevents updates that don't increment version by 1."""
+ pid, _, _ = self.db.add_prompt("Trigger Test", "T", "D") # Version is 1
+ conn = self.db.get_connection()
+
+ with self.assertRaises(sqlite3.IntegrityError) as cm:
+ conn.execute("UPDATE Prompts SET version = 3 WHERE id = ?", (pid,))
+ self.assertIn("Version must increment by exactly 1", str(cm.exception))
+
+ # This should succeed
+ conn.execute("UPDATE Prompts SET version = 2, client_id='raw_sql' WHERE id = ?", (pid,))
+ conn.commit()
+ prompt = self.db.get_prompt_by_id(pid)
+ self.assertEqual(prompt['version'], 2)
+
+ def test_trigger_prevents_uuid_change(self):
+ """Verify the SQL trigger prevents changing a UUID on update."""
+ pid, original_uuid, _ = self.db.add_prompt("UUID Lock", "T", "D")
+ conn = self.db.get_connection()
+ new_uuid = str(uuid.uuid4())
+
+ with self.assertRaises(sqlite3.IntegrityError) as cm:
+ # Try to update the UUID (and correctly increment version)
+ conn.execute("UPDATE Prompts SET uuid = ?, version = 2, client_id='raw' WHERE id = ?", (new_uuid, pid))
+ self.assertIn("UUID cannot be changed", str(cm.exception))
+
+ # Verify UUID is unchanged
+ prompt = self.db.get_prompt_by_id(pid)
+ self.assertEqual(prompt['uuid'], original_uuid)
+
+
+class TestAdvancedBehaviorsAndEdgeCases(BaseTestCase):
+ """
+ Tests for more nuanced behaviors and specific edge cases not covered
+ in standard CRUD operations.
+ """
+
+ def test_reopen_closed_connection(self):
+ """Test that the library can reopen a connection that was explicitly closed."""
+ file_db = self._get_file_db()
+
+ conn1 = file_db.get_connection()
+ self.assertIsNotNone(conn1)
+ conn1.close()
+
+ conn2 = file_db.get_connection()
+ self.assertIsNotNone(conn2)
+ self.assertNotEqual(id(conn1), id(conn2))
+
+ file_db.add_keyword("test_after_reopen")
+ self.assertIsNotNone(file_db.get_active_keyword_by_text("test_after_reopen"))
+
+ def test_undelete_keyword(self):
+ """Test that adding an already soft-deleted keyword undeletes it."""
+ self.db.add_keyword("to be deleted and restored")
+ self.db.soft_delete_keyword("to be deleted and restored")
+ self.assertIsNone(self.db.get_active_keyword_by_text("to be deleted and restored"))
+
+ # Now, add it again
+ kw_id, kw_uuid = self.db.add_keyword("to be deleted and restored")
+
+ # Verify it's active again
+ restored_kw = self.db.get_active_keyword_by_text("to be deleted and restored")
+ self.assertIsNotNone(restored_kw)
+ self.assertEqual(restored_kw['id'], kw_id)
+ self.assertEqual(restored_kw['version'], 3) # 1: create, 2: delete, 3: undelete (update)
+
+ # Check sync log for the 'update' operation
+ sync_logs = self.db.get_sync_log_entries()
+ undelete_log = next(log for log in sync_logs if log['version'] == 3)
+ self.assertEqual(undelete_log['operation'], 'update')
+ self.assertEqual(undelete_log['payload']['deleted'], 0)
+
+ def test_update_prompt_name_to_existing_name_conflict(self):
+ """Ensure updating a prompt's name to another existing prompt's name fails."""
+ self.db.add_prompt("Prompt A", "Author", "Details")
+ pid_b, _, _ = self.db.add_prompt("Prompt B", "Author", "Details")
+
+ with self.assertRaises(ConflictError):
+ self.db.update_prompt_by_id(pid_b, {"name": "Prompt A"})
+
+ def test_nested_transactions(self):
+ pid, _, _ = self.db.add_prompt("Transaction Test", "T", "D")
+
+ try:
+ with self.db.transaction(): # Outer transaction
+ self.db.execute_query("UPDATE Prompts SET name = 'Outer Update', version=2, client_id='t' WHERE id = ?",
+ (pid,))
+
+ with self.db.transaction(): # Inner transaction
+ self.db.execute_query(
+ "UPDATE Prompts SET author = 'Inner Update', version=3, client_id='t' WHERE id = ?", (pid,))
+
+ prompt_inside = self.db.get_prompt_by_id(pid)
+ self.assertEqual(prompt_inside['author'], 'Inner Update')
+
+ raise ValueError("Force rollback of outer transaction")
+
+ except ValueError:
+ pass
+
+ prompt_outside = self.db.get_prompt_by_id(pid)
+ self.assertEqual(prompt_outside['name'], "Transaction Test")
+ self.assertEqual(prompt_outside['author'], "T")
+ self.assertEqual(prompt_outside['version'], 1)
+
+ def test_soft_delete_prompt_by_name_and_uuid(self):
+ """Test soft deletion using name and UUID identifiers."""
+ _, p1_uuid, _ = self.db.add_prompt("Deletable By Name", "A", "D")
+ p2_id, p2_uuid, _ = self.db.add_prompt("Deletable By UUID", "B", "E")
+
+ # Delete by name
+ self.assertTrue(self.db.soft_delete_prompt("Deletable By Name"))
+ self.assertIsNone(self.db.get_prompt_by_name("Deletable By Name"))
+
+ # Delete by UUID
+ self.assertTrue(self.db.soft_delete_prompt(p2_uuid))
+ self.assertIsNone(self.db.get_prompt_by_id(p2_id))
+
+
+class TestSearchFunctionality(BaseTestCase):
+ """More detailed tests for the search_prompts function."""
+
+ def setUp(self):
+ super().setUp()
+ self._add_sample_data(self.db)
+ self.db.add_prompt("Shared Term Prompt", "Author", "This prompt contains python.", keywords=["generic"])
+ self.db.add_prompt("Another Code Prompt", "DevHelper", "More code things.", keywords=["python"])
+
+ def test_search_with_no_query_returns_all_active(self):
+ """Searching with no query should act like listing all active prompts."""
+ results, total = self.db.search_prompts(None)
+ # 3 from _add_sample_data + 2 added in this setUp = 5 active prompts
+ self.assertEqual(total, 5)
+ self.assertEqual(len(results), 5)
+
+ def test_search_with_no_results(self):
+ """Ensure a search with no matches returns an empty list and zero total."""
+ results, total = self.db.search_prompts("nonexistentxyz")
+ self.assertEqual(total, 0)
+ self.assertEqual(len(results), 0)
+
+ def test_search_pagination(self):
+ """Test if pagination works correctly on search results."""
+ results, total = self.db.search_prompts("python", search_fields=['details', 'keywords'], page=1,
+ results_per_page=2)
+ self.assertEqual(total, 3) # Code Explainer, Shared Term, Another Code
+ self.assertEqual(len(results), 2)
+
+ results_p2, total_p2 = self.db.search_prompts("python", search_fields=['details', 'keywords'], page=2,
+ results_per_page=2)
+ self.assertEqual(total_p2, 3)
+ self.assertEqual(len(results_p2), 1)
+
+ def test_search_across_multiple_fields(self):
+ """Test searching in both details and keywords simultaneously."""
+ results, total = self.db.search_prompts("python", search_fields=['details', 'keywords'])
+ self.assertEqual(total, 3)
+ names = {r['name'] for r in results}
+ self.assertIn("Code Explainer", names)
+ self.assertIn("Shared Term Prompt", names)
+ self.assertIn("Another Code Prompt", names)
+
+ def test_search_with_invalid_fts_syntax_raises_error(self):
+ """Verify that malformed FTS queries raise a DatabaseError."""
+ # An unclosed quote is invalid syntax
+ with self.assertRaises(DatabaseError):
+ self.db.search_prompts('invalid "syntax', search_fields=['name'])
+
+
+class TestStandaloneFunctionExports(BaseTestCase):
+ """Tests for variations in the standalone export functions."""
+
+ def setUp(self):
+ super().setUp()
+ self.file_db = self._get_file_db()
+ self._add_sample_data(self.file_db)
+
+ def test_export_with_no_matching_prompts(self):
+ """Test export when the filter criteria yields no results."""
+ status, file_path = export_prompts_formatted(
+ self.file_db,
+ filter_keywords=["nonexistent_keyword"]
+ )
+ self.assertIn("No prompts found", status)
+ self.assertEqual(file_path, "None")
+
+ def test_export_csv_minimal_columns(self):
+ """Test CSV export with most boolean flags turned off."""
+ status, file_path = export_prompts_formatted(
+ self.file_db,
+ export_format='csv',
+ include_details=False,
+ include_system=False,
+ include_user=False,
+ include_associated_keywords=False
+ )
+ self.assertIn("Successfully exported", status)
+ self.assertTrue(os.path.exists(file_path))
+
+ with open(file_path, 'r', encoding='utf-8') as f:
+ header = f.readline().strip()
+ self.assertEqual(header, "Name,UUID,Author")
+
+ os.remove(file_path)
+
+ def test_export_markdown_with_different_template(self):
+ """Test Markdown export using a non-default template."""
+ status, file_path = export_prompts_formatted(
+ self.file_db,
+ export_format='markdown',
+ markdown_template_name="Detailed Template"
+ )
+ self.assertIn("Successfully exported", status)
+ self.assertTrue(os.path.exists(file_path))
+ os.remove(file_path)
+
+
+class TestConcurrencyAndDataIntegrity(BaseTestCase):
+ """
+ Tests for race conditions, data encoding, and integrity under stress.
+ """
+
+ def test_concurrent_updates_to_same_prompt(self):
+ """
+ Simulate a race condition where two threads update the same prompt.
+ One should succeed, the other should fail with a ConflictError.
+ """
+ file_db = self._get_file_db()
+ pid, _, _ = file_db.add_prompt("Race Condition", "Initial", "Data")
+
+ results = {}
+ barrier = threading.Barrier(2, timeout=5)
+
+ def worker(user_id):
+ try:
+ db_instance = PromptsDatabase(self.db_path, client_id=f"worker_{user_id}")
+ barrier.wait()
+ db_instance.update_prompt_by_id(pid, {'details': f'Updated by {user_id}'})
+ results[user_id] = "success"
+ except ConflictError:
+ results[user_id] = "conflict"
+ except DatabaseError as e:
+ if "locked" in str(e).lower() or "conflict" in str(e).lower():
+ results[user_id] = "conflict"
+ else:
+ results[user_id] = e
+ except Exception as e:
+ results[user_id] = e
+ finally:
+ if 'db_instance' in locals():
+ db_instance.close_connection()
+
+ thread1 = threading.Thread(target=worker, args=(1,))
+ thread2 = threading.Thread(target=worker, args=(2,))
+
+ thread1.start()
+ thread2.start()
+ thread1.join()
+ thread2.join()
+
+ self.assertIn("success", results.values())
+ self.assertIn("conflict", results.values())
+
+ final_prompt = file_db.get_prompt_by_id(pid)
+ self.assertEqual(final_prompt['version'], 2)
+ self.assertTrue(final_prompt['details'].startswith("Updated by"))
+
+ def test_unicode_character_support(self):
+ """Ensure all text fields correctly handle Unicode characters."""
+ unicode_name = "こんにちは世界"
+ unicode_author = "Александр"
+ unicode_details = "Testing emoji support 👍 and special characters ç, é, à."
+ unicode_keywords = ["你好", "世界", "prüfung"]
+
+ pid, _, _ = self.db.add_prompt(
+ name=unicode_name,
+ author=unicode_author,
+ details=unicode_details,
+ keywords=unicode_keywords
+ )
+
+ prompt = self.db.fetch_prompt_details(pid)
+ self.assertEqual(prompt['name'], unicode_name)
+ self.assertEqual(prompt['author'], unicode_author)
+ self.assertEqual(prompt['details'], unicode_details)
+ self.assertEqual(sorted(prompt['keywords']), sorted(unicode_keywords))
+
+ results, total = self.db.search_prompts("你好", search_fields=['keywords'])
+ self.assertEqual(total, 1)
+ self.assertEqual(results[0]['name'], unicode_name)
+
+ status, file_path = export_prompts_formatted(self.db, 'csv')
+ self.assertTrue(os.path.exists(file_path))
+ with open(file_path, 'r', encoding='utf-8') as f:
+ content = f.read()
+ self.assertIn(unicode_name, content)
+ self.assertIn(unicode_author, content)
+ self.assertIn("prüfung", content)
+ os.remove(file_path)
+
+ def test_fts_desynchronization_by_direct_sql(self):
+ """
+ Demonstrates that direct SQL updates will de-sync the FTS table.
+ """
+ unique_term = "zzyzx"
+ pid, _, _ = self.db.add_prompt("FTS Sync Test", "Author", f"Details with {unique_term}")
+
+ results, total = self.db.search_prompts(unique_term)
+ self.assertEqual(total, 1)
+
+ conn = self.db.get_connection()
+ conn.execute(
+ "UPDATE Prompts SET details='Details are now different', version=2, client_id='raw' WHERE id=?",
+ (pid,)
+ )
+ conn.commit()
+
+ prompt = self.db.get_prompt_by_id(pid)
+ self.assertEqual(prompt['details'], "Details are now different")
+
+ # FTS search for the *new* term will FAIL because FTS was not updated
+ results_new, total_new = self.db.search_prompts("different")
+ self.assertEqual(total_new, 0)
+
+ # FTS search for the *old* term will SUCCEED because FTS is stale
+ results_old, total_old = self.db.search_prompts(unique_term)
+ self.assertEqual(total_old, 1)
+
+
+# Additional tests from the failing set, now expected to pass
+class TestAdvancedStateTransitions(BaseTestCase):
+ def test_update_soft_deleted_prompt_restores_it(self):
+ pid, _, _ = self.db.add_prompt("To be deleted", "Author", "Old Details")
+ self.db.soft_delete_prompt(pid)
+ self.assertIsNone(self.db.get_prompt_by_id(pid))
+
+ update_data = {"details": "New, Restored Details"}
+ self.db.update_prompt_by_id(pid, update_data)
+
+ restored_prompt = self.db.get_prompt_by_id(pid)
+ self.assertIsNotNone(restored_prompt)
+ self.assertEqual(restored_prompt['deleted'], 0)
+ self.assertEqual(restored_prompt['details'], "New, Restored Details")
+ self.assertEqual(restored_prompt['version'], 3)
+
+ def test_handling_of_corrupted_sync_log_payload(self):
+ self.db.add_keyword("good_payload")
+ log_entry = self.db.get_sync_log_entries()[0]
+ change_id = log_entry['change_id']
+
+ conn = self.db.get_connection()
+ conn.execute(
+ "UPDATE sync_log SET payload = ? WHERE change_id = ?",
+ ("{'bad_json': this_is_not_valid}", change_id)
+ )
+ conn.commit()
+
+ all_logs = self.db.get_sync_log_entries()
+ corrupted_log = next(log for log in all_logs if log['change_id'] == change_id)
+ self.assertIsNone(corrupted_log['payload'])
+
+ def test_add_keyword_with_only_whitespace_fails(self):
+ """Test that adding a keyword that is only whitespace raises an InputError."""
+ with self.assertRaises(InputError):
+ self.db.add_keyword(" ")
+
+ # Verify no sync log was created
+ self.assertEqual(len(self.db.get_sync_log_entries()), 0)
+
+ def test_add_prompt_with_only_whitespace_name_fails(self):
+ """Test that adding a prompt with a whitespace-only name raises an InputError."""
+ with self.assertRaises(InputError):
+ self.db.add_prompt(" ", "Author", "Details")
+
+ # Verify no sync log was created
+ self.assertEqual(len(self.db.get_sync_log_entries()), 0)
+
+
+class TestSyncLogManagement(BaseTestCase):
+ """
+ In-depth tests for the sync_log table management and access methods.
+ """
+
+ def test_get_sync_log_with_since_change_id_and_limit(self):
+ """Verify that fetching logs with a starting ID and a limit works correctly."""
+ # Add 5 items, which should generate at least 5 logs
+ self.db.add_keyword("kw1")
+ self.db.add_keyword("kw2")
+ self.db.add_keyword("kw3")
+ self.db.add_keyword("kw4")
+ self.db.add_keyword("kw5")
+
+ all_logs = self.db.get_sync_log_entries()
+ self.assertEqual(len(all_logs), 5)
+
+ # Get logs since change_id 2 (should get 3, 4, 5)
+ logs_since_2 = self.db.get_sync_log_entries(since_change_id=2)
+ self.assertEqual(len(logs_since_2), 3)
+ self.assertEqual(logs_since_2[0]['change_id'], 3)
+
+ # Get logs since change_id 2 with a limit of 1 (should only get 3)
+ logs_limited = self.db.get_sync_log_entries(since_change_id=2, limit=1)
+ self.assertEqual(len(logs_limited), 1)
+ self.assertEqual(logs_limited[0]['change_id'], 3)
+
+ def test_delete_sync_log_with_nonexistent_ids(self):
+ """Ensure deleting a mix of existing and non-existent log IDs works as expected."""
+ self.db.add_keyword("kw1") # change_id 1
+ self.db.add_keyword("kw2") # change_id 2
+
+ # Attempt to delete IDs 1 and 9999 (which doesn't exist)
+ deleted_count = self.db.delete_sync_log_entries([1, 9999])
+
+ self.assertEqual(deleted_count, 1)
+
+ remaining_logs = self.db.get_sync_log_entries()
+ self.assertEqual(len(remaining_logs), 1)
+ self.assertEqual(remaining_logs[0]['change_id'], 2)
+
+
+class TestComplexStateAndInputInteractions(BaseTestCase):
+ """
+ Tests for nuanced interactions between methods and states.
+ """
+
+ def test_add_prompt_with_overwrite_false_on_deleted_prompt(self):
+ """
+ Verify the specific behavior of add_prompt(overwrite=False) when a prompt
+ is soft-deleted. It should not restore it and should return a specific message.
+ """
+ self.db.add_prompt("Deleted but Exists", "Author", "Details")
+ self.db.soft_delete_prompt("Deleted but Exists")
+
+ # Attempt to add it again without overwrite flag
+ pid, puuid, msg = self.db.add_prompt("Deleted but Exists", "New Author", "New Details", overwrite=False)
+
+ self.assertIn("exists but is soft-deleted", msg)
+
+ # Verify it was not restored or updated
+ prompt = self.db.get_prompt_by_name("Deleted but Exists", include_deleted=True)
+ self.assertEqual(prompt['author'], "Author") # Should be the original author
+ self.assertEqual(prompt['deleted'], 1) # Should remain deleted
+
+ def test_soft_delete_nonexistent_item_returns_false(self):
+ """Ensure attempting to delete non-existent items returns False and doesn't error."""
+ result_prompt = self.db.soft_delete_prompt("non-existent prompt")
+ self.assertFalse(result_prompt)
+
+ result_keyword = self.db.soft_delete_keyword("non-existent keyword")
+ self.assertFalse(result_keyword)
+
+ def test_update_keywords_for_prompt_with_empty_list_removes_all(self):
+ """Updating keywords with an empty list should remove all existing keywords."""
+ pid, _, _ = self.db.add_prompt("Keyword Test", "A", "D", keywords=["kw1", "kw2"])
+ self.assertEqual(len(self.db.fetch_keywords_for_prompt(pid)), 2)
+
+ # Update with an empty list
+ self.db.update_keywords_for_prompt(pid, [])
+
+ self.assertEqual(len(self.db.fetch_keywords_for_prompt(pid)), 0)
+
+ # Verify unlink events were logged
+ unlink_logs = [l for l in self.db.get_sync_log_entries() if l['operation'] == 'unlink']
+ self.assertEqual(len(unlink_logs), 2)
+
+ def test_update_keywords_for_prompt_is_idempotent(self):
+ """Running update_keywords with the same list should result in no changes or new logs."""
+ pid, _, _ = self.db.add_prompt("Idempotent Test", "A", "D", keywords=["kw1", "kw2"])
+
+ initial_log_count = len(self.db.get_sync_log_entries())
+
+ # Rerun with the same keywords
+ self.db.update_keywords_for_prompt(pid, ["kw1", "kw2"])
+
+ final_log_count = len(self.db.get_sync_log_entries())
+ self.assertEqual(initial_log_count, final_log_count,
+ "No new sync logs should be created for an idempotent update")
+
+ def test_update_keywords_for_prompt_handles_duplicates_and_whitespace(self):
+ """Ensure keyword lists are properly normalized before processing."""
+ pid, _, _ = self.db.add_prompt("Normalization Test", "A", "D")
+
+ messy_keywords = [" Tag A ", "tag b", "Tag A", " ", "tag c "]
+ self.db.update_keywords_for_prompt(pid, messy_keywords)
+
+ final_keywords = self.db.fetch_keywords_for_prompt(pid)
+ self.assertEqual(sorted(final_keywords), ["tag a", "tag b", "tag c"])
+
+
+class TestBulkOperationsAndScale(BaseTestCase):
+ """
+ Tests focusing on bulk methods and behavior with a larger number of records.
+ """
+
+ def test_execute_many_success(self):
+ """Test successful bulk insertion with execute_many."""
+ keywords_to_add = [
+ (f"bulk_keyword_{i}", str(uuid.uuid4()), self.db._get_current_utc_timestamp_str(), 1, self.client_id, 0)
+ for i in range(50)
+ ]
+ sql = "INSERT INTO PromptKeywordsTable (keyword, uuid, last_modified, version, client_id, deleted) VALUES (?, ?, ?, ?, ?, ?)"
+
+ with self.db.transaction():
+ self.db.execute_many(sql, keywords_to_add)
+
+ count = self.db.execute_query("SELECT COUNT(*) FROM PromptKeywordsTable").fetchone()[0]
+ self.assertEqual(count, 50)
+
+ def test_execute_many_failure_with_integrity_error_rolls_back(self):
+ """Ensure a failing execute_many call within a transaction rolls back the entire batch."""
+ keywords_to_add = [
+ ("unique_1", str(uuid.uuid4()), self.db._get_current_utc_timestamp_str(), 1, self.client_id, 0),
+ ("not_unique", str(uuid.uuid4()), self.db._get_current_utc_timestamp_str(), 1, self.client_id, 0),
+ ("not_unique", str(uuid.uuid4()), self.db._get_current_utc_timestamp_str(), 1, self.client_id, 0),
+ # Fails here
+ ("unique_2", str(uuid.uuid4()), self.db._get_current_utc_timestamp_str(), 1, self.client_id, 0),
+ ]
+ sql = "INSERT INTO PromptKeywordsTable (keyword, uuid, last_modified, version, client_id, deleted) VALUES (?, ?, ?, ?, ?, ?)"
+
+ with self.assertRaises(DatabaseError) as cm:
+ with self.db.transaction():
+ self.db.execute_many(sql, keywords_to_add)
+
+ self.assertIn("UNIQUE constraint failed", str(cm.exception))
+
+ # Verify rollback
+ count = self.db.execute_query("SELECT COUNT(*) FROM PromptKeywordsTable").fetchone()[0]
+ self.assertEqual(count, 0)
+
+ def test_dependency_integrity_on_delete(self):
+ """
+ Verify that deleting a prompt doesn't delete a keyword used by other prompts.
+ """
+ # "common_kw" is used by both prompts
+ p1_id, _, _ = self.db.add_prompt("Prompt 1", "A", "D", keywords=["p1_kw", "common_kw"])
+ self.db.add_prompt("Prompt 2", "B", "E", keywords=["p2_kw", "common_kw"])
+
+ # Soft delete Prompt 1
+ self.db.soft_delete_prompt(p1_id)
+
+ # Check that "common_kw" still exists and is active
+ kw_data = self.db.get_active_keyword_by_text("common_kw")
+ self.assertIsNotNone(kw_data)
+
+ # Check that Prompt 2 still has its link to "common_kw"
+ p2_details = self.db.fetch_prompt_details("Prompt 2")
+ self.assertIn("common_kw", p2_details['keywords'])
+
+
+if __name__ == '__main__':
+ unittest.main(argv=['first-arg-is-ignored'], exit=False)
+
+
+#
+# End of tests_prompts_db.py
+#######################################################################################################################
diff --git a/Tests/Prompts_DB/tests_prompts_db_properties.py b/Tests/Prompts_DB/tests_prompts_db_properties.py
new file mode 100644
index 00000000..ae6ff3b7
--- /dev/null
+++ b/Tests/Prompts_DB/tests_prompts_db_properties.py
@@ -0,0 +1,574 @@
+# test_prompts_db_properties.py
+#
+# Property-based tests for the Prompts_DB_v2 library using Hypothesis.
+
+# Imports
+import uuid
+import pytest
+import json
+from pathlib import Path
+import sqlite3
+import threading
+import time
+
+# Third-Party Imports
+from hypothesis import given, strategies as st, settings, HealthCheck
+from hypothesis.stateful import RuleBasedStateMachine, rule, precondition, Bundle
+
+# Local Imports
+# Assuming Prompts_DB_v2.py is in a location Python can find.
+# For example, in the same directory or in a package.
+from tldw_chatbook.DB.Prompts_DB import (
+ PromptsDatabase,
+ InputError,
+ DatabaseError,
+ ConflictError
+)
+
+########################################################################################################################
+#
+# Hypothesis Setup:
+
+# A custom profile for DB tests to avoid timeouts on complex operations.
+settings.register_profile(
+ "db_friendly",
+ deadline=1500, # Increased deadline for potentially slow DB I/O
+ suppress_health_check=[
+ HealthCheck.too_slow,
+ HealthCheck.function_scoped_fixture
+ ]
+)
+settings.load_profile("db_friendly")
+
+
+# --- Fixtures ---
+
+@pytest.fixture
+def client_id():
+ """Provides a consistent client ID for tests."""
+ return "hypothesis_client"
+
+
+@pytest.fixture
+def db_path(tmp_path):
+ """Provides a temporary path for the database file for each test."""
+ return tmp_path / "prop_test_prompts_db.sqlite"
+
+
+@pytest.fixture(scope="function")
+def db_instance(db_path, client_id):
+ """Creates a fresh PromptsDatabase instance for each test function."""
+ current_db_path = Path(db_path)
+ # Ensure no leftover files from a failed previous run (important for WAL mode)
+ for suffix in ["", "-wal", "-shm"]:
+ p = Path(str(current_db_path) + suffix)
+ if p.exists():
+ p.unlink(missing_ok=True)
+
+ db = PromptsDatabase(current_db_path, client_id)
+ yield db
+ db.close_connection()
+
+
+# --- Hypothesis Strategies ---
+
+# Strategy for text fields that cannot be empty or just whitespace.
+st_required_text = st.text(min_size=1, max_size=100).filter(lambda s: s.strip())
+
+# Strategy for optional text fields.
+st_optional_text = st.one_of(st.none(), st.text(max_size=500))
+
+
+@st.composite
+def st_prompt_data(draw):
+ """A composite strategy to generate a dictionary of prompt data."""
+ # Generate keywords that are unique after normalization
+ keywords = draw(st.lists(st_required_text, max_size=5, unique_by=lambda s: s.strip().lower()))
+
+ return {
+ "name": draw(st_required_text),
+ "author": draw(st_optional_text),
+ "details": draw(st_optional_text),
+ "system_prompt": draw(st_optional_text),
+ "user_prompt": draw(st_optional_text),
+ "keywords": keywords,
+ }
+
+
+# A strategy for a non-one integer to test version validation triggers.
+st_bad_version_offset = st.integers().filter(lambda x: x != 1)
+
+
+# --- Test Classes ---
+
+class TestPromptProperties:
+ """Property-based tests for core Prompt operations."""
+
+ @given(prompt_data=st_prompt_data())
+ def test_prompt_roundtrip(self, db_instance: PromptsDatabase, prompt_data: dict):
+ """
+ Property: If we add a prompt, retrieving it should return the same data,
+ accounting for any normalization (e.g., stripping name, normalizing keywords).
+ """
+ try:
+ # Use overwrite=False to test the create path; ConflictError is a valid outcome.
+ prompt_id, prompt_uuid, _ = db_instance.add_prompt(**prompt_data, overwrite=False)
+ except ConflictError:
+ return # Hypothesis generated a name collision, which is not a failure.
+
+ assert prompt_id is not None
+ assert prompt_uuid is not None
+
+ retrieved_prompt = db_instance.fetch_prompt_details(prompt_id)
+ assert retrieved_prompt is not None
+
+ # Compare basic fields
+ assert retrieved_prompt["name"] == prompt_data["name"].strip()
+ assert retrieved_prompt["author"] == prompt_data["author"]
+ assert retrieved_prompt["details"] == prompt_data["details"]
+ assert retrieved_prompt["system_prompt"] == prompt_data["system_prompt"]
+ assert retrieved_prompt["user_prompt"] == prompt_data["user_prompt"]
+ assert retrieved_prompt["version"] == 1
+ assert not retrieved_prompt["deleted"]
+
+ # Compare keywords, which are normalized by the database
+ expected_keywords = sorted([db_instance._normalize_keyword(k) for k in prompt_data["keywords"]])
+ retrieved_keywords = sorted(retrieved_prompt["keywords"])
+ assert retrieved_keywords == expected_keywords
+
+ @given(initial_prompt=st_prompt_data(), update_payload=st_prompt_data())
+ def test_update_increments_version_and_changes_data(self, db_instance: PromptsDatabase, initial_prompt: dict,
+ update_payload: dict):
+ """
+ Property: A successful update must increment the version number by exactly 1
+ and correctly apply the new data, including keywords.
+ """
+ try:
+ prompt_id, _, _ = db_instance.add_prompt(**initial_prompt)
+ except ConflictError:
+ return # Skip if initial name conflicts
+
+ original_prompt = db_instance.get_prompt_by_id(prompt_id)
+
+ try:
+ # update_prompt_by_id handles fetching current version and incrementing it.
+ uuid, msg = db_instance.update_prompt_by_id(prompt_id, update_payload)
+ assert uuid is not None
+ except ConflictError as e:
+ # A legitimate failure if the new name is already taken by another prompt.
+ assert "already exists" in str(e)
+ return
+
+ updated_prompt = db_instance.fetch_prompt_details(prompt_id)
+ assert updated_prompt is not None
+ assert updated_prompt['version'] == original_prompt['version'] + 1
+
+ # Verify the payload was applied
+ assert updated_prompt['name'] == update_payload['name'].strip()
+ assert updated_prompt['author'] == update_payload['author']
+ expected_keywords = sorted([db_instance._normalize_keyword(k) for k in update_payload["keywords"]])
+ assert sorted(updated_prompt['keywords']) == expected_keywords
+
+ @given(prompt_data=st_prompt_data())
+ def test_soft_delete_makes_item_unfindable(self, db_instance: PromptsDatabase, prompt_data: dict):
+ """
+ Property: After soft-deleting a prompt, it should not be retrievable by
+ default methods, but should exist in the DB with deleted=1.
+ """
+ try:
+ prompt_id, _, _ = db_instance.add_prompt(**prompt_data)
+ except ConflictError:
+ return
+
+ # Perform the soft delete
+ success = db_instance.soft_delete_prompt(prompt_id)
+ assert success is True
+
+ # Assert it's no longer findable via public methods by default
+ assert db_instance.get_prompt_by_id(prompt_id) is None
+ assert db_instance.fetch_prompt_details(prompt_id) is None
+
+ all_prompts, _, _, _ = db_instance.list_prompts()
+ assert prompt_id not in [p['id'] for p in all_prompts]
+
+ # Assert it CAN be found when explicitly requested
+ deleted_record = db_instance.get_prompt_by_id(prompt_id, include_deleted=True)
+ assert deleted_record is not None
+ assert deleted_record['deleted'] == 1
+ assert deleted_record['version'] == 2 # 1=create, 2=delete
+
+ @given(initial_prompt=st_prompt_data(), update_name=st_required_text, version_offset=st_bad_version_offset)
+ def test_update_with_stale_version_fails_via_trigger(self, db_instance: PromptsDatabase, initial_prompt: dict,
+ update_name: str, version_offset: int):
+ """
+ Property: Attempting a direct DB update with a version that does not increment
+ by exactly 1 must be rejected by the database trigger.
+ """
+ try:
+ prompt_id, _, _ = db_instance.add_prompt(**initial_prompt)
+ except ConflictError:
+ return
+
+ original_prompt = db_instance.get_prompt_by_id(prompt_id)
+
+ # Attempt a direct DB update with a bad version number.
+ # This tests the 'prompts_validate_sync_update' trigger.
+ with pytest.raises(DatabaseError) as excinfo:
+ db_instance.execute_query(
+ "UPDATE Prompts SET name = ?, version = ? WHERE id = ?",
+ (update_name, original_prompt['version'] + version_offset, prompt_id),
+ commit=True
+ )
+ assert "version must increment by exactly 1" in str(excinfo.value).lower()
+
+
+class TestKeywordAndLinkingProperties:
+ """Property-based tests for Keywords and their linking to Prompts."""
+
+ @given(keyword_text=st_required_text)
+ def test_keyword_normalization_and_roundtrip(self, db_instance: PromptsDatabase, keyword_text: str):
+ """
+ Property: Adding a keyword normalizes it (lowercase, stripped).
+ Retrieving it returns the normalized version.
+ """
+ kw_id, kw_uuid = db_instance.add_keyword(keyword_text)
+ assert kw_id is not None
+ assert kw_uuid is not None
+
+ retrieved_kw = db_instance.get_active_keyword_by_text(keyword_text)
+ assert retrieved_kw is not None
+ assert retrieved_kw['keyword'] == db_instance._normalize_keyword(keyword_text)
+
+ @given(keyword=st_required_text)
+ def test_add_keyword_is_idempotent_on_undelete(self, db_instance: PromptsDatabase, keyword: str):
+ """
+ Property: Adding a keyword that was previously soft-deleted should reactivate
+ it (not create a new one), and its version should be correctly incremented.
+ """
+ # 1. Add for the first time
+ kw_id_v1, _ = db_instance.add_keyword(keyword)
+ assert db_instance.get_prompt_by_id(kw_id_v1) is not None # Using wrong get method in original code
+ kw_v1 = db_instance.get_active_keyword_by_text(keyword)
+ assert kw_v1['version'] == 1
+
+ # 2. Soft delete it
+ success = db_instance.soft_delete_keyword(keyword)
+ assert success is True
+
+ # Check raw state
+ raw_kw = db_instance.execute_query("SELECT * FROM PromptKeywordsTable WHERE id=?", (kw_id_v1,)).fetchone()
+ assert raw_kw['deleted'] == 1
+ assert raw_kw['version'] == 2
+
+ # 3. Add it again (should trigger undelete)
+ kw_id_v3, _ = db_instance.add_keyword(keyword)
+
+ # Assert it's the same record
+ assert kw_id_v3 == kw_id_v1
+
+ kw_v3 = db_instance.get_active_keyword_by_text(keyword)
+ assert kw_v3 is not None
+ assert not db_instance.get_prompt_by_id(kw_v3['id'], include_deleted=True)['deleted']
+ # The version should be 3 (1=create, 2=delete, 3=undelete/update)
+ assert kw_v3['version'] == 3
+
+ @given(
+ prompt_data=st_prompt_data(),
+ new_keywords=st.lists(st_required_text, max_size=5, unique_by=lambda s: s.strip().lower())
+ )
+ def test_update_keywords_for_prompt_links_and_unlinks(self, db_instance: PromptsDatabase, prompt_data: dict,
+ new_keywords: list):
+ """
+ Property: Updating keywords for a prompt correctly adds new links,
+ removes old ones, and leaves unchanged ones alone.
+ """
+ try:
+ prompt_id, _, _ = db_instance.add_prompt(**prompt_data)
+ except ConflictError:
+ return
+
+ # Initial state check
+ initial_expected_kws = sorted([db_instance._normalize_keyword(k) for k in prompt_data['keywords']])
+ assert sorted(db_instance.fetch_keywords_for_prompt(prompt_id)) == initial_expected_kws
+
+ # Update the keywords
+ db_instance.update_keywords_for_prompt(prompt_id, new_keywords)
+
+ # Final state check
+ final_expected_kws = sorted([db_instance._normalize_keyword(k) for k in new_keywords])
+ assert sorted(db_instance.fetch_keywords_for_prompt(prompt_id)) == final_expected_kws
+
+
+class TestAdvancedProperties:
+ """Tests for FTS, Sync Log, and other complex interactions."""
+
+ @given(prompt_data=st_prompt_data())
+ def test_soft_deleted_item_is_not_in_fts(self, db_instance: PromptsDatabase, prompt_data: dict):
+ """
+ Property: Once a prompt is soft-deleted, it must not appear in FTS search results.
+ """
+ # Ensure the name has a unique, searchable term.
+ unique_term = str(uuid.uuid4())
+ prompt_data['name'] = f"{prompt_data['name']} {unique_term}"
+
+ try:
+ prompt_id, _, _ = db_instance.add_prompt(**prompt_data)
+ except ConflictError:
+ return
+
+ # 1. Verify it IS searchable before deletion
+ results_before, total_before = db_instance.search_prompts(unique_term)
+ assert total_before == 1
+ assert results_before[0]['id'] == prompt_id
+
+ # 2. Soft-delete the prompt
+ db_instance.soft_delete_prompt(prompt_id)
+
+ # 3. Verify it is NOT searchable after deletion
+ results_after, total_after = db_instance.search_prompts(unique_term)
+ assert total_after == 0
+
+ @given(prompt_data=st_prompt_data())
+ def test_add_creates_correct_sync_log_entries(self, db_instance: PromptsDatabase, prompt_data: dict):
+ """
+ Property: Adding a new prompt must create the correct 'create' and 'link'
+ operations in the sync_log.
+ """
+ latest_change_id_before = db_instance.get_sync_log_entries(limit=1)[-1][
+ 'change_id'] if db_instance.get_sync_log_entries(limit=1) else 0
+
+ try:
+ prompt_id, prompt_uuid, _ = db_instance.add_prompt(**prompt_data)
+ except ConflictError:
+ return
+
+ new_logs = db_instance.get_sync_log_entries(since_change_id=latest_change_id_before)
+
+ # Verify 'create' log for the prompt itself
+ prompt_create_logs = [log for log in new_logs if log['entity'] == 'Prompts' and log['operation'] == 'create']
+ assert len(prompt_create_logs) == 1
+ assert prompt_create_logs[0]['entity_uuid'] == prompt_uuid
+ assert prompt_create_logs[0]['version'] == 1
+
+ # Verify 'link' logs for the keywords
+ normalized_keywords = {db_instance._normalize_keyword(k) for k in prompt_data['keywords']}
+ link_logs = [log for log in new_logs if log['entity'] == 'PromptKeywordLinks' and log['operation'] == 'link']
+ assert len(link_logs) == len(normalized_keywords)
+
+ # The payload should contain the composite UUID of prompt_uuid + keyword_uuid
+ for log in link_logs:
+ assert log['payload']['prompt_uuid'] == prompt_uuid
+ assert log['entity_uuid'].startswith(prompt_uuid)
+
+ @given(prompt_data=st_prompt_data())
+ def test_delete_creates_correct_sync_log_entries(self, db_instance: PromptsDatabase, prompt_data: dict):
+ """
+ Property: Soft-deleting a prompt must create a 'delete' log for the prompt
+ and 'unlink' logs for all its keyword connections.
+ """
+ try:
+ prompt_id, prompt_uuid, _ = db_instance.add_prompt(**prompt_data)
+ except ConflictError:
+ return
+
+ num_keywords = len(prompt_data['keywords'])
+ latest_change_id_before = db_instance.get_sync_log_entries(limit=1)[-1]['change_id']
+
+ # Action: Soft delete
+ db_instance.soft_delete_prompt(prompt_id)
+
+ new_logs = db_instance.get_sync_log_entries(since_change_id=latest_change_id_before)
+
+ # Verify 'delete' log for the prompt
+ prompt_delete_logs = [log for log in new_logs if log['entity'] == 'Prompts' and log['operation'] == 'delete']
+ assert len(prompt_delete_logs) == 1
+ assert prompt_delete_logs[0]['entity_uuid'] == prompt_uuid
+ assert prompt_delete_logs[0]['version'] == 2
+
+ # Verify 'unlink' logs for the keywords
+ unlink_logs = [log for log in new_logs if
+ log['entity'] == 'PromptKeywordLinks' and log['operation'] == 'unlink']
+ assert len(unlink_logs) == num_keywords
+ for log in unlink_logs:
+ assert log['payload']['prompt_uuid'] == prompt_uuid
+
+
+class TestDataIntegrityAndConcurrency:
+ """Tests for database constraints and thread safety."""
+
+ def test_add_prompt_with_conflicting_name_fails(self, db_instance: PromptsDatabase):
+ """
+ Property: Adding a prompt with a name that already exists (and overwrite=False)
+ must raise a ConflictError.
+ """
+ prompt_data = {"name": "Unique Prompt Name", "author": "Tester"}
+ db_instance.add_prompt(**prompt_data)
+
+ # Attempt to add again with the same name
+ with pytest.raises(ConflictError):
+ db_instance.add_prompt(**prompt_data, overwrite=False)
+
+ def test_update_prompt_to_conflicting_name_fails(self, db_instance: PromptsDatabase):
+ """
+ Property: Updating a prompt's name to a name that is already used by
+ another active prompt must raise a ConflictError.
+ """
+ p1_id, _, _ = db_instance.add_prompt(name="Prompt One", author="A")
+ db_instance.add_prompt(name="Prompt Two", author="B") # The conflicting name
+
+ update_payload = {"name": "Prompt Two"}
+ with pytest.raises(ConflictError):
+ db_instance.update_prompt_by_id(p1_id, update_payload)
+
+ def test_each_thread_gets_a_separate_connection(self, db_instance: PromptsDatabase):
+ """
+ Property: The `get_connection` method must provide a unique
+ connection object for each thread, via threading.local.
+ """
+ connection_ids = set()
+ lock = threading.Lock()
+
+ def get_and_store_conn_id():
+ conn = db_instance.get_connection()
+ with lock:
+ connection_ids.add(id(conn))
+
+ threads = [threading.Thread(target=get_and_store_conn_id) for _ in range(5)]
+ for t in threads:
+ t.start()
+ for t in threads:
+ t.join()
+
+ # If threading.local is working, there should be 5 unique connection IDs.
+ assert len(connection_ids) == 5
+
+ def test_wal_mode_allows_concurrent_reads_during_write_transaction(self, db_instance: PromptsDatabase):
+ """
+ Property: In WAL mode, one thread can read from the DB while another
+ thread has an open write transaction.
+ """
+ prompt_id, _, _ = db_instance.add_prompt(name="Concurrent Read Test", details="Original")
+
+ write_transaction_started = threading.Event()
+ read_result = []
+
+ def writer_thread():
+ # The update method opens its own transaction
+ with db_instance.transaction():
+ db_instance.execute_query("UPDATE Prompts SET details = 'Updated' WHERE id = ?", (prompt_id,))
+ write_transaction_started.set() # Signal that the transaction is open
+ time.sleep(0.2) # Hold the transaction open
+ # Transaction commits here
+
+ def reader_thread():
+ write_transaction_started.wait() # Wait until the writer is in its transaction
+ # This read should succeed immediately and read the state BEFORE the commit.
+ prompt = db_instance.get_prompt_by_id(prompt_id)
+ read_result.append(prompt)
+
+ w = threading.Thread(target=writer_thread)
+ r = threading.Thread(target=reader_thread)
+
+ w.start()
+ r.start()
+ w.join()
+ r.join()
+
+ # The reader thread should have completed successfully and read the *original* state.
+ assert len(read_result) == 1
+ assert read_result[0] is not None
+ assert read_result[0]['details'] == "Original" # It read the state before the writer committed.
+
+
+# --- State Machine Tests ---
+
+class PromptLifecycleMachine(RuleBasedStateMachine):
+ """
+ Models the lifecycle of a single Prompt: create, update, delete.
+ Hypothesis will try to find sequences of these actions that break invariants.
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.db = None # Injected by the test class fixture
+ # State for a single prompt's lifecycle
+ self.prompt_id = None
+ self.prompt_name = None
+ self.is_deleted = True
+
+ prompts = Bundle('prompts')
+
+ @rule(target=prompts, data=st_prompt_data())
+ def create_prompt(self, data):
+ # We only want to test the lifecycle of one prompt per machine run for simplicity.
+ if self.prompt_id is not None:
+ return
+
+ try:
+ new_id, _, _ = self.db.add_prompt(**data, overwrite=False)
+ except ConflictError:
+ # Hypothesis might generate a duplicate name. We treat this as "no action taken".
+ return
+
+ self.prompt_id = new_id
+ self.prompt_name = data['name'].strip()
+ self.is_deleted = False
+
+ retrieved = self.db.get_prompt_by_id(self.prompt_id)
+ assert retrieved is not None
+ assert retrieved['name'] == self.prompt_name
+ assert retrieved['version'] == 1
+ return self.prompt_id
+
+ @rule(prompt_id=prompts, update_data=st_prompt_data())
+ def update_prompt(self, prompt_id, update_data):
+ if self.prompt_id is None or self.is_deleted:
+ return
+
+ original_version = self.db.get_prompt_by_id(prompt_id, include_deleted=True)['version']
+
+ try:
+ self.db.update_prompt_by_id(prompt_id, update_data)
+ # If successful, update our internal state
+ self.prompt_name = update_data['name'].strip()
+ except ConflictError as e:
+ # This is a valid outcome if the new name is taken.
+ assert "already exists" in str(e)
+ # The state of our prompt hasn't changed.
+ return
+
+ retrieved = self.db.get_prompt_by_id(self.prompt_id)
+ assert retrieved is not None
+ assert retrieved['version'] == original_version + 1
+ assert retrieved['name'] == self.prompt_name
+
+ @rule(prompt_id=prompts)
+ def soft_delete_prompt(self, prompt_id):
+ if self.prompt_id is None or self.is_deleted:
+ return
+
+ original_version = self.db.get_prompt_by_id(prompt_id, include_deleted=True)['version']
+
+ success = self.db.soft_delete_prompt(prompt_id)
+ assert success
+ self.is_deleted = True
+
+ # Verify it's gone from standard lookups
+ assert self.db.get_prompt_by_id(self.prompt_id) is None
+ assert self.db.get_prompt_by_name(self.prompt_name) is None
+
+ # Verify its deleted state
+ raw_record = self.db.get_prompt_by_id(self.prompt_id, include_deleted=True)
+ assert raw_record['deleted'] == 1
+ assert raw_record['version'] == original_version + 1
+
+
+# This is the actual test class that pytest discovers and runs.
+# It inherits the rules and provides the `db_instance` fixture.
+@settings(max_examples=50, stateful_step_count=20)
+class TestPromptLifecycleAsTest(PromptLifecycleMachine):
+
+ @pytest.fixture(autouse=True)
+ def inject_db(self, db_instance):
+ """Injects the clean db_instance fixture into the state machine."""
+ self.db = db_instance
\ No newline at end of file
diff --git a/pyproject.toml b/pyproject.toml
index 6aa9b042..6f6152aa 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -54,6 +54,8 @@ transformers = ["transformers"]
dev = [ # Example for development dependencies
"pytest",
"textual-dev", # For Textual development tools
+ "hypothesis",
+ "pytest_asyncio",
]
diff --git a/tldw_chatbook/DB/ChaChaNotes_DB.py b/tldw_chatbook/DB/ChaChaNotes_DB.py
index 6ba48738..f691864e 100644
--- a/tldw_chatbook/DB/ChaChaNotes_DB.py
+++ b/tldw_chatbook/DB/ChaChaNotes_DB.py
@@ -1901,6 +1901,7 @@ def search_character_cards(self, search_term: str, limit: int = 10) -> List[Dict
Raises:
CharactersRAGDBError: For database errors during the search.
"""
+ safe_search_term = f'"{search_term}"'
query = """
SELECT cc.*
FROM character_cards_fts fts
@@ -1914,7 +1915,7 @@ def search_character_cards(self, search_term: str, limit: int = 10) -> List[Dict
rows = cursor.fetchall()
return [self._deserialize_row_fields(row, self._CHARACTER_CARD_JSON_FIELDS) for row in rows if row]
except CharactersRAGDBError as e:
- logger.error(f"Error searching character cards for '{search_term}': {e}")
+ logger.error(f"Error searching character cards for '{safe_search_term}': {e}")
raise
# --- Conversation Methods ---
@@ -2296,6 +2297,7 @@ def search_conversations_by_title(self, title_query: str, character_id: Optional
if not title_query.strip():
logger.warning("Empty title_query provided for conversation search. Returning empty list.")
return []
+ safe_search_term = f'"{title_query}"'
base_query = """
SELECT c.*
FROM conversations_fts fts
@@ -2315,7 +2317,7 @@ def search_conversations_by_title(self, title_query: str, character_id: Optional
cursor = self.execute_query(base_query, tuple(params_list))
return [dict(row) for row in cursor.fetchall()]
except CharactersRAGDBError as e:
- logger.error(f"Error searching conversations for title '{title_query}': {e}")
+ logger.error(f"Error searching conversations for title '{safe_search_term}': {e}")
raise
# --- Message Methods ---
@@ -2429,27 +2431,25 @@ def get_messages_for_conversation(self, conversation_id: str, limit: int = 100,
order_by_timestamp: str = "ASC") -> List[Dict[str, Any]]:
"""
Lists messages for a specific conversation.
-
Returns non-deleted messages, ordered by `timestamp` according to `order_by_timestamp`.
- Includes all fields, including `image_data` and `image_mime_type`.
-
- Args:
- conversation_id: The UUID of the conversation.
- limit: Maximum number of messages to return. Defaults to 100.
- offset: Number of messages to skip. Defaults to 0.
- order_by_timestamp: Sort order for 'timestamp' field ("ASC" or "DESC").
- Defaults to "ASC".
-
- Returns:
- A list of message dictionaries. Can be empty.
-
- Raises:
- InputError: If `order_by_timestamp` has an invalid value.
- CharactersRAGDBError: For database errors.
+ Crucially, it also ensures the parent conversation is not soft-deleted.
"""
if order_by_timestamp.upper() not in ["ASC", "DESC"]:
raise InputError("order_by_timestamp must be 'ASC' or 'DESC'.")
- query = f"SELECT id, conversation_id, parent_message_id, sender, content, image_data, image_mime_type, timestamp, ranking, last_modified, version, client_id, deleted FROM messages WHERE conversation_id = ? AND deleted = 0 ORDER BY timestamp {order_by_timestamp} LIMIT ? OFFSET ?" # Explicitly list columns
+
+ # The new query joins with conversations to check its 'deleted' status.
+ query = f"""
+ SELECT m.id, m.conversation_id, m.parent_message_id, m.sender, m.content,
+ m.image_data, m.image_mime_type, m.timestamp, m.ranking,
+ m.last_modified, m.version, m.client_id, m.deleted
+ FROM messages m
+ JOIN conversations c ON m.conversation_id = c.id
+ WHERE m.conversation_id = ?
+ AND m.deleted = 0
+ AND c.deleted = 0
+ ORDER BY m.timestamp {order_by_timestamp}
+ LIMIT ? OFFSET ?
+ """
try:
cursor = self.execute_query(query, (conversation_id, limit, offset))
return [dict(row) for row in cursor.fetchall()]
@@ -2667,6 +2667,7 @@ def search_messages_by_content(self, content_query: str, conversation_id: Option
Raises:
CharactersRAGDBError: For database search errors.
"""
+ safe_search_term = f'"{content_query}"'
base_query = """
SELECT m.*
FROM messages_fts fts
@@ -2686,7 +2687,7 @@ def search_messages_by_content(self, content_query: str, conversation_id: Option
cursor = self.execute_query(base_query, tuple(params_list))
return [dict(row) for row in cursor.fetchall()]
except CharactersRAGDBError as e:
- logger.error(f"Error searching messages for content '{content_query}': {e}")
+ logger.error(f"Error searching messages for content '{safe_search_term}': {e}")
raise
# --- Keyword, KeywordCollection, Note Methods (CRUD + Search) ---
@@ -3252,7 +3253,8 @@ def search_keywords(self, search_term: str, limit: int = 10) -> List[Dict[str, A
Returns:
A list of matching keyword dictionaries.
"""
- return self._search_generic_items_fts("keywords_fts", "keywords", "keyword", search_term, limit)
+ safe_search_term = f'"{search_term}"'
+ return self._search_generic_items_fts("keywords_fts", "keywords", "keyword", safe_search_term, limit)
# Keyword Collections
def add_keyword_collection(self, name: str, parent_id: Optional[int] = None) -> Optional[int]:
@@ -3372,7 +3374,8 @@ def soft_delete_keyword_collection(self, collection_id: int, expected_version: i
)
def search_keyword_collections(self, search_term: str, limit: int = 10) -> List[Dict[str, Any]]:
- return self._search_generic_items_fts("keyword_collections_fts", "keyword_collections", "name", search_term,
+ safe_search_term = f'"{search_term}"'
+ return self._search_generic_items_fts("keyword_collections_fts", "keyword_collections", "name", safe_search_term,
limit)
# Notes (Now with UUID and specific methods)
@@ -3536,19 +3539,21 @@ def soft_delete_note(self, note_id: str, expected_version: int) -> bool | None:
def search_notes(self, search_term: str, limit: int = 10) -> List[Dict[str, Any]]:
"""Searches notes_fts (title and content). Corrected JOIN condition."""
- # notes_fts matches against title and content
- # FTS table column group: notes_fts
- # Content table: notes, content_rowid: rowid (maps to notes.rowid)
+ # FTS5 requires wrapping terms with special characters in double quotes
+ # to be treated as a literal phrase.
+ safe_search_term = f'"{search_term}"'
+
query = """
SELECT main.*
FROM notes_fts fts
- JOIN notes main ON fts.rowid = main.rowid -- Corrected Join condition
- WHERE fts.notes_fts MATCH ? \
+ JOIN notes main ON fts.rowid = main.rowid
+ WHERE fts.notes_fts MATCH ?
AND main.deleted = 0
- ORDER BY rank LIMIT ? \
+ ORDER BY rank LIMIT ?
"""
try:
- cursor = self.execute_query(query, (search_term, limit))
+ # Pass the quoted string as the parameter
+ cursor = self.execute_query(query, (safe_search_term, limit))
return [dict(row) for row in cursor.fetchall()]
except CharactersRAGDBError as e:
logger.error(f"Error searching notes for '{search_term}': {e}")
@@ -3824,6 +3829,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
logger.debug(
f"Exception in nested transaction block on thread {threading.get_ident()}: {exc_type.__name__}. Outermost transaction will handle rollback if this exception propagates.")
+
# Return False to re-raise any exceptions that occurred within the `with` block,
# allowing them to be handled by the caller or to propagate further up.
# This is standard behavior for context managers.
diff --git a/tldw_chatbook/DB/Client_Media_DB_v2.py b/tldw_chatbook/DB/Client_Media_DB_v2.py
index e6fccf54..a93785ba 100644
--- a/tldw_chatbook/DB/Client_Media_DB_v2.py
+++ b/tldw_chatbook/DB/Client_Media_DB_v2.py
@@ -1703,320 +1703,257 @@ def soft_delete_media(self, media_id: int, cascade: bool = True) -> bool:
logger.error(f"Unexpected error soft deleting media ID {media_id}: {e}", exc_info=True)
raise DatabaseError(f"Unexpected error during soft delete: {e}") from e
- def add_media_with_keywords(self,
- *,
- url: Optional[str] = None,
- title: Optional[str],
- media_type: Optional[str],
- content: Optional[str],
- keywords: Optional[List[str]] = None,
- prompt: Optional[str] = None,
- analysis_content: Optional[str] = None,
- transcription_model: Optional[str] = None,
- author: Optional[str] = None,
- ingestion_date: Optional[str] = None,
- overwrite: bool = False,
- chunk_options: Optional[Dict] = None,
- chunks: Optional[List[Dict[str, Any]]] = None
- ) -> Tuple[Optional[int], Optional[str], str]:
- """
- Adds a new media item or updates an existing one based on URL or content hash.
-
- Handles creation or update of the Media record, generates a content hash,
- associates keywords (adding them if necessary), creates an initial
- DocumentVersion, logs appropriate sync events ('create' or 'update' for
- Media, plus events from keyword and document version handling), and
- updates the `media_fts` table.
-
- If `chunks` are provided, they are saved as UnvectorizedMediaChunks.
- If updating an existing media item and new chunks are provided, old chunks
- for that media item are hard-deleted before inserting the new ones.
-
- If an existing item is found (by URL or content hash) and `overwrite` is False,
- the operation is skipped. If `overwrite` is True, the existing item is updated.
-
- Args:
- url (Optional[str]): The URL of the media (unique). Generated if not provided.
- title (Optional[str]): Title of the media. Defaults to 'Untitled'.
- media_type (Optional[str]): Type of media (e.g., 'article', 'video'). Defaults to 'unknown'.
- content (Optional[str]): The main text content. Required.
- keywords (Optional[List[str]]): List of keyword strings to associate.
- prompt (Optional[str]): Optional prompt associated with this version.
- analysis_content (Optional[str]): Optional analysis/summary content.
- transcription_model (Optional[str]): Model used for transcription, if applicable.
- author (Optional[str]): Author of the media.
- ingestion_date (Optional[str]): ISO 8601 formatted UTC timestamp for ingestion.
- Defaults to current time if None.
- overwrite (bool): If True, update the media item if it already exists.
- Defaults to False (skip if exists).
- chunk_options (Optional[Dict]): Placeholder for chunking parameters.
- chunks (Optional[List[Dict[str, Any]]]): A list of dictionaries, where each dictionary
- represents a chunk of the media content. Expected keys
- per dictionary: 'text' (str, required), and optional
- 'start_char' (int), 'end_char' (int),
- 'chunk_type' (str), 'metadata' (dict).
-
- Returns:
- Tuple[Optional[int], Optional[str], str]: A tuple containing:
- - media_id (Optional[int]): The ID of the added/updated media item.
- - media_uuid (Optional[str]): The UUID of the added/updated media item.
- - message (str): A status message indicating the action taken
- ("added", "updated", "already_exists_skipped").
-
- Raises:
- InputError: If `content` is None or required chunk data is malformed.
- ConflictError: If an update fails due to a version mismatch.
- DatabaseError: For underlying database issues or errors during sync/FTS logging.
- """
+ def add_media_with_keywords(
+ self,
+ *,
+ url: Optional[str] = None,
+ title: Optional[str] = None,
+ media_type: Optional[str] = None,
+ content: Optional[str] = None,
+ keywords: Optional[List[str]] = None,
+ prompt: Optional[str] = None,
+ analysis_content: Optional[str] = None,
+ transcription_model: Optional[str] = None,
+ author: Optional[str] = None,
+ ingestion_date: Optional[str] = None,
+ overwrite: bool = False,
+ chunk_options: Optional[Dict] = None,
+ chunks: Optional[List[Dict[str, Any]]] = None,
+ ) -> Tuple[Optional[int], Optional[str], str]:
+ """Add or update a media record, handle keyword links, optional chunks and full-text sync."""
+
+ # ---------------------------------------------------------------------
+ # 1. Fast‑fail validation & normalisation
+ # ---------------------------------------------------------------------
if content is None:
raise InputError("Content cannot be None.")
- title = title or 'Untitled'
- media_type = media_type or 'unknown'
- keywords_list = [k.strip().lower() for k in keywords if k and k.strip()] if keywords else []
- # Get current time and client ID
- current_time = self._get_current_utc_timestamp_str()
- client_id = self.client_id
+ title = title or "Untitled"
+ media_type = media_type or "unknown"
+ keywords_norm = [k.strip().lower() for k in keywords or [] if k and k.strip()]
- # Handle ingestion_date: Use provided, else generate now. Use full timestamp.
- ingestion_date_str = ingestion_date or current_time
+ now = self._get_current_utc_timestamp_str()
+ ingestion_date = ingestion_date or now
+ client_id = self.client_id
content_hash = hashlib.sha256(content.encode()).hexdigest()
- if not url:
- url = f"local://{media_type}/{content_hash}"
+ url = url or f"local://{media_type}/{content_hash}"
+
+ # Determine the final chunk status before any DB operation
+ final_chunk_status = "completed" if chunks is not None else "pending"
+
+ logging.info("add_media_with_keywords: url=%s, title=%s, client=%s", url, title, client_id)
+
+ # ------------------------------------------------------------------
+ # Helper builders
+ # ------------------------------------------------------------------
+ def _media_payload(uuid_: str, version_: int, *, chunk_status: str) -> Dict[str, Any]:
+ """Return a dict suitable for INSERT/UPDATE parameters and for sync logging."""
+ return {
+ "url": url,
+ "title": title,
+ "type": media_type,
+ "content": content,
+ "author": author,
+ "ingestion_date": ingestion_date,
+ "transcription_model": transcription_model,
+ "content_hash": content_hash,
+ "is_trash": 0,
+ "trash_date": None,
+ "chunking_status": chunk_status,
+ "vector_processing": 0,
+ "uuid": uuid_,
+ "last_modified": now,
+ "version": version_,
+ "client_id": client_id,
+ "deleted": 0,
+ }
- logging.info(f"Processing add/update for: URL='{url}', Title='{title}', Client='{client_id}'")
+ def _persist_chunks(cnx: sqlite3.Connection, media_id: int) -> None:
+ """Delete/insert un-vectorized chunks as requested. DOES NOT update parent Media."""
+ if chunks is None:
+ return # caller did not touch chunks
- media_id: Optional[int] = None
- media_uuid: Optional[str] = None
- action = "skipped" # Default action
+ if overwrite:
+ cnx.execute("DELETE FROM UnvectorizedMediaChunks WHERE media_id = ?", (media_id,))
+
+ if not chunks: # empty list means just wipe
+ return
- # Determine initial chunking_status based on presence of 'chunks' argument
- # This will be used when preparing insert_data or update_data for the Media table
- initial_chunking_status = "pending"
- if chunks is not None: # chunks is an empty list or a list with items
- initial_chunking_status = "processing"
+ created = self._get_current_utc_timestamp_str()
+ for idx, ch in enumerate(chunks):
+ if not isinstance(ch, dict) or ch.get("text") is None:
+ logging.warning("Skipping invalid chunk index %s for media_id %s", idx, media_id)
+ continue
+
+ chunk_uuid = self._generate_uuid()
+ cnx.execute(
+ """INSERT INTO UnvectorizedMediaChunks (media_id, chunk_text, chunk_index, start_char, end_char,
+ chunk_type, creation_date, last_modified_orig, is_processed,
+ metadata, uuid, last_modified, version, client_id, deleted,
+ prev_version, merge_parent_uuid)
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
+ (
+ media_id, ch["text"], idx, ch.get("start_char"), ch.get("end_char"), ch.get("chunk_type"),
+ created, created, False,
+ json.dumps(ch.get("metadata")) if isinstance(ch.get("metadata"), dict) else None,
+ chunk_uuid, created, 1, client_id, 0, None, None,
+ ),
+ )
+ self._log_sync_event(
+ cnx, "UnvectorizedMediaChunks", chunk_uuid, "create", 1,
+ {
+ **ch, "media_id": media_id, "uuid": chunk_uuid, "chunk_index": idx,
+ "creation_date": created, "last_modified": created, "version": 1,
+ "client_id": client_id, "deleted": 0,
+ },
+ )
+ # ------------------------------------------------------------------
+ # 2. Main transactional block
+ # ------------------------------------------------------------------
try:
with self.transaction() as conn:
- cursor = conn.cursor()
- cursor.execute('SELECT id, uuid, version FROM Media WHERE (url = ? OR content_hash = ?) AND deleted = 0 LIMIT 1', (url, content_hash))
- existing_media = cursor.fetchone()
- media_id, media_uuid, action = None, None, "skipped"
+ cur = conn.cursor()
+
+ # Find existing record by URL or content_hash
+ cur.execute(
+ "SELECT id, uuid, version, url, content_hash FROM Media WHERE url = ? AND deleted = 0 LIMIT 1",
+ (url,),
+ )
+ row = cur.fetchone()
+
+ if not row:
+ cur.execute(
+ "SELECT id, uuid, version, url, content_hash FROM Media WHERE content_hash = ? AND deleted = 0 LIMIT 1",
+ (content_hash,),
+ )
+ row = cur.fetchone()
+
+ # --- Path A: Record exists, handle UPDATE, CANONICALIZATION, or SKIP ---
+ if row:
+ media_id = row["id"]
+ media_uuid = row["uuid"]
+ current_ver = row["version"]
+ existing_url = row["url"]
+ existing_hash = row["content_hash"]
- if existing_media:
- media_id, media_uuid, current_version = existing_media['id'], existing_media['uuid'], existing_media['version']
+ # Case A.1: Overwrite is requested.
if overwrite:
- action = "updated"
- new_version = current_version + 1
- logger.info(f"Updating existing media ID {media_id} (UUID: {media_uuid}) to version {new_version}.")
- update_data = { # Prepare dict for easier payload generation
- 'url': url,
- 'title': title,
- 'type': media_type,
- 'content': content,
- 'author': author,
- 'ingestion_date': ingestion_date_str,
- 'transcription_model': transcription_model,
- 'content_hash': content_hash,
- 'is_trash': 0,
- 'trash_date': None, # Ensure trash_date is None here
- 'chunking_status': "pending",
- 'vector_processing': 0,
- 'last_modified': current_time, # Set last_modified
- 'version': new_version,
- 'client_id': client_id,
- 'deleted': 0,
- 'uuid': media_uuid
- }
- cursor.execute(
- """UPDATE Media SET url=?, title=?, type=?, content=?, author=?, ingestion_date=?,
- transcription_model=?, content_hash=?, is_trash=?, trash_date=?, chunking_status=?,
- vector_processing=?, last_modified=?, version=?, client_id=?, deleted=?
- WHERE id=? AND version=?""",
- (update_data['url'], update_data['title'], update_data['type'], update_data['content'],
- update_data['author'], update_data['ingestion_date'], update_data['transcription_model'],
- update_data['content_hash'], update_data['is_trash'], update_data['trash_date'], # Pass None for trash_date
- update_data['chunking_status'], update_data['vector_processing'],
- update_data['last_modified'], # Pass current_time
- update_data['version'], update_data['client_id'], update_data['deleted'],
- media_id, current_version)
+ # Case A.1.a: Content is identical. No version bump needed for main content.
+ if content_hash == existing_hash:
+ logging.info(f"Media content for ID {media_id} is identical. Updating metadata/chunks only.")
+
+ # Update keywords and chunks without changing the main Media record yet.
+ self.update_keywords_for_media(media_id, keywords_norm)
+ _persist_chunks(conn, media_id)
+
+ # If new chunks were provided, the media's chunking status has changed,
+ # which justifies a version bump on the parent Media record.
+ if chunks is not None:
+ logging.info(f"Chunks provided for identical media; updating media chunk_status and version for ID {media_id}.")
+ new_ver = current_ver + 1
+ cur.execute(
+ """UPDATE Media SET chunking_status = 'completed', version = ?, last_modified = ?
+ WHERE id = ? AND version = ?""",
+ (new_ver, now, media_id, current_ver)
+ )
+ if cur.rowcount == 0:
+ raise ConflictError(f"Media (updating chunk status for identical content id={media_id})", media_id)
+
+ self._log_sync_event(conn, "Media", media_uuid, "update", new_ver, {"chunking_status": "completed", "last_modified": now})
+
+ return media_id, media_uuid, f"Media '{title}' is already up-to-date."
+
+ # Case A.1.b: Content is different. Proceed with a full versioned update.
+ new_ver = current_ver + 1
+ payload = _media_payload(media_uuid, new_ver, chunk_status=final_chunk_status)
+ cur.execute(
+ """UPDATE Media
+ SET url=:url, title=:title, type=:type, content=:content, author=:author,
+ ingestion_date=:ingestion_date, transcription_model=:transcription_model,
+ content_hash=:content_hash, is_trash=:is_trash, trash_date=:trash_date,
+ chunking_status=:chunking_status, vector_processing=:vector_processing,
+ last_modified=:last_modified, version=:version, client_id=:client_id, deleted=:deleted
+ WHERE id = :id AND version = :ver""",
+ {**payload, "id": media_id, "ver": current_ver},
)
- if cursor.rowcount == 0:
- raise ConflictError("Media", media_id)
+ if cur.rowcount == 0:
+ raise ConflictError(f"Media (full update id={media_id})", media_id)
- # Use the update_data dict directly for the payload
- self._log_sync_event(conn, 'Media', media_uuid, 'update', new_version, update_data)
- self._update_fts_media(conn, media_id, update_data['title'], update_data['content'])
-
- # Consolidate keyword and version creation here for "updated"
- self.update_keywords_for_media(media_id, keywords_list) # Manages its own logs
- # Create a new document version representing this update
+ self._log_sync_event(conn, "Media", media_uuid, "update", new_ver, payload)
+ self._update_fts_media(conn, media_id, payload["title"], payload["content"])
+ self.update_keywords_for_media(media_id, keywords_norm)
self.create_document_version(
- media_id=media_id,
- content=content, # Use the new content for the version
- prompt=prompt,
- analysis_content=analysis_content
+ media_id=media_id, content=content, prompt=prompt, analysis_content=analysis_content
)
+ _persist_chunks(conn, media_id)
+ return media_id, media_uuid, f"Media '{title}' updated to new version."
+
+ # Case A.2: Overwrite is FALSE.
else:
- action = "already_exists_skipped"
- else: # Not existing_media
- action = "added"
+ is_canonicalisation = (
+ existing_url.startswith("local://")
+ and not url.startswith("local://")
+ and content_hash == existing_hash
+ )
+ if is_canonicalisation:
+ logging.info(f"Canonicalizing URL for media_id {media_id} to {url}")
+ new_ver = current_ver + 1
+ cur.execute(
+ "UPDATE Media SET url = ?, last_modified = ?, version = ?, client_id = ? WHERE id = ? AND version = ?",
+ (url, now, new_ver, client_id, media_id, current_ver),
+ )
+ if cur.rowcount == 0:
+ raise ConflictError(f"Media (canonicalization id={media_id})", media_id)
+
+ self._log_sync_event(
+ conn, "Media", media_uuid, "update", new_ver, {"url": url, "last_modified": now}
+ )
+ return media_id, media_uuid, f"Media '{title}' URL canonicalized."
+
+ return None, None, f"Media '{title}' already exists. Overwrite not enabled."
+
+ # --- Path B: Record does not exist, perform INSERT ---
+ else:
media_uuid = self._generate_uuid()
- new_version = 1
- logger.info(f"Inserting new media '{title}' with UUID {media_uuid}.")
- insert_data = { # Prepare dict for easier payload generation
- 'url': url,
- 'title': title,
- 'type': media_type,
- 'content': content,
- 'author': author,
- 'ingestion_date': ingestion_date_str, # Use generated/passed ingestion_date
- 'transcription_model': transcription_model,
- 'content_hash': content_hash,
- 'is_trash': 0,
- 'trash_date': None, # trash_date is NULL on creation
- 'chunking_status': initial_chunking_status,
- 'vector_processing': 0,
- 'uuid': media_uuid,
- 'last_modified': current_time, # Set last_modified
- 'version': new_version,
- 'client_id': client_id,
- 'deleted': 0
- }
- cursor.execute(
- """INSERT INTO Media (url, title, type, content, author, ingestion_date, transcription_model,
- content_hash, is_trash, trash_date, chunking_status, vector_processing, uuid,
- last_modified, version, client_id, deleted)
- VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
- (insert_data['url'], insert_data['title'], insert_data['type'], insert_data['content'],
- insert_data['author'], insert_data['ingestion_date'], # Pass ingestion_date_str
- insert_data['transcription_model'], insert_data['content_hash'], insert_data['is_trash'],
- insert_data['trash_date'], # Pass None for trash_date
- insert_data['chunking_status'], insert_data['vector_processing'], insert_data['uuid'],
- insert_data['last_modified'], # Pass current_time
- insert_data['version'], insert_data['client_id'], insert_data['deleted'])
+ payload = _media_payload(media_uuid, 1, chunk_status=final_chunk_status)
+
+ cur.execute(
+ """INSERT INTO Media (url, title, type, content, author, ingestion_date,
+ transcription_model, content_hash, is_trash, trash_date,
+ chunking_status, vector_processing, uuid, last_modified,
+ version, client_id, deleted)
+ VALUES (:url, :title, :type, :content, :author, :ingestion_date,
+ :transcription_model, :content_hash, :is_trash, :trash_date,
+ :chunking_status, :vector_processing, :uuid, :last_modified,
+ :version, :client_id, :deleted)""",
+ payload,
)
- media_id = cursor.lastrowid
+ media_id = cur.lastrowid
if not media_id:
- raise DatabaseError("Failed to get last row ID for new media.")
-
- # Use the insert_data dict directly for the payload
- self._log_sync_event(conn, 'Media', media_uuid, 'create', new_version, insert_data)
- self._update_fts_media(conn, media_id, insert_data['title'], insert_data['content'])
+ raise DatabaseError("Failed to obtain new media ID.")
- # Consolidate keyword and version creation here for "added"
- self.update_keywords_for_media(media_id, keywords_list)
+ self._log_sync_event(conn, "Media", media_uuid, "create", 1, payload)
+ self._update_fts_media(conn, media_id, payload["title"], payload["content"])
+ self.update_keywords_for_media(media_id, keywords_norm)
self.create_document_version(
- media_id=media_id,
- content=content,
- prompt=prompt,
- analysis_content=analysis_content
+ media_id=media_id, content=content, prompt=prompt, analysis_content=analysis_content
)
-
- # --- Handle Unvectorized Chunks ---
- if chunks is not None: # chunks argument was provided (could be empty or list of dicts)
- if action == "updated":
- # If overwriting and new chunks are provided, clear old ones.
- # If `chunks` is an empty list, it also means clear old ones.
- if overwrite: # Only delete if overwrite is true
- logging.debug(
- f"Hard deleting existing UnvectorizedMediaChunks for updated media_id {media_id} due to overwrite and new chunks being provided.")
- conn.execute("DELETE FROM UnvectorizedMediaChunks WHERE media_id = ?", (media_id,))
-
- num_chunks_saved = 0
- if chunks: # If chunks list is not empty
- chunk_creation_time = self._get_current_utc_timestamp_str()
- for i, chunk_data in enumerate(chunks):
- if not isinstance(chunk_data, dict) or 'text' not in chunk_data or chunk_data['text'] is None:
- logging.warning(
- f"Skipping invalid chunk data at index {i} for media_id {media_id} "
- f"(missing 'text' or not a dict): {str(chunk_data)[:100]}"
- )
- continue
-
- chunk_item_uuid = self._generate_uuid()
- metadata_payload = chunk_data.get('metadata')
-
- chunk_record_tuple = (
- media_id,
- chunk_data['text'],
- i, # chunk_index
- chunk_data.get('start_char'),
- chunk_data.get('end_char'),
- chunk_data.get('chunk_type'),
- chunk_creation_time, # creation_date
- chunk_creation_time, # last_modified_orig
- False, # is_processed
- json.dumps(metadata_payload) if isinstance(metadata_payload, dict) else None, # metadata
- chunk_item_uuid,
- chunk_creation_time, # last_modified
- 1, # version
- self.client_id,
- 0, # deleted
- None, # prev_version
- None # merge_parent_uuid
- )
- try:
- conn.execute("""
- INSERT INTO UnvectorizedMediaChunks (
- media_id, chunk_text, chunk_index, start_char, end_char, chunk_type,
- creation_date, last_modified_orig, is_processed, metadata, uuid,
- last_modified, version, client_id, deleted, prev_version, merge_parent_uuid
- ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
- """, chunk_record_tuple)
-
- # Construct payload for sync log (SQLite row is not directly usable as dict here)
- sync_payload_chunk = {
- 'media_id': media_id, 'chunk_text': chunk_data['text'], 'chunk_index': i,
- 'start_char': chunk_data.get('start_char'), 'end_char': chunk_data.get('end_char'),
- 'chunk_type': chunk_data.get('chunk_type'), 'creation_date': chunk_creation_time,
- 'last_modified_orig': chunk_creation_time, 'is_processed': False,
- 'metadata': metadata_payload, # Store original dict, not JSON string, in sync log if possible
- 'uuid': chunk_item_uuid, 'last_modified': chunk_creation_time,
- 'version': 1, 'client_id': self.client_id, 'deleted': 0
- }
- self._log_sync_event(conn, 'UnvectorizedMediaChunks', chunk_item_uuid, 'create', 1, sync_payload_chunk)
- num_chunks_saved += 1
- except sqlite3.IntegrityError as e:
- logging.error(f"Integrity error saving chunk {i} (type: {chunk_data.get('chunk_type')}) for media_id {media_id}: {e}. Data: {chunk_data['text'][:50]}...")
- raise DatabaseError(f"Failed to save chunk {i} due to integrity constraint: {e}") from e
- logging.info(f"Saved {num_chunks_saved} unvectorized chunks for media_id {media_id}.")
-
- # Update Media chunking_status
- # If chunks were provided (even an empty list, meaning "clear existing and add these (none)"),
- # then chunking is considered 'completed' from the perspective of this operation.
- # If `chunks` was None (meaning "don't touch existing chunks"), status remains as is or 'pending'.
- final_chunking_status_for_media = 'completed' # if chunks is not None
- # If the main `perform_chunking` flag (from request, not DB field) was false,
- # then perhaps status should be different. For now, if chunks data is passed, it's 'completed'.
- # This might need more nuanced logic based on the `perform_chunking` flag from the original request.
- conn.execute("UPDATE Media SET chunking_status = ? WHERE id = ?",
- (final_chunking_status_for_media, media_id,))
- logging.debug(
- f"Updated Media chunking_status to '{final_chunking_status_for_media}' for media_id {media_id} after chunk processing.")
-
- # Original chunk_options placeholder log
+ _persist_chunks(conn, media_id)
if chunk_options:
- logging.info(f"Chunking logic placeholder (using chunk_options) for media {media_id}")
+ logging.info("chunk_options ignored (placeholder): %s", chunk_options)
+
+ return media_id, media_uuid, f"Media '{title}' added."
+
+ except (InputError, ConflictError, sqlite3.IntegrityError) as e:
+ # Catch the specific IntegrityError from the trigger and re-raise as a more descriptive error if you want
+ logging.error(f"Transaction failed, rolling back: {type(e).__name__} - {e}")
+ raise # Re-raise the original exception
+ except Exception as exc:
+ logging.error(f"Unexpected error in transaction: {type(exc).__name__} - {exc}")
+ raise DatabaseError(f"Unexpected error processing media: {exc}") from exc
- # Determine message based on action (outside transaction)
- if action == "updated":
- message = f"Media '{title}' updated."
- elif action == "added":
- message = f"Media '{title}' added."
- else:
- message = f"Media '{title}' exists, not overwritten."
- return media_id, media_uuid, message
- except (InputError, ConflictError, DatabaseError, sqlite3.Error) as e:
- logger.error(f"Error processing media (URL: {url}): {e}", exc_info=isinstance(e, (DatabaseError, sqlite3.Error)))
- if isinstance(e, (InputError, ConflictError, DatabaseError)):
- raise e
- else:
- raise DatabaseError(f"Failed to process media: {e}") from e
- except Exception as e:
- logger.error(f"Unexpected error processing media (URL: {url}): {e}", exc_info=True)
- raise DatabaseError(f"Unexpected error processing media: {e}") from e
def create_document_version(self, media_id: int, content: str, prompt: Optional[str] = None, analysis_content: Optional[str] = None) -> Dict[str, Any]:
"""
@@ -3855,7 +3792,7 @@ def check_database_integrity(db_path): # Standalone check is fine
db_path (str): The path to the SQLite database file.
Returns:
- bool: True if the integrity check returns 'ok', False otherwise or if
+ bool: True if the integrity check returns 'ok', False otherwise, or if
an error occurs during the check.
"""
logger.info(f"Checking integrity of database: {db_path}")
diff --git a/tldw_chatbook/DB/Prompts_DB.py b/tldw_chatbook/DB/Prompts_DB.py
index d46cf9ab..a91a15e8 100644
--- a/tldw_chatbook/DB/Prompts_DB.py
+++ b/tldw_chatbook/DB/Prompts_DB.py
@@ -1121,7 +1121,8 @@ def soft_delete_keyword(self, keyword_text: str) -> bool:
def get_prompt_by_id(self, prompt_id: int, include_deleted: bool = False) -> Optional[Dict]:
query = "SELECT * FROM Prompts WHERE id = ?"
params = [prompt_id]
- if not include_deleted: query += " AND deleted = 0"
+ if not include_deleted:
+ query += " AND deleted = 0"
try:
cursor = self.execute_query(query, tuple(params))
result = cursor.fetchone()
@@ -1133,7 +1134,8 @@ def get_prompt_by_id(self, prompt_id: int, include_deleted: bool = False) -> Opt
def get_prompt_by_uuid(self, prompt_uuid: str, include_deleted: bool = False) -> Optional[Dict]:
query = "SELECT * FROM Prompts WHERE uuid = ?"
params = [prompt_uuid]
- if not include_deleted: query += " AND deleted = 0"
+ if not include_deleted:
+ query += " AND deleted = 0"
try:
cursor = self.execute_query(query, tuple(params))
result = cursor.fetchone()
@@ -1145,7 +1147,8 @@ def get_prompt_by_uuid(self, prompt_uuid: str, include_deleted: bool = False) ->
def get_prompt_by_name(self, name: str, include_deleted: bool = False) -> Optional[Dict]:
query = "SELECT * FROM Prompts WHERE name = ?"
params = [name]
- if not include_deleted: query += " AND deleted = 0"
+ if not include_deleted:
+ query += " AND deleted = 0"
try:
cursor = self.execute_query(query, tuple(params))
result = cursor.fetchone()
@@ -1231,7 +1234,7 @@ def fetch_keywords_for_prompt(self, prompt_id: int, include_deleted: bool = Fals
def search_prompts(self,
search_query: Optional[str],
- search_fields: Optional[List[str]] = None, # e.g. ['name', 'details', 'keywords']
+ search_fields: Optional[List[str]] = None, # e.g. ['name', 'details', 'keywords']
page: int = 1,
results_per_page: int = 20,
include_deleted: bool = False
@@ -1240,96 +1243,86 @@ def search_prompts(self,
if results_per_page < 1: raise ValueError("Results per page must be >= 1")
if search_query and not search_fields:
- search_fields = ["name", "details", "system_prompt", "user_prompt", "author"] # Default FTS fields
+ search_fields = ["name", "details", "system_prompt", "user_prompt", "author"]
elif not search_fields:
search_fields = []
offset = (page - 1) * results_per_page
- base_select_parts = ["p.id", "p.uuid", "p.name", "p.author", "p.details",
- "p.system_prompt", "p.user_prompt", "p.last_modified", "p.version", "p.deleted"]
- count_select = "COUNT(DISTINCT p.id)"
- base_from = "FROM Prompts p"
- joins = []
+ base_select = "SELECT p.*"
+ count_select = "SELECT COUNT(p.id)"
+ from_clause = "FROM Prompts p"
conditions = []
params = []
if not include_deleted:
conditions.append("p.deleted = 0")
- fts_search_active = False
- if search_query:
- fts_query_parts = []
- if "name" in search_fields: fts_query_parts.append("name")
- if "author" in search_fields: fts_query_parts.append("author")
- if "details" in search_fields: fts_query_parts.append("details")
- if "system_prompt" in search_fields: fts_query_parts.append("system_prompt")
- if "user_prompt" in search_fields: fts_query_parts.append("user_prompt")
-
- # FTS on prompt fields
- if fts_query_parts:
- fts_search_active = True
- if not any("prompts_fts fts_p" in j_item for j_item in joins):
- joins.append("JOIN prompts_fts fts_p ON fts_p.rowid = p.id")
- # Build FTS query: field1:query OR field2:query ...
- # For simple matching, just use the query directly if FTS table covers all these.
- # The FTS table definition needs to match these fields.
- # Assuming prompts_fts has 'name', 'author', 'details', 'system_prompt', 'user_prompt'
- conditions.append("fts_p.prompts_fts MATCH ?")
- params.append(search_query) # User provides FTS syntax or simple terms
-
- # FTS on keywords (if specified in search_fields)
+ # --- Robust FTS search using subqueries ---
+ if search_query and search_fields:
+ matching_prompt_ids = set()
+ text_search_fields = {"name", "author", "details", "system_prompt", "user_prompt"}
+
+ # Search in prompt text fields
+ if any(field in text_search_fields for field in search_fields):
+ try:
+ cursor = self.execute_query("SELECT rowid FROM prompts_fts WHERE prompts_fts MATCH ?", (search_query,))
+ matching_prompt_ids.update(row['rowid'] for row in cursor.fetchall())
+ except sqlite3.Error as e:
+ logging.error(f"FTS search on prompts failed: {e}", exc_info=True)
+ raise DatabaseError(f"FTS search on prompts failed: {e}") from e
+
+
+ # Search in keywords
if "keywords" in search_fields:
- fts_search_active = True
- # Join for keywords
- if not any("PromptKeywordLinks pkl" in j_item for j_item in joins):
- joins.append("JOIN PromptKeywordLinks pkl ON p.id = pkl.prompt_id")
- if not any("PromptKeywordsTable pkw" in j_item for j_item in joins):
- joins.append("JOIN PromptKeywordsTable pkw ON pkl.keyword_id = pkw.id AND pkw.deleted = 0")
- if not any("prompt_keywords_fts fts_k" in j_item for j_item in joins):
- joins.append("JOIN prompt_keywords_fts fts_k ON fts_k.rowid = pkw.id")
-
- conditions.append("fts_k.prompt_keywords_fts MATCH ?")
- params.append(search_query) # Match against keywords
-
- order_by_clause_str = "ORDER BY p.last_modified DESC, p.id DESC"
- if fts_search_active:
- # FTS results are naturally sorted by relevance (rank) by SQLite.
- # We can select rank if needed for explicit sorting or display.
- if "fts_p.rank AS relevance_score" not in " ".join(base_select_parts) and "fts_p" in " ".join(joins) :
- base_select_parts.append("fts_p.rank AS relevance_score") # Add if fts_p is used
- elif "fts_k.rank AS relevance_score_kw" not in " ".join(base_select_parts) and "fts_k" in " ".join(joins):
- base_select_parts.append("fts_k.rank AS relevance_score_kw") # Add if fts_k is used
- # A more complex ranking might be needed if both prompt and keyword FTS are active.
- # For now, default sort or rely on SQLite's combined FTS rank if multiple MATCH clauses are used.
- order_by_clause_str = "ORDER BY p.last_modified DESC, p.id DESC" # Fallback, FTS rank is implicit
-
- final_select_stmt = f"SELECT DISTINCT {', '.join(base_select_parts)}"
- join_clause = " ".join(list(dict.fromkeys(joins))) # Unique joins
- where_clause = "WHERE " + " AND ".join(conditions) if conditions else ""
+ try:
+ # 1. Find keyword IDs matching the query
+ kw_cursor = self.execute_query("SELECT rowid FROM prompt_keywords_fts WHERE prompt_keywords_fts MATCH ?", (search_query,))
+ matching_keyword_ids = {row['rowid'] for row in kw_cursor.fetchall()}
+
+ # 2. Find prompt IDs linked to those keywords
+ if matching_keyword_ids:
+ placeholders = ','.join('?' * len(matching_keyword_ids))
+ link_cursor = self.execute_query(
+ f"SELECT DISTINCT prompt_id FROM PromptKeywordLinks WHERE keyword_id IN ({placeholders})",
+ tuple(matching_keyword_ids)
+ )
+ matching_prompt_ids.update(row['prompt_id'] for row in link_cursor.fetchall())
+ except sqlite3.Error as e:
+ logging.error(f"FTS search on keywords failed: {e}", exc_info=True)
+ raise DatabaseError(f"FTS search on keywords failed: {e}") from e
+
+ if not matching_prompt_ids:
+ return [], 0 # No matches found, short-circuit
+
+ # Add the final ID list to the main query conditions
+ id_placeholders = ','.join('?' * len(matching_prompt_ids))
+ conditions.append(f"p.id IN ({id_placeholders})")
+ params.extend(list(matching_prompt_ids))
+
+ # --- Build and Execute Final Query ---
+ where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else ""
+ order_by_clause = "ORDER BY p.last_modified DESC, p.id DESC"
try:
- count_sql = f"SELECT {count_select} {base_from} {join_clause} {where_clause}"
- count_cursor = self.execute_query(count_sql, tuple(params))
- total_matches = count_cursor.fetchone()[0]
+ # Get total count
+ count_sql = f"{count_select} {from_clause} {where_clause}"
+ total_matches = self.execute_query(count_sql, tuple(params)).fetchone()[0]
results_list = []
- if total_matches > 0 and offset < total_matches:
- results_sql = f"{final_select_stmt} {base_from} {join_clause} {where_clause} {order_by_clause_str} LIMIT ? OFFSET ?"
+ if total_matches > 0:
+ # Get paginated results
+ results_sql = f"{base_select} {from_clause} {where_clause} {order_by_clause} LIMIT ? OFFSET ?"
paginated_params = tuple(params + [results_per_page, offset])
results_cursor = self.execute_query(results_sql, paginated_params)
results_list = [dict(row) for row in results_cursor.fetchall()]
- # If keywords need to be attached to each result
+ # Attach keywords to each result
for res_dict in results_list:
res_dict['keywords'] = self.fetch_keywords_for_prompt(res_dict['id'], include_deleted=False)
return results_list, total_matches
- except sqlite3.Error as e:
- if "no such table: prompts_fts" in str(e).lower() or "no such table: prompt_keywords_fts" in str(e).lower():
- logging.error(f"FTS table missing in {self.db_path_str}. Search may fail or be incomplete.")
- # Fallback to LIKE search or raise error
- # For now, let it fail and be caught by generic error.
- logging.error(f"DB error during prompt search in '{self.db_path_str}': {e}", exc_info=True)
+ except (DatabaseError, sqlite3.Error) as e:
+ logging.error(f"DB error during prompt search: {e}", exc_info=True)
raise DatabaseError(f"Failed to search prompts: {e}") from e
# --- Sync Log Access Methods ---
@@ -1355,6 +1348,7 @@ def get_sync_log_entries(self, since_change_id: int = 0, limit: Optional[int] =
logger.error(f"Error fetching sync_log entries: {e}")
raise DatabaseError("Failed to fetch sync_log entries") from e
+
def delete_sync_log_entries(self, change_ids: List[int]) -> int:
if not change_ids: return 0
if not all(isinstance(cid, int) for cid in change_ids):
diff --git a/tldw_chatbook/DB/Sync_Client.py b/tldw_chatbook/DB/Sync_Client.py
index 3ef19b52..d13f0add 100644
--- a/tldw_chatbook/DB/Sync_Client.py
+++ b/tldw_chatbook/DB/Sync_Client.py
@@ -13,10 +13,7 @@
# Third-Party Imports
#
# Local Imports
-try:
- from tldw_cli.tldw_app.DB.Media_DB import Database, ConflictError, DatabaseError, InputError
-except ImportError:
- logger.error("ERROR: Could not import the 'Media_DB' library. Make sure Media_DB.py is accessible.")
+from tldw_chatbook.DB.Client_Media_DB_v2 import MediaDatabase as Database, ConflictError, DatabaseError, InputError
#
#######################################################################################################################
#
diff --git a/tldw_chatbook/Event_Handlers/LLM_Management_Events/llm_management_events_ollama.py b/tldw_chatbook/Event_Handlers/LLM_Management_Events/llm_management_events_ollama.py
index c60cf202..4ff3c099 100644
--- a/tldw_chatbook/Event_Handlers/LLM_Management_Events/llm_management_events_ollama.py
+++ b/tldw_chatbook/Event_Handlers/LLM_Management_Events/llm_management_events_ollama.py
@@ -16,7 +16,7 @@
from pathlib import Path
from typing import TYPE_CHECKING, Optional
-from cv2 import data
+#from cv2 import data
from textual.containers import Container
from textual.css.query import QueryError
from textual.widgets import Input, TextArea, RichLog
@@ -33,7 +33,8 @@
__all__ = [
# ─── Ollama ───────────────────────────────────────────────────────────────
"handle_ollama_nav_button_pressed",
- "handle_ollama_list_models_button_pressed",
+ # FIXME
+ #"handle_ollama_list_models_button_pressed",
"handle_ollama_show_model_button_pressed",
"handle_ollama_delete_model_button_pressed",
"handle_ollama_copy_model_button_pressed",
@@ -106,57 +107,58 @@ async def handle_ollama_nav_button_pressed(app: "TldwCli") -> None:
app.notify("An unexpected error occurred while switching to Ollama view.", severity="error")
-async def handle_ollama_list_models_button_pressed(app: "TldwCli") -> None:
- """Handles the 'List Models' button press for Ollama."""
- logger = getattr(app, "loguru_logger", logging.getLogger(__name__))
- logger.debug("Ollama 'List Models' button pressed.")
- try:
- base_url_input = app.query_one("#ollama-server-url", Input)
- log_output_widget = app.query_one("#ollama-combined-output", RichLog)
-
- base_url = base_url_input.value.strip()
- if not base_url:
- app.notify("Ollama Server URL is required.", severity="error")
- base_url_input.focus()
- return
-
- log_output_widget.clear()
- _update_ollama_combined_output(app, f"Attempting to list models from: {base_url}...")
-
- app.run_worker(
- _worker_ollama_list_models,
- base_url,
- thread=True,
- name=f"ollama_list_models_{time.monotonic()}",
- group="ollama_api",
- description="Listing Ollama local models",
- on_success=partial(_on_list_models_success, app),
- on_error=partial(_on_ollama_worker_error, app, "list_models")
- )
- if logging.error:
- log_output_widget.write(f"Error listing models: {logging.error}")
-
- if logging.error: # This is the original error check, the one above is newly added by script
- log_output_widget.write(f"Error listing models: {logging.error}")
- app.notify("Error listing Ollama models.", severity="error")
- elif data and data.get('models'):
- try:
- # Assuming 'data' is the JSON response, and 'models' is a list within it.
- formatted_models = json.dumps(data['models'], indent=2)
- log_output_widget.write(formatted_models)
- app.notify(f"Successfully listed {len(data['models'])} Ollama models.")
- except (TypeError, KeyError, json.JSONDecodeError) as e:
- log_output_widget.write(f"Error processing model list response: {e}\nRaw data: {data}")
- app.notify("Error processing model list from Ollama.", severity="error")
- else:
- log_output_widget.write("No models found or unexpected response.")
- app.notify("No Ollama models found or unexpected response.", severity="warning")
- except QueryError as e: # pragma: no cover
- logger.error(f"QueryError in handle_ollama_list_models_button_pressed: {e}", exc_info=True)
- app.notify("Error accessing Ollama UI elements for listing models.", severity="error")
- except Exception as e: # pragma: no cover
- logger.error(f"Unexpected error in handle_ollama_list_models_button_pressed: {e}", exc_info=True)
- app.notify("An unexpected error occurred while listing Ollama models.", severity="error")
+# FIXME
+# async def handle_ollama_list_models_button_pressed(app: "TldwCli") -> None:
+# """Handles the 'List Models' button press for Ollama."""
+# logger = getattr(app, "loguru_logger", logging.getLogger(__name__))
+# logger.debug("Ollama 'List Models' button pressed.")
+# try:
+# base_url_input = app.query_one("#ollama-server-url", Input)
+# log_output_widget = app.query_one("#ollama-combined-output", RichLog)
+#
+# base_url = base_url_input.value.strip()
+# if not base_url:
+# app.notify("Ollama Server URL is required.", severity="error")
+# base_url_input.focus()
+# return
+#
+# log_output_widget.clear()
+# _update_ollama_combined_output(app, f"Attempting to list models from: {base_url}...")
+#
+# app.run_worker(
+# _worker_ollama_list_models,
+# base_url,
+# thread=True,
+# name=f"ollama_list_models_{time.monotonic()}",
+# group="ollama_api",
+# description="Listing Ollama local models",
+# on_success=partial(_on_list_models_success, app),
+# on_error=partial(_on_ollama_worker_error, app, "list_models")
+# )
+# if logging.error:
+# log_output_widget.write(f"Error listing models: {logging.error}")
+#
+# if logging.error: # This is the original error check, the one above is newly added by script
+# log_output_widget.write(f"Error listing models: {logging.error}")
+# app.notify("Error listing Ollama models.", severity="error")
+# elif data and data.get('models'):
+# try:
+# # Assuming 'data' is the JSON response, and 'models' is a list within it.
+# formatted_models = json.dumps(data['models'], indent=2)
+# log_output_widget.write(formatted_models)
+# app.notify(f"Successfully listed {len(data['models'])} Ollama models.")
+# except (TypeError, KeyError, json.JSONDecodeError) as e:
+# log_output_widget.write(f"Error processing model list response: {e}\nRaw data: {data}")
+# app.notify("Error processing model list from Ollama.", severity="error")
+# else:
+# log_output_widget.write("No models found or unexpected response.")
+# app.notify("No Ollama models found or unexpected response.", severity="warning")
+# except QueryError as e: # pragma: no cover
+# logger.error(f"QueryError in handle_ollama_list_models_button_pressed: {e}", exc_info=True)
+# app.notify("Error accessing Ollama UI elements for listing models.", severity="error")
+# except Exception as e: # pragma: no cover
+# logger.error(f"Unexpected error in handle_ollama_list_models_button_pressed: {e}", exc_info=True)
+# app.notify("An unexpected error occurred while listing Ollama models.", severity="error")
async def handle_ollama_show_model_button_pressed(app: "TldwCli") -> None:
diff --git a/tldw_chatbook/Event_Handlers/ingest_events.py b/tldw_chatbook/Event_Handlers/ingest_events.py
index eca057aa..b05d3375 100644
--- a/tldw_chatbook/Event_Handlers/ingest_events.py
+++ b/tldw_chatbook/Event_Handlers/ingest_events.py
@@ -3,6 +3,7 @@
#
# Imports
import json
+from os import getenv
from pathlib import Path
from typing import TYPE_CHECKING, Optional, List, Any, Dict, Callable, Union
#
@@ -12,6 +13,7 @@
ListView, Collapsible, LoadingIndicator, Button
from textual.css.query import QueryError
from textual.containers import Container, VerticalScroll
+from textual.worker import Worker
from ..Constants import ALL_TLDW_API_OPTION_CONTAINERS
#
@@ -20,6 +22,7 @@
import tldw_chatbook.Event_Handlers.conv_char_events as ccp_handlers
from .Chat_Events import chat_events as chat_handlers
from tldw_chatbook.Event_Handlers.Chat_Events.chat_events import populate_chat_conversation_character_filter_select
+from ..config import get_cli_setting
from ..tldw_api import (
TLDWAPIClient, ProcessVideoRequest, ProcessAudioRequest,
APIConnectionError, APIRequestError, APIResponseError, AuthenticationError,
@@ -1121,6 +1124,12 @@ async def handle_tldw_api_submit_button_pressed(app: 'TldwCli', event: Button.Pr
submit_button.disabled = True
# app.notify is already called at the start of the function
+ def _reset_ui():
+ """Return the widgets to their idle state after a hard failure."""
+ loading_indicator.display = False
+ submit_button.disabled = False
+ status_area.load_text("Submission halted.")
+
# --- Get Auth Token (after basic validations pass) ---
auth_token: Optional[str] = None
try:
@@ -1135,14 +1144,23 @@ async def handle_tldw_api_submit_button_pressed(app: 'TldwCli', event: Button.Pr
submit_button.disabled = False
status_area.load_text("Custom token required. Submission halted.")
return
+
elif auth_method == "config_token":
- auth_token = app.app_config.get("tldw_api", {}).get("auth_token_config")
+ # 1. Look in the active config, then in the environment.
+ auth_token = (
+ get_cli_setting("tldw_api", "auth_token") # ~/.config/tldw_cli/config.toml
+ or getenv("TDLW_AUTH_TOKEN") # optional override
+ )
+
+ # 2. Abort early if we still have nothing.
if not auth_token:
- app.notify("Auth Token not found in tldw_api.auth_token_config. Please configure or use custom.", severity="error")
- # Revert UI loading state
- loading_indicator.display = False
- submit_button.disabled = False
- status_area.load_text("Config token missing. Submission halted.")
+ msg = (
+ "Auth token not found — add it to the [tldw_api] section as "
+ "`auth_token = \"\"` or export TDLW_AUTH_TOKEN."
+ )
+ logger.error(msg)
+ app.notify(msg, severity="error")
+ _reset_ui()
return
except QueryError as e:
logger.error(f"UI component not found for TLDW API auth token for {selected_media_type}: {e}")
@@ -1486,12 +1504,18 @@ def on_worker_failure(error: Exception):
app.notify(brief_notify_message, severity="error", timeout=8)
+ # STORE THE CONTEXT
+ app._last_tldw_api_request_context = {
+ "request_model": request_model,
+ "overwrite_db": overwrite_db,
+ }
app.run_worker(
process_media_worker,
name=f"tldw_api_processing_{selected_media_type}", # Unique worker name per tab
group="api_calls",
- description=f"Processing {selected_media_type} media via TLDW API"
+ description=f"Processing {selected_media_type} media via TLDW API",
+ exit_on_error=False
)
@@ -1815,6 +1839,244 @@ def on_import_failure_notes(error: Exception):
description="Importing selected note files."
)
+async def handle_tldw_api_worker_failure(app: 'TldwCli', event: 'Worker.StateChanged'):
+ """Handles the failure of a TLDW API worker and updates the UI."""
+ worker_name = event.worker.name or ""
+ media_type = worker_name.replace("tldw_api_processing_", "")
+ error = event.worker.error
+
+ logger.error(f"TLDW API request worker failed for {media_type}: {error}", exc_info=True)
+
+ try:
+ loading_indicator = app.query_one(f"#tldw-api-loading-indicator-{media_type}", LoadingIndicator)
+ submit_button = app.query_one(f"#tldw-api-submit-{media_type}", Button)
+ status_area = app.query_one(f"#tldw-api-status-area-{media_type}", TextArea)
+
+ loading_indicator.display = False
+ submit_button.disabled = False
+ except QueryError as e_ui:
+ logger.error(f"UI component not found in on_worker_failure for {media_type}: {e_ui}")
+ return
+
+ error_message_parts = [f"## API Request Failed! ({media_type.title()})\n\n"]
+ brief_notify_message = f"{media_type.title()} API Request Failed."
+
+ # This logic is copied from your original local on_worker_failure function
+ if isinstance(error, APIConnectionError):
+ error_type = "Connection Error"
+ error_message_parts.append(f"**Type:** {error_type}\n")
+ error_message_parts.append(f"**Message:** `{str(error)}`\n")
+ brief_notify_message = f"Connection Error: {str(error)[:100]}"
+ elif isinstance(error, APIResponseError):
+ error_type = "API Error"
+ error_message_parts.append(f"**Type:** API Error\n**Status Code:** {error.status_code}\n**Message:** `{str(error)}`\n")
+ brief_notify_message = f"API Error {error.status_code}: {str(error)[:100]}"
+ # ... add other specific error types from your original function if needed ...
+ else:
+ error_type = "General Error"
+ error_message_parts.append(f"**Type:** {type(error).__name__}\n")
+ error_message_parts.append(f"**Message:** `{str(error)}`\n")
+ brief_notify_message = f"Processing failed: {str(error)[:100]}"
+
+ status_area.clear()
+ status_area.load_text("".join(error_message_parts))
+ status_area.display = True
+ app.notify(brief_notify_message, severity="error", timeout=8)
+
+
+# async def handle_tldw_api_worker_success(app: 'TldwCli', event: 'Worker.StateChanged'):
+# """Handles the success of a TLDW API worker and ingests the results."""
+# # This function would contain the logic from your original 'on_worker_success'
+# # It needs to be made async and re-query UI elements.
+# # The logic is complex, so for brevity, I'll show the skeleton.
+# # The key is that you have access to `app` and `event` (which has the result).
+#
+# worker_name = event.worker.name or ""
+# media_type = worker_name.replace("tldw_api_processing_", "")
+# response_data = event.worker.result
+#
+# logger.info(f"TLDW API worker for {media_type} succeeded. Processing results.")
+#
+# try:
+# # Reset UI state (disable loading, enable button)
+# app.query_one(f"#tldw-api-loading-indicator-{media_type}", LoadingIndicator).display = False
+# app.query_one(f"#tldw-api-submit-{media_type}", Button).disabled = False
+# status_area = app.query_one(f"#tldw-api-status-area-{media_type}", TextArea)
+# status_area.clear()
+#
+# # The rest of your success logic goes here. For example:
+# if not app.media_db:
+# # ... handle missing db ...
+# return
+#
+# # ... process response_data, ingest to db, build summary_parts ...
+# # (This logic is already well-written in your original `on_worker_success`)
+# # You will need to retrieve `overwrite_db` and `request_model` if they
+# # are needed for the success logic, potentially by storing them on the app
+# # instance temporarily before starting the worker.
+# status_area.load_text("## Success!\n\n(Full success logic from original function would go here)")
+# status_area.display = True
+# app.notify(f"{media_type.title()} processing complete.", severity="information")
+#
+# except QueryError as e_ui:
+# logger.error(f"UI component not found in on_worker_success for {media_type}: {e_ui}")
+# except Exception as e:
+# logger.error(f"Error handling TLDW API worker success for {media_type}: {e}", exc_info=True)
+# app.notify("Error processing successful API response.", severity="error")
+async def handle_tldw_api_worker_success(app: 'TldwCli', event: 'Worker.StateChanged'):
+ """Handles the success of a TLDW API worker and ingests the results."""
+ worker_name = event.worker.name or ""
+ media_type = worker_name.replace("tldw_api_processing_", "")
+ response_data = event.worker.result
+
+ logger.info(f"TLDW API worker for {media_type} succeeded. Processing results.")
+
+ try:
+ # Reset UI state (disable loading, enable button)
+ app.query_one(f"#tldw-api-loading-indicator-{media_type}", LoadingIndicator).display = False
+ app.query_one(f"#tldw-api-submit-{media_type}", Button).disabled = False
+ status_area = app.query_one(f"#tldw-api-status-area-{media_type}", TextArea)
+ status_area.clear()
+
+ except QueryError as e_ui:
+ logger.error(f"UI component not found in on_worker_success for {media_type}: {e_ui}")
+ return
+ except Exception as e:
+ logger.error(f"Error resetting UI state in worker success handler: {e}", exc_info=True)
+ return
+
+ # --- Pre-flight Checks and Context Retrieval ---
+ if not app.media_db:
+ logger.error("Media_DB_v2 not initialized. Cannot ingest API results.")
+ app.notify("Error: Local media database not available.", severity="error")
+ status_area.load_text("## Error\n\nLocal media database is not available. Cannot save results.")
+ return
+
+ # Retrieve the context we saved before starting the worker
+ request_context = getattr(app, "_last_tldw_api_request_context", {})
+ request_model = request_context.get("request_model")
+ overwrite_db = request_context.get("overwrite_db", False)
+
+ if not request_model:
+ logger.error("Could not retrieve request_model from app context. Cannot properly ingest results.")
+ status_area.load_text("## Internal Error\n\nCould not retrieve original request context. Ingestion aborted.")
+ return
+
+ # --- Data Processing and Ingestion ---
+ processed_count = 0
+ error_count = 0
+ successful_ingestions_details = []
+ results_to_ingest: List[MediaItemProcessResult] = []
+
+ # Normalize different response types into a single list of MediaItemProcessResult
+ if isinstance(response_data, BatchMediaProcessResponse):
+ results_to_ingest = response_data.results
+ elif isinstance(response_data, list) and all(isinstance(item, ProcessedMediaWikiPage) for item in response_data):
+ for mw_page in response_data:
+ if mw_page.status == "Error":
+ error_count += 1
+ logger.error(f"MediaWiki page '{mw_page.title}' processing error: {mw_page.error_message}")
+ continue
+ # Adapt ProcessedMediaWikiPage to the common result structure
+ results_to_ingest.append(MediaItemProcessResult(
+ status="Success",
+ input_ref=mw_page.input_ref or mw_page.title,
+ processing_source=mw_page.title,
+ media_type="mediawiki_page",
+ metadata={"title": mw_page.title, "page_id": mw_page.page_id, "namespace": mw_page.namespace},
+ content=mw_page.content,
+ chunks=[{"text": chunk.get("text", ""), "metadata": chunk.get("metadata", {})} for chunk in mw_page.chunks] if mw_page.chunks else None,
+ ))
+ elif isinstance(response_data, BatchProcessXMLResponse):
+ for xml_item in response_data.results:
+ if xml_item.status == "Error":
+ error_count +=1; continue
+ results_to_ingest.append(MediaItemProcessResult(
+ status="Success", input_ref=xml_item.input_ref, media_type="xml",
+ metadata={"title": xml_item.title, "author": xml_item.author, "keywords": xml_item.keywords},
+ content=xml_item.content, analysis=xml_item.summary,
+ ))
+ else:
+ logger.error(f"Unexpected TLDW API response data type for {media_type}: {type(response_data)}.")
+ status_area.load_text(f"## API Request Processed\n\nUnexpected response format. Raw response logged.")
+ app.notify("Error: Received unexpected data format from API.", severity="error")
+ return
+
+ # --- Ingestion Loop ---
+ for item_result in results_to_ingest:
+ if item_result.status == "Success":
+ try:
+ # Prepare chunks for database insertion
+ unvectorized_chunks_to_save = []
+ if item_result.chunks:
+ for chunk_item in item_result.chunks:
+ if isinstance(chunk_item, dict) and "text" in chunk_item:
+ unvectorized_chunks_to_save.append({
+ "text": chunk_item.get("text"), "metadata": chunk_item.get("metadata", {})
+ })
+ elif isinstance(chunk_item, str):
+ unvectorized_chunks_to_save.append({"text": chunk_item, "metadata": {}})
+
+ # Call the DB function with data from both the API response and original request
+ media_id, _, msg = app.media_db.add_media_with_keywords(
+ url=item_result.input_ref,
+ title=item_result.metadata.get("title", item_result.input_ref),
+ media_type=item_result.media_type,
+ content=item_result.content or item_result.transcript,
+ keywords=item_result.metadata.get("keywords", []) or request_model.keywords,
+ prompt=request_model.custom_prompt,
+ analysis_content=item_result.analysis or item_result.summary,
+ author=item_result.metadata.get("author") or request_model.author,
+ overwrite=overwrite_db,
+ chunks=unvectorized_chunks_to_save
+ )
+
+ if media_id:
+ logger.info(f"Successfully ingested '{item_result.input_ref}' into local DB. Media ID: {media_id}. Msg: {msg}")
+ processed_count += 1
+ successful_ingestions_details.append({
+ "input_ref": item_result.input_ref,
+ "title": item_result.metadata.get("title", "N/A"),
+ "media_type": item_result.media_type,
+ "db_id": media_id
+ })
+ else:
+ logger.error(f"Failed to ingest '{item_result.input_ref}' into local DB. Message: {msg}")
+ error_count += 1
+
+ except Exception as e_ingest:
+ logger.error(f"Error ingesting item '{item_result.input_ref}' into local DB: {e_ingest}", exc_info=True)
+ error_count += 1
+ else:
+ logger.error(f"API processing error for '{item_result.input_ref}': {item_result.error}")
+ error_count += 1
+
+ # --- Build and Display Summary ---
+ summary_parts = [f"## TLDW API Request Successful ({media_type.title()})\n\n"]
+ if not results_to_ingest and error_count == 0:
+ summary_parts.append("API request successful, but no items were provided or found for processing.\n")
+ else:
+ summary_parts.append(f"- **Successfully Processed & Ingested:** {processed_count}\n")
+ summary_parts.append(f"- **Errors (API or DB):** {error_count}\n\n")
+
+ if error_count > 0:
+ summary_parts.append("**Please check the application logs for details on any errors.**\n\n")
+
+ if successful_ingestions_details:
+ summary_parts.append("### Ingested Items:\n")
+ for detail in successful_ingestions_details[:10]: # Show max 10 details
+ title_str = f" (Title: `{detail['title']}`)" if detail['title'] != 'N/A' else ""
+ summary_parts.append(f"- **Input:** `{detail['input_ref']}`{title_str}\n")
+ summary_parts.append(f" - **Type:** {detail['media_type']}, **DB ID:** {detail['db_id']}\n")
+ if len(successful_ingestions_details) > 10:
+ summary_parts.append(f"\n...and {len(successful_ingestions_details) - 10} more items.")
+
+ status_area.load_text("".join(summary_parts))
+ status_area.display = True
+ status_area.scroll_home(animate=False)
+
+ notify_msg = f"{media_type.title()} Ingestion: {processed_count} done, {error_count} errors."
+ app.notify(notify_msg, severity="information" if error_count == 0 else "warning", timeout=7)
# --- Button Handler Map ---
INGEST_BUTTON_HANDLERS = {
diff --git a/tldw_chatbook/UI/Ingest_Window.py b/tldw_chatbook/UI/Ingest_Window.py
index 1cb0a68a..b17215df 100644
--- a/tldw_chatbook/UI/Ingest_Window.py
+++ b/tldw_chatbook/UI/Ingest_Window.py
@@ -57,7 +57,7 @@ def compose_tldw_api_form(self, media_type: str) -> ComposeResult:
if not analysis_provider_options:
analysis_provider_options = [("No Providers Configured", Select.BLANK)]
- with VerticalScroll(classes="ingest-form-scrollable"): # TODO: Consider if this scrollable itself needs a unique ID if we have nested ones. For now, assuming not.
+ with VerticalScroll(classes="ingest-form-scrollable"): # FIXME/TODO: Needs unique Header ID since this is temlplate for whatever media type is selected
yield Static("TLDW API Configuration", classes="sidebar-title")
yield Label("API Endpoint URL:")
yield Input(default_api_url, id=f"tldw-api-endpoint-url-{media_type}", placeholder="http://localhost:8000")
diff --git a/tldw_chatbook/app.py b/tldw_chatbook/app.py
index 47dd7574..b7f87026 100644
--- a/tldw_chatbook/app.py
+++ b/tldw_chatbook/app.py
@@ -334,6 +334,9 @@ class TldwCli(App[None]): # Specify return type for run() if needed, None is co
current_chat_note_id: Optional[str] = None
current_chat_note_version: Optional[int] = None
+ # Shared state for tldw API requests
+ _last_tldw_api_request_context: Dict[str, Any] = {}
+
def __init__(self):
super().__init__()
self.MediaDatabase = MediaDatabase
@@ -2001,6 +2004,14 @@ async def on_button_pressed(self, event: Button.Pressed) -> None:
self.loguru_logger.debug(f"Button pressed: ID='{button_id}' on Tab='{self.current_tab}'")
+ if button_id.startswith("tldw-api-browse-local-files-button-"):
+ try:
+ ingest_window = self.query_one(IngestWindow)
+ await ingest_window.on_button_pressed(event)
+ return # Event handled, stop further processing
+ except QueryError:
+ self.loguru_logger.error("Could not find IngestWindow to delegate browse button press.")
+
# 1. Handle global tab switching first
if button_id.startswith("tab-"):
await tab_events.handle_tab_button_pressed(self, event)
@@ -2070,19 +2081,16 @@ def _update_mlx_log(self, message: str) -> None:
async def on_input_changed(self, event: Input.Changed) -> None:
input_id = event.input.id
current_active_tab = self.current_tab
- # --- Chat Sidebar Prompt Search ---
- if input_id == "chat-prompt-search-input" and current_active_tab == TAB_CHAT:
- await chat_handlers.handle_chat_sidebar_prompt_search_input_changed(self, event.value)
# --- Notes Search ---
- elif input_id == "notes-search-input" and current_active_tab == TAB_NOTES:
+ if input_id == "notes-search-input" and current_active_tab == TAB_NOTES: # Changed from elif to if
await notes_handlers.handle_notes_search_input_changed(self, event.value)
# --- Chat Sidebar Conversation Search ---
elif input_id == "chat-conversation-search-bar" and current_active_tab == TAB_CHAT:
await chat_handlers.handle_chat_conversation_search_bar_changed(self, event.value)
elif input_id == "conv-char-search-input" and current_active_tab == TAB_CCP:
- await ccp_handlers.handle_ccp_conversation_search_input_changed(self, event.value)
+ await ccp_handlers.handle_ccp_conversation_search_input_changed(self, event)
elif input_id == "ccp-prompt-search-input" and current_active_tab == TAB_CCP:
- await ccp_handlers.handle_ccp_prompt_search_input_changed(self, event.value)
+ await ccp_handlers.handle_ccp_prompt_search_input_changed(self, event)
elif input_id == "chat-prompt-search-input" and current_active_tab == TAB_CHAT: # New condition
if self._chat_sidebar_prompt_search_timer: # Use the new timer variable
self._chat_sidebar_prompt_search_timer.stop()
@@ -2268,6 +2276,15 @@ async def on_worker_state_changed(self, event: Worker.StateChanged) -> None:
else:
self.loguru_logger.debug(f"Chat-related worker '{worker_name_attr}' in other state: {worker_state}")
+ #######################################################################
+ # --- Handle tldw server API Calls Worker (tldw API Ingestion) ---
+ #######################################################################
+ elif worker_group == "api_calls":
+ self.loguru_logger.info(f"TLDW API worker '{event.worker.name}' finished with state {event.state}.")
+ if worker_state == WorkerState.SUCCESS:
+ await ingest_events.handle_tldw_api_worker_success(self, event)
+ elif worker_state == WorkerState.ERROR:
+ await ingest_events.handle_tldw_api_worker_failure(self, event)
#######################################################################
# --- Handle Ollama API Worker ---
diff --git a/tldw_chatbook/config.py b/tldw_chatbook/config.py
index 480e6717..8826e5f1 100644
--- a/tldw_chatbook/config.py
+++ b/tldw_chatbook/config.py
@@ -867,7 +867,7 @@ def get_api_key(toml_key: str, env_var: str, section: Dict = api_section_legacy)
[tldw_api]
base_url = "http://127.0.0.1:8000" # Or your actual default remote endpoint
# Default auth token can be stored here, or leave empty if user must always provide
-# auth_token = "your_secret_token_if_you_have_a_default"
+auth_token = "default-secret-key-for-single-user"
[logging]
# Log file will be placed in the same directory as the chachanotes_db_path below.
@@ -894,6 +894,11 @@ def get_api_key(toml_key: str, env_var: str, section: Dict = api_section_legacy)
vLLM = "http://localhost:8000" # Check if your API provider uses this address
Custom = "http://localhost:1234/v1"
Custom_2 = "http://localhost:5678/v1"
+Custom_3 = "http://localhost:5678/v1"
+Custom_4 = "http://localhost:5678/v1"
+Custom_5 = "http://localhost:5678/v1"
+Custom_6 = "http://localhost:5678/v1"
+
# Add other local URLs if needed
[providers]
@@ -1397,6 +1402,84 @@ def load_cli_config_and_ensure_existence(force_reload: bool = False) -> Dict[str
return _CONFIG_CACHE
+def save_setting_to_cli_config(section: str, key: str, value: Any) -> bool:
+ """
+ Saves a specific setting to the user's CLI TOML configuration file.
+
+ This function reads the current config, updates a specific key within a
+ section (handling nested sections like 'api_settings.openai'), and writes
+ the entire configuration back to the file. It then forces a reload of the
+ config cache.
+
+ Args:
+ section: The name of the TOML section (e.g., "general", "api_settings.openai").
+ key: The key within the section to update.
+ value: The new value for the key.
+
+ Returns:
+ True if the setting was saved successfully, False otherwise.
+ """
+ global _CONFIG_CACHE, settings
+ logger.info(f"Attempting to save setting: [{section}].{key} = {repr(value)}")
+
+ # Ensure the parent directory for the config file exists.
+ try:
+ DEFAULT_CONFIG_PATH.parent.mkdir(parents=True, exist_ok=True)
+ except OSError as e:
+ logger.error(f"Could not create config directory {DEFAULT_CONFIG_PATH.parent}: {e}")
+ return False
+
+ # Step 1: Read the current configuration from the user's file.
+ # If the file doesn't exist, we start with an empty dictionary.
+ config_data: Dict[str, Any] = {}
+ if DEFAULT_CONFIG_PATH.exists():
+ try:
+ with open(DEFAULT_CONFIG_PATH, "rb") as f:
+ config_data = tomllib.load(f)
+ except tomllib.TOMLDecodeError as e:
+ logger.error(f"Corrupted config file at {DEFAULT_CONFIG_PATH}. Cannot save. Please fix or delete it. Error: {e}")
+ # Consider creating a backup of the corrupt file for the user.
+ return False
+ except Exception as e:
+ logger.error(f"Unexpected error reading {DEFAULT_CONFIG_PATH}: {e}", exc_info=True)
+ return False
+
+ # Step 2: Modify the configuration data in memory.
+ # This handles nested sections by splitting the section string.
+ keys = section.split('.')
+ current_level = config_data
+
+ try:
+ for part in keys:
+ # Traverse or create the nested dictionary structure.
+ current_level = current_level.setdefault(part, {})
+ # Assign the new value to the key in the target section.
+ current_level[key] = value
+ except (TypeError, AttributeError):
+ # This error occurs if a key in the path (e.g., 'api_settings') is a value, not a table.
+ logger.error(
+ f"Configuration structure conflict. Could not set '{key}' in section '{section}' "
+ f"because a part of the path is not a table/dictionary. Please check your config file."
+ )
+ return False
+
+ # Step 3: Write the updated configuration back to the TOML file.
+ try:
+ with open(DEFAULT_CONFIG_PATH, "w", encoding="utf-8") as f:
+ toml.dump(config_data, f)
+ logger.success(f"Successfully saved setting to {DEFAULT_CONFIG_PATH}")
+
+ # Step 4: Invalidate and reload global config caches to reflect changes immediately.
+ load_cli_config_and_ensure_existence(force_reload=True)
+ settings = load_settings()
+ logger.info("Global configuration caches reloaded.")
+
+ return True
+ except (IOError, toml.TomlDecodeError) as e:
+ logger.error(f"Failed to write updated config to {DEFAULT_CONFIG_PATH}: {e}", exc_info=True)
+ return False
+
+
# --- CLI Setting Getter ---
def get_cli_setting(section: str, key: str, default: Any = None) -> Any:
"""Helper to get a specific setting from the loaded CLI configuration."""
diff --git a/tldw_chatbook/tldw_api/client.py b/tldw_chatbook/tldw_api/client.py
index f966fa97..434ce28a 100644
--- a/tldw_chatbook/tldw_api/client.py
+++ b/tldw_chatbook/tldw_api/client.py
@@ -28,14 +28,19 @@ class TLDWAPIClient:
def __init__(self, base_url: str, token: Optional[str] = None, timeout: float = 300.0):
self.base_url = base_url.rstrip('/')
self.token = token
+ self.bearer_token = None
self.timeout = timeout
self._client: Optional[httpx.AsyncClient] = None
async def _get_client(self) -> httpx.AsyncClient:
if self._client is None or self._client.is_closed:
headers = {}
+ if self.bearer_token:
+ # Bearer Auth
+ headers["Authorization"] = f"Bearer {self.bearer_token}"
if self.token:
- headers["Authorization"] = f"Bearer {self.token}"
+ # Token Auth
+ headers["X-API-KEY"] = self.token
self._client = httpx.AsyncClient(
base_url=self.base_url,
headers=headers,
@@ -130,31 +135,31 @@ async def _stream_request(
async def process_video(self, request_data: ProcessVideoRequest, file_paths: Optional[List[str]] = None) -> BatchMediaProcessResponse:
form_data = model_to_form_data(request_data)
httpx_files = prepare_files_for_httpx(file_paths, upload_field_name="files")
- response_dict = await self._request("POST", "/api/v1/process-videos", data=form_data, files=httpx_files)
+ response_dict = await self._request("POST", "/api/v1/media/process-videos", data=form_data, files=httpx_files)
return BatchMediaProcessResponse(**response_dict)
async def process_audio(self, request_data: ProcessAudioRequest, file_paths: Optional[List[str]] = None) -> BatchMediaProcessResponse:
form_data = model_to_form_data(request_data)
httpx_files = prepare_files_for_httpx(file_paths, upload_field_name="files")
- response_dict = await self._request("POST", "/api/v1/process-audios", data=form_data, files=httpx_files)
+ response_dict = await self._request("POST", "/api/v1/media/process-audios", data=form_data, files=httpx_files)
return BatchMediaProcessResponse(**response_dict)
async def process_pdf(self, request_data: ProcessPDFRequest, file_paths: Optional[List[str]] = None) -> BatchMediaProcessResponse:
form_data = model_to_form_data(request_data)
httpx_files = prepare_files_for_httpx(file_paths, upload_field_name="files")
- response_dict = await self._request("POST", "/api/v1/process-pdfs", data=form_data, files=httpx_files)
+ response_dict = await self._request("POST", "/api/v1/media/process-pdfs", data=form_data, files=httpx_files)
return BatchMediaProcessResponse(**response_dict)
async def process_ebook(self, request_data: ProcessEbookRequest, file_paths: Optional[List[str]] = None) -> BatchMediaProcessResponse:
form_data = model_to_form_data(request_data)
httpx_files = prepare_files_for_httpx(file_paths, upload_field_name="files")
- response_dict = await self._request("POST", "/api/v1/process-ebooks", data=form_data, files=httpx_files)
+ response_dict = await self._request("POST", "/api/v1/media/process-ebooks", data=form_data, files=httpx_files)
return BatchMediaProcessResponse(**response_dict)
async def process_document(self, request_data: ProcessDocumentRequest, file_paths: Optional[List[str]] = None) -> BatchMediaProcessResponse:
form_data = model_to_form_data(request_data)
httpx_files = prepare_files_for_httpx(file_paths, upload_field_name="files")
- response_dict = await self._request("POST", "/api/v1/process-documents", data=form_data, files=httpx_files)
+ response_dict = await self._request("POST", "/api/v1/media/process-documents", data=form_data, files=httpx_files)
return BatchMediaProcessResponse(**response_dict)
async def process_xml(self, request_data: ProcessXMLRequest, file_path: str) -> BatchProcessXMLResponse: # XML expects single file