diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index f65d788c..d1be84b0 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -26,7 +26,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install pytest + pip install pytest pytest_asyncio hypothesis if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - name: Test with pytest run: | diff --git a/README.md b/README.md index b6f349c6..06f20ca2 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,7 @@ Current status: Working/In-Progress - Launch with `python3 -m tldw_chatbook.app` (working on pip packaging) - Inspiration from [Elia Chat](https://github.com/darrenburns/elia) +#### Chat Features
Chat Features @@ -57,6 +58,7 @@ Current status: Working/In-Progress - Support for searchingloading/editing/saving Notes from the Notes Database.
+#### Notes & Media Features
Notes & Media Features @@ -75,7 +77,7 @@ Current status: Working/In-Progress -
- +#### Local LLM Management Features
Local LLM Inference diff --git a/Tests/DB/__init__.py b/STests/Sync/__init__.py similarity index 100% rename from Tests/DB/__init__.py rename to STests/Sync/__init__.py diff --git a/Tests/Sync/test_sync_client-1.py b/STests/Sync/test_sync_client-1.py similarity index 100% rename from Tests/Sync/test_sync_client-1.py rename to STests/Sync/test_sync_client-1.py diff --git a/Tests/Sync/test_sync_client-2.py b/STests/Sync/test_sync_client-2.py similarity index 99% rename from Tests/Sync/test_sync_client-2.py rename to STests/Sync/test_sync_client-2.py index 96255185..94aed185 100644 --- a/Tests/Sync/test_sync_client-2.py +++ b/STests/Sync/test_sync_client-2.py @@ -14,11 +14,11 @@ import requests -from tldw_cli.tldw_app.DB.Media_DB import Database, DatabaseError +from tldw_chatbook.DB.Client_Media_DB_v2 import MediaDatabase as Database, DatabaseError sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) -from tldw_cli.tldw_app.DB.Sync_Client import ClientSyncEngine, SYNC_BATCH_SIZE # Your client sync engine +from tldw_chatbook.DB.Sync_Client import ClientSyncEngine, SYNC_BATCH_SIZE # Your client sync engine diff --git a/Tests/MediaDB2/test_sync_client.py b/STests/Sync/test_sync_client.py similarity index 100% rename from Tests/MediaDB2/test_sync_client.py rename to STests/Sync/test_sync_client.py diff --git a/Tests/MediaDB2/test_sync_server.py b/STests/Sync/test_sync_server.py similarity index 100% rename from Tests/MediaDB2/test_sync_server.py rename to STests/Sync/test_sync_server.py diff --git a/Tests/MediaDB2/Sync/__init__.py b/STests/__init__.py similarity index 100% rename from Tests/MediaDB2/Sync/__init__.py rename to STests/__init__.py diff --git a/Tests/ChaChaNotesDB/test_chachanotes_db.py b/Tests/ChaChaNotesDB/test_chachanotes_db.py index 9a8f69a0..7167f7b3 100644 --- a/Tests/ChaChaNotesDB/test_chachanotes_db.py +++ b/Tests/ChaChaNotesDB/test_chachanotes_db.py @@ -1,4 +1,4 @@ -# test_chacha_notes_db.py +# test_chachanotes_db.py # # # Imports @@ -8,27 +8,34 @@ import uuid from datetime import datetime, timezone from pathlib import Path -import os # For :memory: check + # # Third-Party Imports +from hypothesis import strategies as st +from hypothesis import given, settings, HealthCheck # # Local Imports -from tldw_Server_API.app.core.DB_Management.ChaChaNotes_DB import ( +# --- UPDATED IMPORT PATH --- +from tldw_chatbook.DB.ChaChaNotes_DB import ( CharactersRAGDB, CharactersRAGDBError, SchemaError, InputError, ConflictError ) + + # ####################################################################################################################### # # Functions: + # --- Fixtures --- @pytest.fixture def client_id(): + """Provides a consistent client ID for tests.""" return "test_client_001" @@ -39,11 +46,11 @@ def db_path(tmp_path): @pytest.fixture(scope="function") -def db_instance(db_path, client_id): # Add db_path and tmp_path back +def db_instance(db_path, client_id): """Creates a DB instance for each test, ensuring a fresh database.""" current_db_path = Path(db_path) - # Clean up any existing files + # Clean up any existing files from previous runs to be safe for suffix in ["", "-wal", "-shm"]: p = Path(str(current_db_path) + suffix) if p.exists(): @@ -54,12 +61,12 @@ def db_instance(db_path, client_id): # Add db_path and tmp_path back db = None try: - db = CharactersRAGDB(current_db_path, client_id) # Use current_db_path + db = CharactersRAGDB(current_db_path, client_id) yield db finally: if db: db.close_connection() - # Additional cleanup + # Additional cleanup after test completes for suffix in ["", "-wal", "-shm"]: p = Path(str(current_db_path) + suffix) if p.exists(): @@ -71,18 +78,39 @@ def db_instance(db_path, client_id): # Add db_path and tmp_path back @pytest.fixture def mem_db_instance(client_id): - """Creates an in-memory DB instance.""" + """Creates an in-memory DB instance for tests that don't need file persistence.""" db = CharactersRAGDB(":memory:", client_id) yield db db.close_connection() +@pytest.fixture +def sample_card(db_instance: CharactersRAGDB) -> dict: + """A fixture that adds a sample card to the DB and returns its data.""" + card_data = _create_sample_card_data("FromFixture") + card_id = db_instance.add_character_card(card_data) + # Return the full record from the DB, which includes ID, version, etc. + return db_instance.get_character_card_by_id(card_id) + +# You can create similar fixtures for conversations, messages, etc. +@pytest.fixture +def sample_conv(db_instance: CharactersRAGDB, sample_card: dict) -> dict: + """Adds a sample conversation linked to the sample_card.""" + conv_data = { + "character_id": sample_card['id'], + "title": "Conversation From Fixture" + } + conv_id = db_instance.add_conversation(conv_data) + return db_instance.get_conversation_by_id(conv_id) + # --- Helper Functions --- def get_current_utc_timestamp_iso(): + """Returns the current UTC time in ISO 8601 format, as used by the DB.""" return datetime.now(timezone.utc).isoformat(timespec='milliseconds').replace('+00:00', 'Z') def _create_sample_card_data(name_suffix="", client_id_override=None): + """Creates a sample character card data dictionary.""" return { "name": f"Test Character {name_suffix}", "description": "A test character.", @@ -90,18 +118,18 @@ def _create_sample_card_data(name_suffix="", client_id_override=None): "scenario": "A test scenario.", "image": b"testimagebytes", "first_message": "Hello, test!", - "alternate_greetings": json.dumps(["Hi", "Hey"]), # Ensure JSON strings for direct use + "alternate_greetings": json.dumps(["Hi", "Hey"]), "tags": json.dumps(["test", "sample"]), "extensions": json.dumps({"custom_field": "value"}), - "client_id": client_id_override # For testing specific client_id scenarios + "client_id": client_id_override } # --- Test Cases --- class TestDBInitialization: - def test_db_creation(self, db_path, client_id): - current_db_path = Path(db_path) # Ensure it's a Path object + def test_db_creation_and_schema_version(self, db_path, client_id): + current_db_path = Path(db_path) assert not current_db_path.exists() db = CharactersRAGDB(current_db_path, client_id) assert current_db_path.exists() @@ -109,33 +137,32 @@ def test_db_creation(self, db_path, client_id): # Check schema version conn = db.get_connection() - version = \ - conn.execute("SELECT version FROM db_schema_version WHERE schema_name = ?", (db._SCHEMA_NAME,)).fetchone()[ - 'version'] - assert version == db._CURRENT_SCHEMA_VERSION + version_row = conn.execute("SELECT version FROM db_schema_version WHERE schema_name = ?", + (db._SCHEMA_NAME,)).fetchone() + assert version_row is not None + assert version_row['version'] == db._CURRENT_SCHEMA_VERSION db.close_connection() - def test_in_memory_db(self, client_id): + def test_in_memory_db_initialization(self, client_id): db = CharactersRAGDB(":memory:", client_id) assert db.is_memory_db assert db.client_id == client_id - # Check schema version for in-memory conn = db.get_connection() - version = \ - conn.execute("SELECT version FROM db_schema_version WHERE schema_name = ?", (db._SCHEMA_NAME,)).fetchone()[ - 'version'] - assert version == db._CURRENT_SCHEMA_VERSION + version_row = conn.execute("SELECT version FROM db_schema_version WHERE schema_name = ?", + (db._SCHEMA_NAME,)).fetchone() + assert version_row is not None + assert version_row['version'] == db._CURRENT_SCHEMA_VERSION db.close_connection() - def test_missing_client_id(self, db_path): + def test_initialization_with_missing_client_id(self, db_path): with pytest.raises(ValueError, match="Client ID cannot be empty or None."): CharactersRAGDB(db_path, "") with pytest.raises(ValueError, match="Client ID cannot be empty or None."): CharactersRAGDB(db_path, None) - def test_reopen_db(self, db_path, client_id): + def test_reopening_db_preserves_schema(self, db_path, client_id): db1 = CharactersRAGDB(db_path, client_id) - v1 = db1._get_db_version(db1.get_connection()) # Assuming _get_db_version is still available for tests + v1 = db1._get_db_version(db1.get_connection()) db1.close_connection() db2 = CharactersRAGDB(db_path, "another_client") @@ -144,17 +171,16 @@ def test_reopen_db(self, db_path, client_id): assert v2 == CharactersRAGDB._CURRENT_SCHEMA_VERSION db2.close_connection() - def test_schema_newer_than_code(self, db_path, client_id): + def test_opening_db_with_newer_schema_raises_error(self, db_path, client_id): db = CharactersRAGDB(db_path, client_id) conn = db.get_connection() - # Manually set a newer version + newer_version = CharactersRAGDB._CURRENT_SCHEMA_VERSION + 1 conn.execute("UPDATE db_schema_version SET version = ? WHERE schema_name = ?", - (CharactersRAGDB._CURRENT_SCHEMA_VERSION + 1, CharactersRAGDB._SCHEMA_NAME)) + (newer_version, CharactersRAGDB._SCHEMA_NAME)) conn.commit() db.close_connection() - # Match the wrapped message from CharactersRAGDBError - expected_message_part = "Database initialization failed: Schema initialization/migration for 'rag_char_chat_schema' failed: Database schema 'rag_char_chat_schema' version .* is newer than supported by code" + expected_message_part = f"version \\({newer_version}\\) is newer than supported by code \\({CharactersRAGDB._CURRENT_SCHEMA_VERSION}\\)" with pytest.raises(CharactersRAGDBError, match=expected_message_part): CharactersRAGDB(db_path, client_id) @@ -163,27 +189,26 @@ class TestCharacterCards: def test_add_character_card(self, db_instance: CharactersRAGDB): card_data = _create_sample_card_data("Add") card_id = db_instance.add_character_card(card_data) - assert card_id is not None assert isinstance(card_id, int) retrieved = db_instance.get_character_card_by_id(card_id) assert retrieved is not None assert retrieved["name"] == card_data["name"] assert retrieved["description"] == card_data["description"] - assert retrieved["image"] == card_data["image"] # BLOB check - assert isinstance(retrieved["alternate_greetings"], list) # Check deserialization + assert retrieved["image"] == card_data["image"] + assert isinstance(retrieved["alternate_greetings"], list) assert retrieved["alternate_greetings"] == json.loads(card_data["alternate_greetings"]) - assert retrieved["client_id"] == db_instance.client_id # Ensure instance client_id is used + assert retrieved["client_id"] == db_instance.client_id assert retrieved["version"] == 1 assert not retrieved["deleted"] - def test_add_character_card_missing_name(self, db_instance: CharactersRAGDB): + def test_add_character_card_with_missing_name_raises_error(self, db_instance: CharactersRAGDB): card_data = _create_sample_card_data("MissingName") del card_data["name"] with pytest.raises(InputError, match="Required field 'name' is missing"): db_instance.add_character_card(card_data) - def test_add_character_card_duplicate_name(self, db_instance: CharactersRAGDB): + def test_add_character_card_with_duplicate_name_raises_error(self, db_instance: CharactersRAGDB): card_data = _create_sample_card_data("Duplicate") db_instance.add_character_card(card_data) with pytest.raises(ConflictError, match=f"Character card with name '{card_data['name']}' already exists"): @@ -194,220 +219,135 @@ def test_get_character_card_by_id_not_found(self, db_instance: CharactersRAGDB): def test_get_character_card_by_name(self, db_instance: CharactersRAGDB): card_data = _create_sample_card_data("ByName") - db_instance.add_character_card(card_data) + card_id = db_instance.add_character_card(card_data) retrieved = db_instance.get_character_card_by_name(card_data["name"]) assert retrieved is not None - assert retrieved["description"] == card_data["description"] + assert retrieved["id"] == card_id def test_list_character_cards(self, db_instance: CharactersRAGDB): - assert db_instance.list_character_cards() == [] + # A new DB instance should contain exactly one default card. + initial_cards = db_instance.list_character_cards() + assert len(initial_cards) == 1 + assert initial_cards[0]['name'] == 'Default Assistant' + card_data1 = _create_sample_card_data("List1") card_data2 = _create_sample_card_data("List2") db_instance.add_character_card(card_data1) db_instance.add_character_card(card_data2) + + # The list should now contain 3 cards (1 default + 2 new) cards = db_instance.list_character_cards() - assert len(cards) == 2 - # Sort by name for predictable order if names are unique and sortable - sorted_cards = sorted(cards, key=lambda c: c['name']) - assert sorted_cards[0]["name"] == card_data1["name"] - assert sorted_cards[1]["name"] == card_data2["name"] - - def test_update_character_card(self, db_instance: CharactersRAGDB): - card_data_initial = _create_sample_card_data("Update") - card_id = db_instance.add_character_card(card_data_initial) - assert card_id is not None - - original_card = db_instance.get_character_card_by_id(card_id) - assert original_card is not None - initial_expected_version = original_card['version'] # Should be 1 - - update_payload = { - "description": "Updated Description", # Keep this as it was working - "personality": "More Testy" # NEW: Very simple string for personality - } - - # Determine how many version bumps to expect from this payload - # (This is a simplified count for this test; the DB method handles the actual bumps) - num_updatable_fields_in_payload = 0 - if "description" in update_payload: - num_updatable_fields_in_payload += 1 - if "personality" in update_payload: - num_updatable_fields_in_payload += 1 - # Add other fields from update_payload if they are individually updated - - # If no actual fields are updated, metadata update still bumps version once - final_expected_version_bump = num_updatable_fields_in_payload if num_updatable_fields_in_payload > 0 else 1 - if not update_payload: # If payload is empty, version should still bump once due to metadata update - final_expected_version_bump = 1 - - updated = db_instance.update_character_card(card_id, update_payload, expected_version=initial_expected_version) + assert len(cards) == 3 + + # You can still sort and check your added cards if you filter out the default one. + added_card_names = {c['name'] for c in cards if c['name'] != 'Default Assistant'} + assert added_card_names == {card_data1['name'], card_data2['name']} + + def test_update_character_card(self, db_instance: CharactersRAGDB, sample_card: dict): + update_payload = {"description": "Updated Description"} + updated = db_instance.update_character_card( + sample_card['id'], + update_payload, + expected_version=sample_card['version'] + ) assert updated is True - retrieved = db_instance.get_character_card_by_id(card_id) - assert retrieved is not None + retrieved = db_instance.get_character_card_by_id(sample_card['id']) assert retrieved["description"] == "Updated Description" - assert retrieved["personality"] == "More Testy" - assert retrieved["name"] == card_data_initial["name"] # Unchanged + assert retrieved["version"] == sample_card['version'] + 1 - # Adjust expected version based on sequential updates - assert retrieved["version"] == initial_expected_version + 1 - - def test_update_character_card_version_conflict(self, db_instance: CharactersRAGDB): - card_data = _create_sample_card_data("VersionConflict") - card_id = db_instance.add_character_card(card_data) - assert card_id is not None - - # Original version is 1 - client_expected_version = 1 + def test_update_character_card_with_version_conflict_raises_error(self, db_instance: CharactersRAGDB): + card_id = db_instance.add_character_card(_create_sample_card_data("VersionConflict")) # Simulate another client's update, bumping DB version to 2 - conn = db_instance.get_connection() - conn.execute("UPDATE character_cards SET version = 2, client_id = 'other_client' WHERE id = ?", (card_id,)) - conn.commit() + db_instance.update_character_card(card_id, {"description": "First update"}, expected_version=1) + # Client tries to update with old expected_version=1 update_payload = {"description": "Conflict Update"} - # Updated match string to be more flexible or exact - expected_error_regex = r"Update failed: version mismatch \(db has 2, client expected 1\) for character_cards ID 1\." + expected_error_regex = r"version mismatch \(db has 2, client expected 1\)" with pytest.raises(ConflictError, match=expected_error_regex): - db_instance.update_character_card(card_id, update_payload, expected_version=client_expected_version) + db_instance.update_character_card(card_id, update_payload, expected_version=1) - def test_update_character_card_not_found(self, db_instance: CharactersRAGDB): - with pytest.raises(ConflictError, - match="Record not found in character_cards."): # Match new _get_current_db_version error + def test_update_character_card_not_found_raises_error(self, db_instance: CharactersRAGDB): + with pytest.raises(ConflictError, match="Record not found in character_cards"): db_instance.update_character_card(999, {"description": "Not Found"}, expected_version=1) - def test_soft_delete_character_card(self, db_instance: CharactersRAGDB): - card_data = _create_sample_card_data("Delete") - card_id = db_instance.add_character_card(card_data) - assert card_id is not None - - original_card = db_instance.get_character_card_by_id(card_id) - assert original_card is not None - expected_version_for_first_delete = original_card['version'] # Should be 1 - - deleted = db_instance.soft_delete_character_card(card_id, expected_version=expected_version_for_first_delete) + def test_soft_delete_character_card(self, db_instance: CharactersRAGDB, sample_card: dict): + deleted = db_instance.soft_delete_character_card( + sample_card['id'], + expected_version=sample_card['version'] + ) assert deleted is True + assert db_instance.get_character_card_by_id(sample_card['id']) is None - retrieved_after_first_delete = db_instance.get_character_card_by_id(card_id) # Should be None - assert retrieved_after_first_delete is None + def test_soft_delete_is_idempotent(self, db_instance: CharactersRAGDB): + card_id = db_instance.add_character_card(_create_sample_card_data("IdempotentDelete")) + db_instance.soft_delete_character_card(card_id, expected_version=1) + # Calling delete again on an already deleted record should succeed + assert db_instance.soft_delete_character_card(card_id, expected_version=1) is True + # Verify version didn't change again conn = db_instance.get_connection() - raw_retrieved_after_first_delete = conn.execute("SELECT * FROM character_cards WHERE id = ?", - (card_id,)).fetchone() - assert raw_retrieved_after_first_delete is not None - assert raw_retrieved_after_first_delete["deleted"] == 1 - assert raw_retrieved_after_first_delete["version"] == expected_version_for_first_delete + 1 # Version is now 2 - - # Attempt to delete again with the *original* expected_version (which is now incorrect: 1). - # The soft_delete_character_card method should recognize the card is already deleted - # and treat this as an idempotent success, returning True. - # The internal _get_current_db_version would raise "Record is soft-deleted", - # which soft_delete_character_card catches and handles. - assert db_instance.soft_delete_character_card(card_id, - expected_version=expected_version_for_first_delete) is True - - # Verify version didn't change again (it's still 2 from the first delete) - still_deleted_card_info = conn.execute("SELECT version, deleted FROM character_cards WHERE id = ?", - (card_id,)).fetchone() - assert still_deleted_card_info is not None - assert still_deleted_card_info["deleted"] == 1 - assert still_deleted_card_info['version'] == expected_version_for_first_delete + 1 # Still version 2 - - # Test idempotent success: calling soft_delete on an already deleted record - # with its *current correct version* should also succeed. - current_version_of_deleted_card = raw_retrieved_after_first_delete['version'] # This is 2 - assert db_instance.soft_delete_character_card(card_id, expected_version=current_version_of_deleted_card) is True + raw_record = conn.execute("SELECT version FROM character_cards WHERE id = ?", (card_id,)).fetchone() + assert raw_record["version"] == 2 def test_search_character_cards(self, db_instance: CharactersRAGDB): - card1_data = _create_sample_card_data("Searchable Alpha") + card1_data = _create_sample_card_data("Search Alpha") card1_data["description"] = "Unique keyword: ZYX" - card2_data = _create_sample_card_data("Searchable Beta") - card2_data["system_prompt"] = "Contains ZYX too" + card2_data = _create_sample_card_data("Search Beta") + card2_data["system_prompt"] = "Also has ZYX" card3_data = _create_sample_card_data("Unsearchable") db_instance.add_character_card(card1_data) - db_instance.add_character_card(card2_data) + card2_id = db_instance.add_character_card(card2_data) db_instance.add_character_card(card3_data) results = db_instance.search_character_cards("ZYX") assert len(results) == 2 - names = [r["name"] for r in results] + names = {r["name"] for r in results} assert card1_data["name"] in names assert card2_data["name"] in names - # Test search after delete - card1 = db_instance.get_character_card_by_name(card1_data["name"]) - assert card1 is not None - db_instance.soft_delete_character_card(card1["id"], expected_version=card1["version"]) + # Test search after soft-deleting one of the results + card2 = db_instance.get_character_card_by_id(card2_id) + db_instance.soft_delete_character_card(card2["id"], expected_version=card2["version"]) + results_after_delete = db_instance.search_character_cards("ZYX") assert len(results_after_delete) == 1 - assert results_after_delete[0]["name"] == card2_data["name"] - + assert results_after_delete[0]["name"] == card1_data["name"] + + @pytest.mark.parametrize( + "field_to_remove, expected_error, error_match", + [ + ("name", InputError, "Required field 'name' is missing"), + # Assuming you add a required 'creator' field later + # ("creator", InputError, "Required field 'creator' is missing"), + ] + ) + def test_add_card_missing_required_fields(self, db_instance, field_to_remove, expected_error, error_match): + card_data = _create_sample_card_data("MissingFields") + del card_data[field_to_remove] + with pytest.raises(expected_error, match=error_match): + db_instance.add_character_card(card_data) class TestConversationsAndMessages: @pytest.fixture def char_id(self, db_instance): card_id = db_instance.add_character_card(_create_sample_card_data("ConvChar")) - assert card_id is not None return card_id def test_add_conversation(self, db_instance: CharactersRAGDB, char_id): - conv_data = { - "id": str(uuid.uuid4()), - "character_id": char_id, - "title": "Test Conversation" - } + conv_data = {"id": str(uuid.uuid4()), "character_id": char_id, "title": "Test Conversation"} conv_id = db_instance.add_conversation(conv_data) assert conv_id == conv_data["id"] retrieved = db_instance.get_conversation_by_id(conv_id) - assert retrieved is not None assert retrieved["title"] == "Test Conversation" assert retrieved["character_id"] == char_id - assert retrieved["root_id"] == conv_id # Default root_id assert retrieved["version"] == 1 assert retrieved["client_id"] == db_instance.client_id - def test_add_conversation_duplicate_id(self, db_instance: CharactersRAGDB, char_id): - conv_id_val = str(uuid.uuid4()) - conv_data = {"id": conv_id_val, "character_id": char_id, "title": "First"} - db_instance.add_conversation(conv_data) - - conv_data_dup = {"id": conv_id_val, "character_id": char_id, "title": "Duplicate"} - with pytest.raises(ConflictError, match=f"Conversation with ID '{conv_id_val}' already exists"): - db_instance.add_conversation(conv_data_dup) - - def test_add_message(self, db_instance: CharactersRAGDB, char_id): + def test_add_message_and_get_for_conversation(self, db_instance: CharactersRAGDB, char_id): conv_id = db_instance.add_conversation({"character_id": char_id, "title": "MsgConv"}) - assert conv_id is not None - - msg_data = { - "conversation_id": conv_id, - "sender": "user", - "content": "Hello there!" - } - msg_id = db_instance.add_message(msg_data) - assert msg_id is not None - - retrieved_msg = db_instance.get_message_by_id(msg_id) - assert retrieved_msg is not None - assert retrieved_msg["sender"] == "user" - assert retrieved_msg["content"] == "Hello there!" - assert retrieved_msg["conversation_id"] == conv_id - assert retrieved_msg["version"] == 1 - assert retrieved_msg["client_id"] == db_instance.client_id - - # Test adding message to non-existent conversation - msg_data_bad_conv = { - "conversation_id": str(uuid.uuid4()), - "sender": "user", - "content": "Test" - } - with pytest.raises(InputError, match="Cannot add message: Conversation ID .* not found or deleted"): - db_instance.add_message(msg_data_bad_conv) - - def test_get_messages_for_conversation_ordering(self, db_instance: CharactersRAGDB, char_id): - conv_id = db_instance.add_conversation({"character_id": char_id, "title": "OrderedMsgConv"}) - assert conv_id is not None msg1_id = db_instance.add_message( {"conversation_id": conv_id, "sender": "user", "content": "First", "timestamp": "2023-01-01T10:00:00Z"}) msg2_id = db_instance.add_message( @@ -423,517 +363,206 @@ def test_get_messages_for_conversation_ordering(self, db_instance: CharactersRAG assert messages_desc[0]["id"] == msg2_id assert messages_desc[1]["id"] == msg1_id - def test_update_conversation(self, db_instance: CharactersRAGDB, char_id: int): - # 1. Setup: Add an initial conversation with a SIMPLE title - initial_title = "AlphaTitleOne" # Simple, unique for this test run - conv_id = db_instance.add_conversation({ - "character_id": char_id, - "title": initial_title - }) - assert conv_id is not None, "Failed to add initial conversation" - - # 2. Verify initial state in main table - original_conv_main = db_instance.get_conversation_by_id(conv_id) - assert original_conv_main is not None, "Failed to retrieve initial conversation" - assert original_conv_main['title'] == initial_title, "Initial title mismatch in main table" - initial_expected_version = original_conv_main['version'] - assert initial_expected_version == 1, "Initial version should be 1" - - # 3. Verify initial FTS state (new title is searchable) - try: - initial_fts_results = db_instance.search_conversations_by_title(initial_title) - assert len(initial_fts_results) == 1, \ - f"FTS Pre-Update: Expected 1 result for '{initial_title}', got {len(initial_fts_results)}. Results: {initial_fts_results}" - assert initial_fts_results[0]['id'] == conv_id, \ - f"FTS Pre-Update: Conversation ID {conv_id} not found in initial search results for '{initial_title}'." - except Exception as e: - pytest.fail(f"Failed during initial FTS check for '{initial_title}': {e}") - - # 4. Define update payload with a SIMPLE, DIFFERENT title - updated_title = "BetaTitleTwo" # Simple, unique, different from initial_title - updated_rating = 5 - update_payload = {"title": updated_title, "rating": updated_rating} - - # 5. Perform the update - updated = db_instance.update_conversation(conv_id, update_payload, expected_version=initial_expected_version) - assert updated is True, "update_conversation returned False" - - # 6. Verify updated state in main table - retrieved_after_update = db_instance.get_conversation_by_id(conv_id) - assert retrieved_after_update is not None, "Failed to retrieve conversation after update" - assert retrieved_after_update["title"] == updated_title, "Title was not updated correctly in main table" - assert retrieved_after_update["rating"] == updated_rating, "Rating was not updated correctly" - assert retrieved_after_update[ - "version"] == initial_expected_version + 1, "Version did not increment correctly after update" - - # 7. Verify FTS state after update - # Search for the NEW title - try: - search_results_new_title = db_instance.search_conversations_by_title(updated_title) - assert len(search_results_new_title) == 1, \ - f"FTS Post-Update: Expected 1 result for new title '{updated_title}', got {len(search_results_new_title)}. Results: {search_results_new_title}" - assert search_results_new_title[0]['id'] == conv_id, \ - f"FTS Post-Update: Conversation ID {conv_id} not found in search results for new title '{updated_title}'." - except Exception as e: - pytest.fail(f"Failed during FTS check for new title '{updated_title}': {e}") - - # Search for the OLD title - try: - search_results_old_title = db_instance.search_conversations_by_title(initial_title) - found_old_title_for_this_conv_via_match = any(r['id'] == conv_id for r in search_results_old_title) - - if found_old_title_for_this_conv_via_match: - print( - f"\nINFO (FTS Nuance): FTS MATCH found old title '{initial_title}' for conv_id {conv_id} immediately after update.") - print(f" Search results for old title via MATCH: {search_results_old_title}") - - # Verify the actual content stored in FTS table for this rowid - # This confirms the trigger updated the FTS table's data record. - conn_debug = db_instance.get_connection() # Get a connection - - # Get the rowid of the conversation from the main table first - main_table_rowid_cursor = conn_debug.execute("SELECT rowid FROM conversations WHERE id = ?", (conv_id,)) - main_table_rowid_row = main_table_rowid_cursor.fetchone() - assert main_table_rowid_row is not None, f"Could not fetch rowid from main 'conversations' table for id {conv_id}" - target_conv_rowid = main_table_rowid_row['rowid'] - - fts_content_cursor = conn_debug.execute( - "SELECT title FROM conversations_fts WHERE rowid = ?", - (target_conv_rowid,) # Use the actual rowid from the main table - ) - fts_content_row = fts_content_cursor.fetchone() - - current_fts_content_title = "FTS ROW NOT FOUND (SHOULD EXIST)" - if fts_content_row: - current_fts_content_title = fts_content_row['title'] - else: # This case should ideally not happen if the new title was inserted - print( - f"ERROR: FTS row for rowid {target_conv_rowid} not found directly after update, but MATCH found it.") - - print( - f" Actual content in conversations_fts.title for rowid {target_conv_rowid}: '{current_fts_content_title}'") - - assert current_fts_content_title == updated_title, \ - f"FTS CONTENT CHECK FAILED: Stored FTS content for rowid {target_conv_rowid} of conv_id {conv_id} was '{current_fts_content_title}', expected '{updated_title}'." - - # The following assertion is expected to FAIL due to FTS5 MATCH "stickiness" - # It demonstrates that while the FTS data record is updated, MATCH might still find old terms immediately. - # To make the overall test "pass" while acknowledging this, this line would be commented out or adjusted. - # assert not found_old_title_for_this_conv_via_match, \ - # (f"FTS MATCH BEHAVIOR: Old title '{initial_title}' was STILL MATCHED for conversation ID {conv_id} " - # f"after update, even though FTS content for its rowid ({target_conv_rowid}) is now '{current_fts_content_title}'. " - # f"This highlights FTS5's eventual consistency for MATCH queries post-update.") - else: - # This is the ideal immediate outcome: old title is not found by MATCH. - assert not found_old_title_for_this_conv_via_match # This will pass if branch is taken - - except Exception as e: - pytest.fail(f"Failed during FTS check for old title '{initial_title}': {e}") - - def test_soft_delete_conversation_and_messages(self, db_instance: CharactersRAGDB, char_id): - # Setup: Conversation with messages - conv_title_for_delete_test = "DeleteConvForFTS" - conv_id = db_instance.add_conversation({"character_id": char_id, "title": conv_title_for_delete_test}) - assert conv_id is not None - msg1_id = db_instance.add_message({"conversation_id": conv_id, "sender": "user", "content": "Msg1"}) - assert msg1_id is not None - + def test_update_conversation_and_fts(self, db_instance: CharactersRAGDB, char_id: int): + initial_title = "AlphaTitleOne" + conv_id = db_instance.add_conversation({"character_id": char_id, "title": initial_title}) original_conv = db_instance.get_conversation_by_id(conv_id) - assert original_conv is not None - expected_version = original_conv['version'] - # Verify it's in FTS before delete - results_before_delete = db_instance.search_conversations_by_title(conv_title_for_delete_test) - assert len(results_before_delete) == 1, "Conversation should be in FTS before soft delete" - assert results_before_delete[0]['id'] == conv_id + # Verify FTS state before update + assert len(db_instance.search_conversations_by_title(initial_title)) == 1 - # Soft delete conversation - deleted = db_instance.soft_delete_conversation(conv_id, expected_version=expected_version) - assert deleted is True - assert db_instance.get_conversation_by_id(conv_id) is None + # Perform update + updated_title = "BetaTitleTwo" + db_instance.update_conversation(conv_id, {"title": updated_title}, expected_version=original_conv['version']) - msg1 = db_instance.get_message_by_id(msg1_id) - assert msg1 is not None - assert msg1["conversation_id"] == conv_id + # Verify FTS state after update + assert len(db_instance.search_conversations_by_title(updated_title)) == 1 + assert len(db_instance.search_conversations_by_title(initial_title)) == 0, "FTS should not find the old title" - # FTS search for conversation should not find it - # UNCOMMENTED AND VERIFIED: - results_after_delete = db_instance.search_conversations_by_title(conv_title_for_delete_test) - assert len(results_after_delete) == 0, "FTS search should not find the soft-deleted conversation" + def test_soft_delete_conversation_and_fts(self, db_instance: CharactersRAGDB, char_id): + conv_title_for_delete_test = "DeleteConvForFTS" + conv_id = db_instance.add_conversation({"character_id": char_id, "title": conv_title_for_delete_test}) + original_conv = db_instance.get_conversation_by_id(conv_id) - def test_conversation_fts_search(self, db_instance: CharactersRAGDB, char_id): - conv_id1 = db_instance.add_conversation({"character_id": char_id, "title": "Unique Alpha Search Term"}) - conv_id2 = db_instance.add_conversation({"character_id": char_id, "title": "Another Alpha For Test"}) - db_instance.add_conversation({"character_id": char_id, "title": "Beta Content Only"}) + assert len(db_instance.search_conversations_by_title(conv_title_for_delete_test)) == 1 - results_alpha = db_instance.search_conversations_by_title("Alpha") - assert len(results_alpha) == 2 - found_ids_alpha = {r['id'] for r in results_alpha} - assert conv_id1 in found_ids_alpha - assert conv_id2 in found_ids_alpha + db_instance.soft_delete_conversation(conv_id, expected_version=original_conv['version']) - results_unique = db_instance.search_conversations_by_title("Unique") - assert len(results_unique) == 1 - assert results_unique[0]['id'] == conv_id1 + assert db_instance.get_conversation_by_id(conv_id) is None + assert len(db_instance.search_conversations_by_title( + conv_title_for_delete_test)) == 0, "FTS should not find soft-deleted conversation" - def test_search_messages_by_content_FIXED_JOIN(self, db_instance: CharactersRAGDB, char_id): - # This test specifically validates the FTS join fix for messages (TEXT PK) + def test_search_messages_by_content(self, db_instance: CharactersRAGDB, char_id): conv_id = db_instance.add_conversation({"character_id": char_id, "title": "MessageSearchConv"}) - assert conv_id is not None msg1_data = {"id": str(uuid.uuid4()), "conversation_id": conv_id, "sender": "user", "content": "UniqueMessageContentAlpha"} - msg2_data = {"id": str(uuid.uuid4()), "conversation_id": conv_id, "sender": "ai", "content": "Another phrase"} - db_instance.add_message(msg1_data) - db_instance.add_message(msg2_data) results = db_instance.search_messages_by_content("UniqueMessageContentAlpha") assert len(results) == 1 assert results[0]["id"] == msg1_data["id"] - assert results[0]["content"] == msg1_data["content"] - - # Test search within a specific conversation - results_conv_specific = db_instance.search_messages_by_content("UniqueMessageContentAlpha", - conversation_id=conv_id) - assert len(results_conv_specific) == 1 - assert results_conv_specific[0]["id"] == msg1_data["id"] - - # Test search for content in another conversation (should not be found if conv_id is specified) - other_conv_id = db_instance.add_conversation({"character_id": char_id, "title": "Other MessageSearchConv"}) - assert other_conv_id is not None - db_instance.add_message({"id": str(uuid.uuid4()), "conversation_id": other_conv_id, "sender": "user", - "content": "UniqueMessageContentAlpha In Other"}) - results_other_conv = db_instance.search_messages_by_content("UniqueMessageContentAlpha", - conversation_id=other_conv_id) - assert len(results_other_conv) == 1 - assert results_other_conv[0]["content"] == "UniqueMessageContentAlpha In Other" - results_original_conv_again = db_instance.search_messages_by_content("UniqueMessageContentAlpha", - conversation_id=conv_id) - assert len(results_original_conv_again) == 1 - assert results_original_conv_again[0]["id"] == msg1_data["id"] - - -class TestNotes: - def test_add_note(self, db_instance: CharactersRAGDB): - note_id = db_instance.add_note("Test Note Title", "This is the content of the note.") - assert isinstance(note_id, str) # UUID - - retrieved = db_instance.get_note_by_id(note_id) - assert retrieved is not None - assert retrieved["title"] == "Test Note Title" - assert retrieved["content"] == "This is the content of the note." - assert retrieved["version"] == 1 - assert not retrieved["deleted"] - - def test_add_note_empty_title(self, db_instance: CharactersRAGDB): - with pytest.raises(InputError, match="Note title cannot be empty."): - db_instance.add_note("", "Content") - - def test_add_note_duplicate_id(self, db_instance: CharactersRAGDB): - fixed_id = str(uuid.uuid4()) - db_instance.add_note("First Note", "Content1", note_id=fixed_id) - with pytest.raises(ConflictError, match=f"Note with ID '{fixed_id}' already exists."): - db_instance.add_note("Second Note", "Content2", note_id=fixed_id) - - def test_update_note(self, db_instance: CharactersRAGDB): + # @pytest.mark.parametrize( + # "msg_data, raises_error", + # [ + # ({"content": "Hello", "image_data": None, "image_mime_type": None}, False), + # ({"content": "", "image_data": b'img', "image_mime_type": "image/png"}, False), + # ({"content": "Hello", "image_data": b'img', "image_mime_type": "image/png"}, False), + # # Failure cases + # ({"content": "", "image_data": None, "image_mime_type": None}, True), # Both missing + # ({"content": None, "image_data": None, "image_mime_type": None}, True), # Both missing + # ({"content": "", "image_data": b'img', "image_mime_type": None}, True), # Mime type missing + # ] + # ) + # def test_add_message_content_requirements(self, db_instance, sample_conv, msg_data, raises_error): + # full_payload = { + # "conversation_id": sample_conv['id'], + # "sender": "user", + # **msg_data + # } + # + # if raises_error: + # with pytest.raises((InputError, TypeError)): # TypeError if content is None + # db_instance.add_message(full_payload) + # else: + # msg_id = db_instance.add_message(full_payload) + # assert msg_id is not None + + + +class TestNotesAndKeywords: + def test_add_and_update_note(self, db_instance: CharactersRAGDB): note_id = db_instance.add_note("Original Title", "Original Content") - assert note_id is not None + assert isinstance(note_id, str) original_note = db_instance.get_note_by_id(note_id) - assert original_note is not None - expected_version = original_note['version'] # Should be 1 - - updated = db_instance.update_note(note_id, {"title": "Updated Title", "content": "Updated Content"}, - expected_version=expected_version) + updated = db_instance.update_note(note_id, {"title": "Updated Title"}, + expected_version=original_note['version']) assert updated is True retrieved = db_instance.get_note_by_id(note_id) - assert retrieved is not None assert retrieved["title"] == "Updated Title" - assert retrieved["content"] == "Updated Content" - assert retrieved["version"] == expected_version + 1 - - def test_list_notes(self, db_instance: CharactersRAGDB): - assert db_instance.list_notes() == [] - id1 = db_instance.add_note("Note A", "Content A") - # Introduce a slight delay or ensure timestamps are distinct if relying on last_modified for order - # For this test, assuming add_note sets distinct last_modified or order is by insertion for simple tests - id2 = db_instance.add_note("Note B", "Content B") - notes = db_instance.list_notes() - assert len(notes) == 2 - # Default order is last_modified DESC - # To make it robust, fetch and compare timestamps or ensure test data forces order - # For simplicity, if Note B is added after Note A, it should appear first - note_ids_in_order = [n['id'] for n in notes] - if id1 and id2: # Ensure they were created - # This assertion depends on the exact timing of creation and how last_modified is set. - # A more robust test would explicitly set created_at/last_modified if possible, - # or query and sort by a reliable field. - # For now, we assume recent additions are first due to DESC order. - assert note_ids_in_order[0] == id2 - assert note_ids_in_order[1] == id1 - - def test_search_notes(self, db_instance: CharactersRAGDB): - db_instance.add_note("Alpha Note", "Contains a keyword ZYX") - db_instance.add_note("Beta Note", "Another one with ZYX in title") - db_instance.add_note("Gamma Note", "Nothing special") - - # DEBUGGING: - # conn = db_instance.get_connection() - # fts_content = conn.execute("SELECT rowid, title, content FROM notes_fts;").fetchall() - # print("\nNotes FTS Content:") - # for row in fts_content: - # print(dict(row)) - # END DEBUGGING - - results = db_instance.search_notes("ZYX") - assert len(results) == 2 - titles = sorted([r['title'] for r in results]) # Sort for predictable assertion - assert titles == ["Alpha Note", "Beta Note"] - + assert retrieved["version"] == original_note['version'] + 1 -class TestKeywordsAndCollections: - def test_add_keyword(self, db_instance: CharactersRAGDB): - keyword_id = db_instance.add_keyword(" TestKeyword ") # Test stripping - assert keyword_id is not None - retrieved = db_instance.get_keyword_by_id(keyword_id) - assert retrieved is not None - assert retrieved["keyword"] == "TestKeyword" - assert retrieved["version"] == 1 + def test_add_keyword_and_undelete(self, db_instance: CharactersRAGDB): + keyword_id = db_instance.add_keyword("TestKeyword") + kw_v1 = db_instance.get_keyword_by_id(keyword_id) - def test_add_keyword_duplicate_active(self, db_instance: CharactersRAGDB): - db_instance.add_keyword("UniqueKeyword") - with pytest.raises(ConflictError, match="'UniqueKeyword' already exists and is active"): - db_instance.add_keyword("UniqueKeyword") + db_instance.soft_delete_keyword(keyword_id, expected_version=kw_v1['version']) + assert db_instance.get_keyword_by_id(keyword_id) is None - def test_add_keyword_undelete(self, db_instance: CharactersRAGDB): - keyword_id = db_instance.add_keyword("ToDeleteAndReadd") - assert keyword_id is not None - - # Get current version for soft delete - keyword_v1 = db_instance.get_keyword_by_id(keyword_id) - assert keyword_v1 is not None - - db_instance.soft_delete_keyword(keyword_id, expected_version=keyword_v1['version']) # v2, deleted - - # Adding same keyword should undelete and update - # The add_keyword method's undelete logic might not need an explicit expected_version - # from the client for this specific "add which might undelete" scenario. - # It internally handles its own version check if it finds a deleted record. - new_keyword_id = db_instance.add_keyword("ToDeleteAndReadd") - assert new_keyword_id == keyword_id # Should be the same ID + # Adding same keyword again should undelete it + new_keyword_id = db_instance.add_keyword("TestKeyword") + assert new_keyword_id == keyword_id retrieved = db_instance.get_keyword_by_id(keyword_id) - assert retrieved is not None assert not retrieved["deleted"] - # Version logic: - # 1 (initial add) - # 2 (soft_delete_keyword with expected_version=1) - # 3 (add_keyword causing undelete, which itself bumps version) - assert retrieved["version"] == 3 - - def test_add_keyword_collection(self, db_instance: CharactersRAGDB): - coll_id = db_instance.add_keyword_collection("My Collection") - assert coll_id is not None - retrieved = db_instance.get_keyword_collection_by_id(coll_id) - assert retrieved is not None - assert retrieved["name"] == "My Collection" - assert retrieved["parent_id"] is None + assert retrieved["version"] == 3 # 1(add) -> 2(delete) -> 3(undelete/update) - child_coll_id = db_instance.add_keyword_collection("Child Collection", parent_id=coll_id) - assert child_coll_id is not None - retrieved_child = db_instance.get_keyword_collection_by_id(child_coll_id) - assert retrieved_child is not None - assert retrieved_child["parent_id"] == coll_id - - def test_link_conversation_to_keyword(self, db_instance: CharactersRAGDB): + def test_link_and_unlink_conversation_to_keyword(self, db_instance: CharactersRAGDB): char_id = db_instance.add_character_card(_create_sample_card_data("LinkChar")) - assert char_id is not None conv_id = db_instance.add_conversation({"character_id": char_id, "title": "LinkConv"}) - assert conv_id is not None kw_id = db_instance.add_keyword("Linkable") - assert kw_id is not None assert db_instance.link_conversation_to_keyword(conv_id, kw_id) is True keywords = db_instance.get_keywords_for_conversation(conv_id) assert len(keywords) == 1 assert keywords[0]["id"] == kw_id - # Test idempotency - assert db_instance.link_conversation_to_keyword(conv_id, kw_id) is False # Already linked + # Test idempotency of linking + assert db_instance.link_conversation_to_keyword(conv_id, kw_id) is False - # Test unlinking assert db_instance.unlink_conversation_from_keyword(conv_id, kw_id) is True assert len(db_instance.get_keywords_for_conversation(conv_id)) == 0 - assert db_instance.unlink_conversation_from_keyword(conv_id, kw_id) is False # Already unlinked - # Similar tests for other link types: - # link_collection_to_keyword, link_note_to_keyword + # Test idempotency of unlinking + assert db_instance.unlink_conversation_from_keyword(conv_id, kw_id) is False class TestSyncLog: - def test_sync_log_entry_on_add_character(self, db_instance: CharactersRAGDB): + def test_sync_log_entry_on_add_and_update_character(self, db_instance: CharactersRAGDB): initial_log_max_id = db_instance.get_latest_sync_log_change_id() card_data = _create_sample_card_data("SyncLogChar") card_id = db_instance.add_character_card(card_data) - assert card_id is not None - - log_entries = db_instance.get_sync_log_entries(since_change_id=initial_log_max_id) # Get new entries - - char_log_entry = None - for entry in log_entries: # Search among new entries - if entry["entity"] == "character_cards" and entry["entity_id"] == str(card_id) and entry[ - "operation"] == "create": - char_log_entry = entry - break - - assert char_log_entry is not None - assert char_log_entry["payload"]["name"] == card_data["name"] - assert char_log_entry["payload"]["version"] == 1 - assert char_log_entry["client_id"] == db_instance.client_id - - def test_sync_log_entry_on_update_character(self, db_instance: CharactersRAGDB): - card_id = db_instance.add_character_card(_create_sample_card_data("SyncUpdateChar")) - assert card_id is not None - - original_card = db_instance.get_character_card_by_id(card_id) - assert original_card is not None - expected_version = original_card['version'] - - latest_change_id = db_instance.get_latest_sync_log_change_id() - - db_instance.update_character_card(card_id, {"description": "Updated for Sync"}, - expected_version=expected_version) - - new_entries = db_instance.get_sync_log_entries(since_change_id=latest_change_id) - assert len(new_entries) >= 1 - update_log_entry = None - for entry in new_entries: - if entry["entity"] == "character_cards" and entry["entity_id"] == str(card_id) and entry[ - "operation"] == "update": - update_log_entry = entry - break - - assert update_log_entry is not None - assert update_log_entry["payload"]["description"] == "Updated for Sync" - assert update_log_entry["payload"]["version"] == expected_version + 1 - - def test_sync_log_entry_on_soft_delete_character(self, db_instance: CharactersRAGDB): + log_entries = db_instance.get_sync_log_entries(since_change_id=initial_log_max_id) + create_entry = next((e for e in log_entries if e["entity"] == "character_cards" and e["operation"] == "create"), + None) + assert create_entry is not None + assert create_entry["entity_id"] == str(card_id) + assert create_entry["payload"]["name"] == card_data["name"] + + # Test update + latest_change_id_after_add = db_instance.get_latest_sync_log_change_id() + db_instance.update_character_card(card_id, {"description": "Updated for Sync"}, expected_version=1) + + update_log_entries = db_instance.get_sync_log_entries(since_change_id=latest_change_id_after_add) + update_entry = next( + (e for e in update_log_entries if e["entity"] == "character_cards" and e["operation"] == "update"), None) + assert update_entry is not None + assert update_entry["payload"]["description"] == "Updated for Sync" + assert update_entry["payload"]["version"] == 2 + + def test_sync_log_on_soft_delete_character(self, db_instance: CharactersRAGDB): card_id = db_instance.add_character_card(_create_sample_card_data("SyncDeleteChar")) - assert card_id is not None - - original_card = db_instance.get_character_card_by_id(card_id) - assert original_card is not None - expected_version = original_card['version'] - latest_change_id = db_instance.get_latest_sync_log_change_id() - db_instance.soft_delete_character_card(card_id, expected_version=expected_version) + db_instance.soft_delete_character_card(card_id, expected_version=1) new_entries = db_instance.get_sync_log_entries(since_change_id=latest_change_id) - delete_log_entry = None - for entry in new_entries: - if entry["entity"] == "character_cards" and entry["entity_id"] == str(card_id) and entry[ - "operation"] == "delete": - delete_log_entry = entry - break - - assert delete_log_entry is not None - # assert delete_log_entry["payload"]["deleted"] is True # Original failing line - assert delete_log_entry["payload"]["deleted"] == 1 # If JSON payload has integer 1 for true - # OR, if you expect a boolean true after json.loads and your DB stores it in a way that json.loads makes it bool: - # assert delete_log_entry["payload"]["deleted"] is True - # For SQLite storing boolean as 0/1, json.loads(payload_with_integer_1) will keep it as integer 1. - assert delete_log_entry["payload"]["version"] == expected_version + 1 + delete_entry = next((e for e in new_entries if e["entity"] == "character_cards" and e["operation"] == "delete"), + None) + assert delete_entry is not None + assert delete_entry["entity_id"] == str(card_id) + assert delete_entry["payload"]["deleted"] == 1 # Stored as integer + assert delete_entry["payload"]["version"] == 2 def test_sync_log_for_link_tables(self, db_instance: CharactersRAGDB): char_id = db_instance.add_character_card(_create_sample_card_data("SyncLinkChar")) - assert char_id is not None conv_id = db_instance.add_conversation({"character_id": char_id, "title": "SyncLinkConv"}) - assert conv_id is not None kw_id = db_instance.add_keyword("SyncLinkable") - assert kw_id is not None - latest_change_id = db_instance.get_latest_sync_log_change_id() + db_instance.link_conversation_to_keyword(conv_id, kw_id) - new_entries = db_instance.get_sync_log_entries(since_change_id=latest_change_id) - link_log_entry = None - expected_entity_id = f"{conv_id}_{kw_id}" - for entry in new_entries: - if entry["entity"] == "conversation_keywords" and entry["entity_id"] == expected_entity_id and entry[ - "operation"] == "create": - link_log_entry = entry - break - - assert link_log_entry is not None - assert link_log_entry["payload"]["conversation_id"] == conv_id - assert link_log_entry["payload"]["keyword_id"] == kw_id + link_entries = db_instance.get_sync_log_entries(since_change_id=latest_change_id) + link_entry = next( + (e for e in link_entries if e["entity"] == "conversation_keywords" and e["operation"] == "create"), None) + assert link_entry is not None + assert link_entry["payload"]["conversation_id"] == conv_id + assert link_entry["payload"]["keyword_id"] == kw_id + # Test unlink latest_change_id_after_link = db_instance.get_latest_sync_log_change_id() db_instance.unlink_conversation_from_keyword(conv_id, kw_id) - new_entries_unlink = db_instance.get_sync_log_entries(since_change_id=latest_change_id_after_link) - unlink_log_entry = None - for entry in new_entries_unlink: - if entry["entity"] == "conversation_keywords" and entry["entity_id"] == expected_entity_id and entry[ - "operation"] == "delete": - unlink_log_entry = entry - break - assert unlink_log_entry is not None + unlink_entries = db_instance.get_sync_log_entries(since_change_id=latest_change_id_after_link) + unlink_entry = next( + (e for e in unlink_entries if e["entity"] == "conversation_keywords" and e["operation"] == "delete"), None) + assert unlink_entry is not None + assert unlink_entry["entity_id"] == f"{conv_id}_{kw_id}" class TestTransactions: def test_transaction_commit(self, db_instance: CharactersRAGDB): - card_data1_name = "Trans1 Character" - card_data2_name = "Trans2 Character" - - with db_instance.transaction() as conn: # Get conn from context for direct execution - # Direct execution to test transaction atomicity without involving add_character_card's own transaction - conn.execute( - "INSERT INTO character_cards (name, description, client_id, last_modified, version) VALUES (?, ?, ?, ?, ?)", - (card_data1_name, "Desc1", db_instance.client_id, get_current_utc_timestamp_iso(), 1) - ) - id1_row = conn.execute("SELECT id FROM character_cards WHERE name = ?", (card_data1_name,)).fetchone() - assert id1_row is not None - id1 = id1_row['id'] - - conn.execute( - "INSERT INTO character_cards (name, description, client_id, last_modified, version) VALUES (?, ?, ?, ?, ?)", - (card_data2_name, "Desc2", db_instance.client_id, get_current_utc_timestamp_iso(), 1) - ) - - retrieved1 = db_instance.get_character_card_by_id(id1) - retrieved2 = db_instance.get_character_card_by_name(card_data2_name) - assert retrieved1 is not None - assert retrieved2 is not None + with db_instance.transaction() as conn: + conn.execute("INSERT INTO character_cards (name, client_id) VALUES (?, ?)", + ("Trans1", db_instance.client_id)) + conn.execute("INSERT INTO character_cards (name, client_id) VALUES (?, ?)", + ("Trans2", db_instance.client_id)) + + assert db_instance.get_character_card_by_name("Trans1") is not None + assert db_instance.get_character_card_by_name("Trans2") is not None def test_transaction_rollback(self, db_instance: CharactersRAGDB): - card_data_name = "TransRollback Character" initial_count = len(db_instance.list_character_cards()) - with pytest.raises(sqlite3.IntegrityError): - with db_instance.transaction() as conn: # Get conn from context - # First insert (will be part of transaction) - conn.execute( - "INSERT INTO character_cards (name, description, client_id, last_modified, version) VALUES (?, ?, ?, ?, ?)", - (card_data_name, "DescRollback", db_instance.client_id, get_current_utc_timestamp_iso(), 1) - ) - # Second insert that causes an error (duplicate unique name) - conn.execute( - "INSERT INTO character_cards (name, description, client_id, last_modified, version) VALUES (?, ?, ?, ?, ?)", - (card_data_name, "DescRollbackFail", db_instance.client_id, get_current_utc_timestamp_iso(), 1) - ) - - # Check that the first insert was rolled back + with db_instance.transaction() as conn: + conn.execute("INSERT INTO character_cards (name, client_id) VALUES (?, ?)", + ("TransRollback", db_instance.client_id)) + # This will fail due to duplicate name, causing a rollback + conn.execute("INSERT INTO character_cards (name, client_id) VALUES (?, ?)", + ("TransRollback", db_instance.client_id)) + assert len(db_instance.list_character_cards()) == initial_count - assert db_instance.get_character_card_by_name(card_data_name) is None - -# More tests can be added for: -# - Specific FTS trigger behavior (though search tests cover them indirectly) -# - Behavior of ON DELETE CASCADE / ON UPDATE CASCADE where applicable (e.g., true deletion of character should cascade to conversations IF hard delete was used and schema supported it) -# - More complex conflict scenarios with multiple clients (harder to simulate perfectly in unit tests without multiple DB instances writing to the same file). -# - All permutations of linking and unlinking for all link tables. -# - All specific error conditions for each method (e.g. InputError for various fields). \ No newline at end of file + assert db_instance.get_character_card_by_name("TransRollback") is None \ No newline at end of file diff --git a/Tests/ChaChaNotesDB/test_chachanotes_db_properties.py b/Tests/ChaChaNotesDB/test_chachanotes_db_properties.py new file mode 100644 index 00000000..23b9ebdd --- /dev/null +++ b/Tests/ChaChaNotesDB/test_chachanotes_db_properties.py @@ -0,0 +1,1128 @@ +# test_chachanotes_db_properties.py +# +# Property-based tests for the ChaChaNotes_DB library using Hypothesis. +# +# Imports +import uuid +import pytest +import json +from pathlib import Path +import sqlite3 +import threading +import time +# +# Third-Party Imports +from hypothesis import given, strategies as st, settings, HealthCheck +from hypothesis.stateful import RuleBasedStateMachine, rule, precondition, Bundle +# +# Local Imports +from tldw_chatbook.DB.ChaChaNotes_DB import ( + CharactersRAGDB, + InputError, + CharactersRAGDBError, + ConflictError +) +# +######################################################################################################################## +# +# Functions: +# --- Hypothesis Tests --- + +settings.register_profile( + "db_friendly", + deadline=1000, + suppress_health_check=[ + HealthCheck.too_slow, + HealthCheck.function_scoped_fixture # <--- THIS IS THE FIX + ] +) +settings.load_profile("db_friendly") + +# Strategy for generating a valid character card dictionary +# The `.map(lambda t: ...)` part is to assemble the parts into a dictionary +st_character_card_data = st.tuples( + st.text(min_size=1, max_size=100), # name + st.one_of(st.none(), st.text(max_size=500)), # description + st.one_of(st.none(), st.text(max_size=500)), # personality + st.one_of(st.none(), st.binary(max_size=1024)), # image + st.one_of(st.none(), st.lists(st.text(max_size=50)).map(json.dumps)), # alternate_greetings as json string + st.one_of(st.none(), st.lists(st.text(max_size=20)).map(json.dumps)) # tags as json string +).map(lambda t: { + "name": t[0], + "description": t[1], + "personality": t[2], + "image": t[3], + "alternate_greetings": t[4], + "tags": t[5], +}) + +# Define a strategy for a non-zero integer to add to the version +st_version_offset = st.integers().filter(lambda x: x != 0) + +# To prevent tests from being too slow on complex data, we can set a deadline. +# We also disable the 'too_slow' health check as DB operations can sometimes be slow. +settings.register_profile("db_friendly", deadline=1000, suppress_health_check=[HealthCheck.too_slow]) +settings.load_profile("db_friendly") + +# --- Fixtures (Copied from your existing test file for a self-contained example) --- + +@pytest.fixture +def client_id(): + """Provides a consistent client ID for tests.""" + return "hypothesis_client" + + +@pytest.fixture +def db_path(tmp_path): + """Provides a temporary path for the database file for each test.""" + return tmp_path / "prop_test_db.sqlite" + + +@pytest.fixture(scope="function") +def db_instance(db_path, client_id): + """Creates a DB instance for each test, ensuring a fresh database.""" + current_db_path = Path(db_path) + # Ensure no leftover files from a failed previous run + for suffix in ["", "-wal", "-shm"]: + p = Path(str(current_db_path) + suffix) + if p.exists(): + p.unlink(missing_ok=True) + + db = CharactersRAGDB(current_db_path, client_id) + yield db + db.close_connection() + + +# --- Hypothesis Strategies --- +# These strategies define how to generate random, valid data for our database objects. + +# A strategy for text fields that cannot be empty or just whitespace. +st_required_text = st.text(min_size=1, max_size=100).filter(lambda s: s.strip()) + +# A strategy for optional text or binary fields. +st_optional_text = st.one_of(st.none(), st.text(max_size=500)) +st_optional_binary = st.one_of(st.none(), st.binary(max_size=1024)) + +# A strategy for fields that are stored as JSON strings in the DB. +# We generate a Python list/dict and then map it to a JSON string. +st_json_list = st.lists(st.text(max_size=50)).map(json.dumps) +st_json_dict = st.dictionaries(st.text(max_size=20), st.text(max_size=100)).map(json.dumps) + + +@st.composite +def st_character_card_data(draw): + """A composite strategy to generate a dictionary of character card data.""" + # `draw` is a function that pulls a value from another strategy. + name = draw(st_required_text) + + # To avoid conflicts with the guaranteed 'Default Assistant', we filter it out. + if name == 'Default Assistant': + name += "_hypothesis" # Just ensure it's not the exact name + + return { + "name": name, + "description": draw(st_optional_text), + "personality": draw(st_optional_text), + "scenario": draw(st_optional_text), + "system_prompt": draw(st_optional_text), + "image": draw(st_optional_binary), + "post_history_instructions": draw(st_optional_text), + "first_message": draw(st_optional_text), + "message_example": draw(st_optional_text), + "creator_notes": draw(st_optional_text), + "alternate_greetings": draw(st.one_of(st.none(), st_json_list)), + "tags": draw(st.one_of(st.none(), st_json_list)), + "creator": draw(st_optional_text), + "character_version": draw(st_optional_text), + "extensions": draw(st.one_of(st.none(), st_json_dict)), + } + + +@st.composite +def st_note_data(draw): + """Generates a dictionary for a note (title and content).""" + return { + "title": draw(st_required_text), + "content": draw(st.text(max_size=2000)) # Content can be empty + } + + +# --- Property Test Classes --- + +class TestCharacterCardProperties: + """Property-based tests for Character Cards.""" + + @settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) + @given(card_data=st_character_card_data()) + def test_character_card_roundtrip(self, db_instance: CharactersRAGDB, card_data: dict): + """ + Property: If we add a character card, retrieving it should return the exact same data. + This is a "round-trip" test. + """ + try: + card_id = db_instance.add_character_card(card_data) + except ConflictError: + # Hypothesis might generate the same name twice. This is not a failure of the + # DB logic, so we just skip this test case. + return + + retrieved_card = db_instance.get_character_card_by_id(card_id) + + assert retrieved_card is not None + assert retrieved_card["version"] == 1 + assert not retrieved_card["deleted"] + + # Compare original data with retrieved data + for key, value in card_data.items(): + if key in db_instance._CHARACTER_CARD_JSON_FIELDS and value is not None: + # JSON fields are deserialized, so we compare to the parsed version + assert retrieved_card[key] == json.loads(value) + else: + assert retrieved_card[key] == value + + @settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) + @given(initial_card=st_character_card_data(), update_payload=st_character_card_data()) + def test_update_increments_version_and_changes_data(self, db_instance: CharactersRAGDB, initial_card: dict, + update_payload: dict): + """ + Property: A successful update must increment the version number by exactly 1 + and correctly apply the new data. + """ + try: + card_id = db_instance.add_character_card(initial_card) + except ConflictError: + return # Skip if initial card name conflicts + + original_card = db_instance.get_character_card_by_id(card_id) + + try: + success = db_instance.update_character_card(card_id, update_payload, + expected_version=original_card['version']) + except ConflictError as e: + # An update can legitimately fail if the new name is already taken. + # We accept this as a valid outcome. + assert "already exists" in str(e) + return + + assert success is True + + updated_card = db_instance.get_character_card_by_id(card_id) + assert updated_card is not None + assert updated_card['version'] == original_card['version'] + 1 + + # Verify the payload was applied + assert updated_card['name'] == update_payload['name'] + assert updated_card['description'] == update_payload['description'] + + @settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) + @given(card_data=st_character_card_data()) + def test_soft_delete_makes_item_unfindable(self, db_instance: CharactersRAGDB, card_data: dict): + """ + Property: After soft-deleting an item, it should not be retrievable by + the standard `get` or `list` methods, but should exist in the DB with deleted=1. + """ + try: + card_id = db_instance.add_character_card(card_data) + except ConflictError: + return + + original_card = db_instance.get_character_card_by_id(card_id) + + # Perform the soft delete + success = db_instance.soft_delete_character_card(card_id, expected_version=original_card['version']) + assert success is True + + # Assert it's no longer findable via public methods + assert db_instance.get_character_card_by_id(card_id) is None + + all_cards = db_instance.list_character_cards() + assert card_id not in [c['id'] for c in all_cards] + + # Assert its raw state in the DB is correct + conn = db_instance.get_connection() + raw_record = conn.execute("SELECT deleted, version FROM character_cards WHERE id = ?", (card_id,)).fetchone() + assert raw_record is not None + assert raw_record['deleted'] == 1 + assert raw_record['version'] == original_card['version'] + 1 + + @settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) + @given( + initial_card=st_character_card_data(), + update_payload=st_character_card_data(), + version_offset=st_version_offset + ) + def test_update_with_stale_version_always_fails(self, db_instance: CharactersRAGDB, initial_card: dict, + update_payload: dict, version_offset: int): + """ + Property: Attempting to update a record with an incorrect `expected_version` + must always raise a ConflictError. + """ + try: + card_id = db_instance.add_character_card(initial_card) + except ConflictError: + return # Skip if initial card name conflicts + + original_card = db_instance.get_character_card_by_id(card_id) + + # Use the generated non-zero offset to create a stale version + stale_version = original_card['version'] + version_offset + + with pytest.raises(ConflictError): + db_instance.update_character_card(card_id, update_payload, expected_version=stale_version) + + @settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) + @given( + initial_card=st_character_card_data(), + update_payload=st_character_card_data() + ) + def test_update_does_not_change_immutable_fields(self, db_instance: CharactersRAGDB, initial_card: dict, + update_payload: dict): + """ + Property: The `update` method must not change immutable fields like `id` and `created_at`, + even if they are passed in the payload. + """ + try: + card_id = db_instance.add_character_card(initial_card) + except ConflictError: + return + + original_card = db_instance.get_character_card_by_id(card_id) + + # Add immutable fields to the update payload to try and change them + malicious_payload = update_payload.copy() + malicious_payload['id'] = 99999 # Try to change the ID + malicious_payload['created_at'] = "1999-01-01T00:00:00Z" # Try to change creation time + + try: + db_instance.update_character_card(card_id, malicious_payload, expected_version=original_card['version']) + except ConflictError: + # This can happen if the update_payload name conflicts, which is a valid outcome. + return + + updated_card = db_instance.get_character_card_by_id(card_id) + + # Assert that the immutable fields did NOT change. + assert updated_card['id'] == original_card['id'] + assert updated_card['created_at'] == original_card['created_at'] + + +class TestNoteAndKeywordProperties: + """Property-based tests for Notes, Keywords, and their linking.""" + + @settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) + @given(note_data=st_note_data()) + def test_note_roundtrip(self, db_instance: CharactersRAGDB, note_data: dict): + """ + Property: A created note, when retrieved, has the same data, + accounting for any sanitization (like stripping whitespace). + """ + note_id = db_instance.add_note(**note_data) + assert note_id is not None + + retrieved = db_instance.get_note_by_id(note_id) + + assert retrieved is not None + # Compare the retrieved title to the STRIPPED version of the original title + assert retrieved['title'] == note_data['title'].strip() # <-- The fix + assert retrieved['content'] == note_data['content'] + assert retrieved['version'] == 1 + + @settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) + @given(keyword=st_required_text) + def test_add_keyword_is_idempotent_on_undelete(self, db_instance: CharactersRAGDB, keyword: str): + """ + Property: Adding a keyword that was previously soft-deleted should reactivate + it, not create a new one, and its version should be correctly incremented. + """ + # 1. Add for the first time + try: + kw_id_v1 = db_instance.add_keyword(keyword) + except ConflictError: + return + + kw_v1 = db_instance.get_keyword_by_id(kw_id_v1) + assert kw_v1['version'] == 1 + + # 2. Soft delete it + db_instance.soft_delete_keyword(kw_id_v1, expected_version=1) + kw_v2_raw = db_instance.get_connection().execute("SELECT * FROM keywords WHERE id = ?", (kw_id_v1,)).fetchone() + assert kw_v2_raw['deleted'] == 1 + assert kw_v2_raw['version'] == 2 + + # 3. Add it again (should trigger undelete) + kw_id_v3 = db_instance.add_keyword(keyword) + + # Assert it's the same record + assert kw_id_v3 == kw_id_v1 + + kw_v3 = db_instance.get_keyword_by_id(kw_id_v3) + assert kw_v3 is not None + assert not kw_v3['deleted'] + # The version should be 3 (1=create, 2=delete, 3=undelete) + assert kw_v3['version'] == 3 + + @settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) + @given(note_data=st_note_data(), keyword_text=st_required_text) + def test_linking_and_unlinking_properties(self, db_instance: CharactersRAGDB, note_data: dict, keyword_text: str): + """ + Property: Linking two items should make them appear in each other's "get_links" + methods, and unlinking should remove them. + """ + try: + note_id = db_instance.add_note(**note_data) + keyword_id = db_instance.add_keyword(keyword_text) + except ConflictError: + return # Skip if hypothesis generates conflicting data + + # Initially, no links should exist + assert db_instance.get_keywords_for_note(note_id) == [] + assert db_instance.get_notes_for_keyword(keyword_id) == [] + + # --- Test Linking --- + link_success = db_instance.link_note_to_keyword(note_id, keyword_id) + assert link_success is True + + # Check that linking again is idempotent (returns False) + link_again_success = db_instance.link_note_to_keyword(note_id, keyword_id) + assert link_again_success is False + + # Verify the link exists from both sides + keywords_for_note = db_instance.get_keywords_for_note(note_id) + assert len(keywords_for_note) == 1 + assert keywords_for_note[0]['id'] == keyword_id + + notes_for_keyword = db_instance.get_notes_for_keyword(keyword_id) + assert len(notes_for_keyword) == 1 + assert notes_for_keyword[0]['id'] == note_id + + # --- Test Unlinking --- + unlink_success = db_instance.unlink_note_from_keyword(note_id, keyword_id) + assert unlink_success is True + + # Check that unlinking again is idempotent (returns False) + unlink_again_success = db_instance.unlink_note_from_keyword(note_id, keyword_id) + assert unlink_again_success is False + + # Verify the link is gone + assert db_instance.get_keywords_for_note(note_id) == [] + assert db_instance.get_notes_for_keyword(keyword_id) == [] + + +class TestAdvancedProperties: + @settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) + @given(note_data=st_note_data()) + def test_soft_deleted_item_is_not_in_fts(self, db_instance: CharactersRAGDB, note_data: dict): + """ + Property: Once an item is soft-deleted, it must not appear in FTS search results. + """ + # Ensure the title has a unique, searchable term. + unique_term = str(uuid.uuid4()) + note_data['title'] = f"{note_data['title']} {unique_term}" + + note_id = db_instance.add_note(**note_data) + original_note = db_instance.get_note_by_id(note_id) + + # 1. Verify it IS searchable before deletion + search_results_before = db_instance.search_notes(unique_term) + assert len(search_results_before) == 1 + assert search_results_before[0]['id'] == note_id + + # 2. Soft-delete the note + db_instance.soft_delete_note(note_id, expected_version=original_note['version']) + + # 3. Verify it is NOT searchable after deletion + search_results_after = db_instance.search_notes(unique_term) + assert len(search_results_after) == 0 + + @settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) + @given(note_data=st_note_data()) + def test_add_creates_sync_log_entry(self, db_instance: CharactersRAGDB, note_data: dict): + """ + Property: Adding a new item must create exactly one 'create' operation + in the sync_log for that item. + """ + latest_change_id_before = db_instance.get_latest_sync_log_change_id() + + # Add the note (this action should be logged by a trigger) + note_id = db_instance.add_note(**note_data) + + new_log_entries = db_instance.get_sync_log_entries(since_change_id=latest_change_id_before) + + # There should be exactly one new entry + assert len(new_log_entries) == 1 + + log_entry = new_log_entries[0] + assert log_entry['entity'] == 'notes' + assert log_entry['entity_id'] == note_id + assert log_entry['operation'] == 'create' + assert log_entry['client_id'] == db_instance.client_id + assert log_entry['version'] == 1 + assert log_entry['payload']['title'] == note_data['title'].strip() # The log stores the stripped version + assert log_entry['payload']['content'] == note_data['content'] + + @settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) + @given(note_data=st_note_data(), update_content=st.text(max_size=500)) + def test_update_creates_sync_log_entry(self, db_instance: CharactersRAGDB, note_data: dict, update_content: str): + """ + Property: Updating an item must create exactly one 'update' operation + in the sync_log for that item. + """ + note_id = db_instance.add_note(**note_data) + latest_change_id_before = db_instance.get_latest_sync_log_change_id() + + # Update the note + update_payload = {'content': update_content} + db_instance.update_note(note_id, update_payload, expected_version=1) + + new_log_entries = db_instance.get_sync_log_entries(since_change_id=latest_change_id_before) + + assert len(new_log_entries) == 1 + log_entry = new_log_entries[0] + assert log_entry['entity'] == 'notes' + assert log_entry['entity_id'] == note_id + assert log_entry['operation'] == 'update' + assert log_entry['version'] == 2 + assert log_entry['payload']['content'] == update_content + + @settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) + @given(note_data=st_note_data()) + def test_delete_creates_sync_log_entry(self, db_instance: CharactersRAGDB, note_data: dict): + """ + Property: Soft-deleting an item must create exactly one 'delete' operation + in the sync_log for that item. + """ + note_id = db_instance.add_note(**note_data) + latest_change_id_before = db_instance.get_latest_sync_log_change_id() + + # Delete the note + db_instance.soft_delete_note(note_id, expected_version=1) + + new_log_entries = db_instance.get_sync_log_entries(since_change_id=latest_change_id_before) + + assert len(new_log_entries) == 1 + log_entry = new_log_entries[0] + assert log_entry['entity'] == 'notes' + assert log_entry['entity_id'] == note_id + assert log_entry['operation'] == 'delete' + assert log_entry['version'] == 2 + assert log_entry['payload']['deleted'] == 1 + + @settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) + @given(note_data=st_note_data(), keyword_text=st_required_text) + def test_link_action_creates_correct_sync_log(self, db_instance: CharactersRAGDB, note_data: dict, + keyword_text: str): + """ + Property: The `link_note_to_keyword` action must create a sync_log entry + with the correct entity, IDs, and operation in its payload. + """ + try: + note_id = db_instance.add_note(**note_data) + kw_id = db_instance.add_keyword(keyword_text) + except ConflictError: + return + + latest_change_id_before = db_instance.get_latest_sync_log_change_id() + + # Action + db_instance.link_note_to_keyword(note_id, kw_id) + + # There should be exactly one new log entry from the linking. + # We ignore the create logs for the note and keyword. + link_log_entries = db_instance.get_sync_log_entries(since_change_id=latest_change_id_before, + entity_type='note_keywords') + assert len(link_log_entries) == 1 + + log_entry = link_log_entries[0] + assert log_entry['entity'] == 'note_keywords' + assert log_entry['operation'] == 'create' + assert log_entry['entity_id'] == f"{note_id}_{kw_id}" + assert log_entry['payload']['note_id'] == note_id + assert log_entry['payload']['keyword_id'] == kw_id + + + @settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) + @given(note_data=st_note_data(), keyword_text=st_required_text) + def test_unlink_action_creates_correct_sync_log(self, db_instance: CharactersRAGDB, note_data: dict, + keyword_text: str): + """ + Property: The `unlink_note_to_keyword` action must create a sync_log entry + with the correct entity, IDs, and operation. + """ + try: + note_id = db_instance.add_note(**note_data) + kw_id = db_instance.add_keyword(keyword_text) + except ConflictError: + return + + db_instance.link_note_to_keyword(note_id, kw_id) + latest_change_id_before = db_instance.get_latest_sync_log_change_id() + + # Action + db_instance.unlink_note_from_keyword(note_id, kw_id) + + unlink_log_entries = db_instance.get_sync_log_entries(since_change_id=latest_change_id_before, + entity_type='note_keywords') + assert len(unlink_log_entries) == 1 + + log_entry = unlink_log_entries[0] + assert log_entry['entity'] == 'note_keywords' + assert log_entry['operation'] == 'delete' + assert log_entry['entity_id'] == f"{note_id}_{kw_id}" + assert log_entry['payload']['note_id'] == note_id + assert log_entry['payload']['keyword_id'] == kw_id + + +@pytest.fixture +def populated_conversation(db_instance: CharactersRAGDB): + """A fixture to create a character, conversation, and message for cascade tests.""" + card_id = db_instance.add_character_card({'name': 'Cascade Test Character'}) + card = db_instance.get_character_card_by_id(card_id) + + conv_id = db_instance.add_conversation({'character_id': card['id'], 'title': 'Cascade Conv'}) + conv = db_instance.get_conversation_by_id(conv_id) + + msg_id = db_instance.add_message({'conversation_id': conv['id'], 'sender': 'user', 'content': 'Cascade Msg'}) + msg = db_instance.get_message_by_id(msg_id) + + return {"card": card, "conv": conv, "msg": msg} + + +class TestCascadeAndLinkingProperties: + def test_soft_deleting_conversation_makes_messages_unfindable(self, db_instance: CharactersRAGDB, + populated_conversation): + """ + Property: After a conversation is soft-deleted, its messages should not be + returned by get_messages_for_conversation. + """ + conv = populated_conversation['conv'] + msg = populated_conversation['msg'] + + # 1. Verify message and conversation exist before + assert db_instance.get_message_by_id(msg['id']) is not None + assert len(db_instance.get_messages_for_conversation(conv['id'])) == 1 + + # 2. Soft-delete the parent conversation + db_instance.soft_delete_conversation(conv['id'], expected_version=conv['version']) + + # 3. Verify the messages are now un-findable via the main query method + messages = db_instance.get_messages_for_conversation(conv['id']) + assert messages == [] + + # 4. As per our current design, the message record itself still exists (orphaned). + # This is an important check of the current behavior. + assert db_instance.get_message_by_id(msg['id']) is not None + + def test_hard_deleting_conversation_cascades_to_messages(self, db_instance: CharactersRAGDB, + populated_conversation): + """ + Property: A hard DELETE on a conversation should cascade and delete its + messages, enforcing the FOREIGN KEY ... ON DELETE CASCADE constraint. + """ + conv = populated_conversation['conv'] + msg = populated_conversation['msg'] + + # 1. Verify message exists before + assert db_instance.get_message_by_id(msg['id']) is not None + + # 2. Perform a hard delete in a clean transaction + with db_instance.transaction() as conn: + conn.execute("DELETE FROM conversations WHERE id = ?", (conv['id'],)) + + # 3. Now the message should be truly gone. + assert db_instance.get_message_by_id(msg['id']) is None + + @settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) + @given(keyword_text=st_required_text) + def test_deleting_keyword_cascades_to_link_tables(self, db_instance: CharactersRAGDB, populated_conversation, + keyword_text: str): + """ + Property: Deleting a keyword should remove all links to it from linking tables + due to ON DELETE CASCADE. + """ + conv = populated_conversation['conv'] + try: + kw_id = db_instance.add_keyword(keyword_text) + except ConflictError: + return + + # Link the conversation and keyword + db_instance.link_conversation_to_keyword(conv['id'], kw_id) + + # Verify link exists + keywords = db_instance.get_keywords_for_conversation(conv['id']) + assert len(keywords) == 1 + assert keywords[0]['id'] == kw_id + + # Soft-delete the keyword + keyword = db_instance.get_keyword_by_id(kw_id) + db_instance.soft_delete_keyword(kw_id, keyword['version']) + + # The link should now be gone when we retrieve it, because the JOIN will fail + # on `k.deleted = 0`. This tests the query logic. + keywords_after_delete = db_instance.get_keywords_for_conversation(conv['id']) + assert keywords_after_delete == [] + + + @settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) + @given(note_data=st_note_data(), keyword_text=st_required_text) + def test_linking_is_idempotent(self, db_instance: CharactersRAGDB, note_data: dict, keyword_text: str): + """ + Property: Calling a link function multiple times has the same effect as calling it once. + The first call should return True (1 row affected), subsequent calls should return False (0 rows affected). + """ + try: + note_id = db_instance.add_note(**note_data) + kw_id = db_instance.add_keyword(keyword_text) + except ConflictError: + return + + # First call should succeed and return True + assert db_instance.link_note_to_keyword(note_id, kw_id) is True + + # Second call should do nothing and return False + assert db_instance.link_note_to_keyword(note_id, kw_id) is False + + # Third call should also do nothing and return False + assert db_instance.link_note_to_keyword(note_id, kw_id) is False + + # Verify there is still only one link + keywords = db_instance.get_keywords_for_note(note_id) + assert len(keywords) == 1 + + + @settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) + @given(note_data=st_note_data(), keyword_text=st_required_text) + def test_unlinking_is_idempotent(self, db_instance: CharactersRAGDB, note_data: dict, keyword_text: str): + """ + Property: Calling an unlink function on a non-existent link does nothing. + Calling it on an existing link works once, then does nothing on subsequent calls. + """ + try: + note_id = db_instance.add_note(**note_data) + kw_id = db_instance.add_keyword(keyword_text) + except ConflictError: + return + + # 1. Unlinking a non-existent link should return False + assert db_instance.unlink_note_from_keyword(note_id, kw_id) is False + + # 2. Create the link + db_instance.link_note_to_keyword(note_id, kw_id) + assert len(db_instance.get_keywords_for_note(note_id)) == 1 + + # 3. First unlink should succeed and return True + assert db_instance.unlink_note_from_keyword(note_id, kw_id) is True + + # 4. Second unlink should fail and return False + assert db_instance.unlink_note_from_keyword(note_id, kw_id) is False + + # Verify the link is gone + assert len(db_instance.get_keywords_for_note(note_id)) == 0 + + +# ========================================================== +# == STATE MACHINE SECTION +# ========================================================== + +class NoteLifecycleMachine(RuleBasedStateMachine): + """ + This class defines the rules and state for our test. + It is not run directly by pytest. + """ + + def __init__(self): + super().__init__() + self.db = None # This will be injected by the test class + self.note_id = None + self.expected_version = 0 + self.is_deleted = True + + notes = Bundle('notes') + + @rule(target=notes, note_data=st_note_data()) + def create_note(self, note_data): + # We only want to test the lifecycle of one note per machine run. + if self.note_id is not None: + return + + self.note_id = self.db.add_note(**note_data) + self.is_deleted = False + self.expected_version = 1 + + retrieved = self.db.get_note_by_id(self.note_id) + assert retrieved is not None + return self.note_id + + @rule(note_id=notes, update_data=st_note_data()) + def update_note(self, note_id, update_data): + if self.note_id is None or self.is_deleted: + return + + success = self.db.update_note(note_id, update_data, self.expected_version) + assert success + self.expected_version += 1 + + retrieved = self.db.get_note_by_id(self.note_id) + assert retrieved is not None + assert retrieved['version'] == self.expected_version + + @rule(note_id=notes) + def soft_delete_note(self, note_id): + if self.note_id is None or self.is_deleted: + return + + success = self.db.soft_delete_note(note_id, self.expected_version) + assert success + self.expected_version += 1 + self.is_deleted = True + + assert self.db.get_note_by_id(self.note_id) is None + + +# This class IS the test. pytest will discover it. +# It inherits our rules and provides the `db_instance` fixture. +@settings(suppress_health_check=[HealthCheck.function_scoped_fixture, HealthCheck.too_slow], max_examples=50) +class TestNoteLifecycleAsTest(NoteLifecycleMachine): + + @pytest.fixture(autouse=True) + def inject_db(self, db_instance): + """Injects the clean db_instance fixture into the state machine for each test run.""" + self.db = db_instance + + +# ========================================================== +# == Character Card State Machine +# ========================================================== + +class CharacterCardLifecycleMachine(RuleBasedStateMachine): + """Models the lifecycle of a CharacterCard.""" + + def __init__(self): + super().__init__() + self.db = None + self.card_id = None + self.card_name = None + self.expected_version = 0 + self.is_deleted = True + + cards = Bundle('cards') + + @rule(target=cards, card_data=st_character_card_data()) + def create_card(self, card_data): + # Only create one card per machine run for simplicity. + if self.card_id is not None: + return + + try: + new_id = self.db.add_character_card(card_data) + except ConflictError: + # It's possible for hypothesis to generate a duplicate name + # in its sequence. We treat this as "no action taken". + return + + self.card_id = new_id + self.card_name = card_data['name'] + self.expected_version = 1 + self.is_deleted = False + + retrieved = self.db.get_character_card_by_id(self.card_id) + assert retrieved is not None + assert retrieved['name'] == self.card_name + return self.card_id + + @rule(card_id=cards, update_data=st_character_card_data()) + def update_card(self, card_id, update_data): + if self.card_id is None or self.is_deleted: + return + + try: + success = self.db.update_character_card(card_id, update_data, self.expected_version) + assert success + self.expected_version += 1 + self.card_name = update_data['name'] # Name can change + except ConflictError as e: + # Update can fail legitimately if the new name is already taken. + assert "already exists" in str(e) + # The state of our card hasn't changed, so we just return. + return + + retrieved = self.db.get_character_card_by_id(self.card_id) + assert retrieved is not None + assert retrieved['version'] == self.expected_version + assert retrieved['name'] == self.card_name + + @rule(card_id=cards) + def soft_delete_card(self, card_id): + if self.card_id is None or self.is_deleted: + return + + success = self.db.soft_delete_character_card(card_id, self.expected_version) + assert success + self.expected_version += 1 + self.is_deleted = True + + assert self.db.get_character_card_by_id(self.card_id) is None + assert self.db.get_character_card_by_name(self.card_name) is None + + +# The pytest test class that runs the machine +@settings(suppress_health_check=[HealthCheck.function_scoped_fixture, HealthCheck.too_slow], max_examples=50) +class TestCharacterCardLifecycle(CharacterCardLifecycleMachine): + + @pytest.fixture(autouse=True) + def inject_db(self, db_instance): + self.db = db_instance + + +st_message_content = st.text(max_size=1000) + + +# ========================================================== +# == Conversation/Message State Machine +# ========================================================== + +class ConversationMachine(RuleBasedStateMachine): + """ + Models creating a conversation and adding messages to it. + This machine tests the integrity of a single conversation over time. + """ + + def __init__(self): + super().__init__() + self.db = None + self.card_id = None + self.conv_id = None + self.message_count = 0 + + @precondition(lambda self: self.card_id is None) + @rule() + def create_character(self): + """Setup step: create a character to host the conversation.""" + self.card_id = self.db.add_character_card({'name': 'Chat Host Character'}) + assert self.card_id is not None + + @precondition(lambda self: self.card_id is not None and self.conv_id is None) + @rule() + def create_conversation(self): + """Create the main conversation for this test run.""" + self.conv_id = self.db.add_conversation({'character_id': self.card_id, 'title': 'Test Chat'}) + assert self.conv_id is not None + + @precondition(lambda self: self.conv_id is not None) + @rule(content=st_message_content, sender=st.sampled_from(['user', 'ai'])) + def add_message(self, content, sender): + """Add a new message to the existing conversation.""" + if not content and not sender: # Ensure message has some substance + return + + msg_id = self.db.add_message({ + 'conversation_id': self.conv_id, + 'sender': sender, + 'content': content + }) + assert msg_id is not None + self.message_count += 1 + + def teardown(self): + """ + This method is called at the end of a state machine run. + We use it to check the final state of the system. + """ + if self.conv_id is not None: + messages = self.db.get_messages_for_conversation(self.conv_id) + assert len(messages) == self.message_count + + +# The pytest test class that runs the machine +@settings(suppress_health_check=[HealthCheck.function_scoped_fixture, HealthCheck.too_slow], max_examples=20, + stateful_step_count=50) +class TestConversationInteractions(ConversationMachine): + + @pytest.fixture(autouse=True) + def inject_db(self, db_instance): + self.db = db_instance + + +class TestDataIntegrity: + def test_add_conversation_with_nonexistent_character_fails(self, db_instance: CharactersRAGDB): + """ + Property: Cannot create a conversation for a character_id that does not exist. + This tests the FOREIGN KEY constraint. + """ + non_existent_char_id = 99999 + conv_data = {"character_id": non_existent_char_id, "title": "Orphan Conversation"} + + # The database will raise an IntegrityError. Your wrapper should catch this + # and raise a custom error. + with pytest.raises(CharactersRAGDBError, match="FOREIGN KEY constraint failed"): + db_instance.add_conversation(conv_data) + + def test_add_message_to_nonexistent_conversation_fails(self, db_instance: CharactersRAGDB): + """ + Property: Cannot add a message to a conversation_id that does not exist. + """ + non_existent_conv_id = "a-fake-uuid-string" + msg_data = { + "conversation_id": non_existent_conv_id, + "sender": "user", + "content": "Message to nowhere" + } + + # Your `add_message` has a pre-flight check for this, which should raise InputError. + # This tests your application-level check. + with pytest.raises(InputError, match="Conversation ID .* not found or deleted"): + db_instance.add_message(msg_data) + + def test_rating_outside_range_fails(self, db_instance: CharactersRAGDB): + """ + Property: Conversation rating must be between 1 and 5. + This tests the CHECK constraint via the public API. + """ + card_id = db_instance.add_character_card({'name': 'Rating Test Character'}) + + # Test the application-level check in `update_conversation` + conv_id = db_instance.add_conversation({'character_id': card_id, "title": "Rating Conv"}) + with pytest.raises(InputError, match="Rating must be between 1 and 5"): + db_instance.update_conversation(conv_id, {"rating": 0}, expected_version=1) + with pytest.raises(InputError, match="Rating must be between 1 and 5"): + db_instance.update_conversation(conv_id, {"rating": 6}, expected_version=1) + + # Test the DB-level CHECK constraint directly by calling the wrapped execute_query + with pytest.raises(CharactersRAGDBError, match="Database constraint violation"): + # We start a transaction to ensure atomicity, but call the DB method + # that handles exception wrapping. + with db_instance.transaction(): + db_instance.execute_query("UPDATE conversations SET rating = 10 WHERE id = ?", (conv_id,)) + + +class TestConcurrency: + def test_each_thread_gets_a_separate_connection(self, db_instance: CharactersRAGDB): + """ + Property: The `_get_thread_connection` method must provide a unique + connection object for each thread. + """ + connection_ids = set() + lock = threading.Lock() + + def get_and_store_conn_id(): + conn = db_instance.get_connection() + with lock: + connection_ids.add(id(conn)) + + threads = [threading.Thread(target=get_and_store_conn_id) for _ in range(5)] + for t in threads: + t.start() + for t in threads: + t.join() + + # If threading.local is working, there should be 5 unique connection IDs. + assert len(connection_ids) == 5 + + def test_wal_mode_allows_concurrent_reads_during_write_transaction(self, db_instance: CharactersRAGDB): + """ + Property: In WAL mode, one thread can read from the DB while another + thread has an open write transaction. + """ + card_id = db_instance.add_character_card({'name': 'Concurrent Read Test'}) + + # A threading.Event to signal when the write transaction has started + write_transaction_started = threading.Event() + read_result = [] + + def writer_thread(): + with db_instance.transaction(): + db_instance.update_character_card(card_id, {'description': 'long update'}, 1) + write_transaction_started.set() # Signal that the transaction is open + time.sleep(0.2) # Hold the transaction open + # Transaction commits here + + def reader_thread(): + write_transaction_started.wait() # Wait until the writer is in its transaction + # This read should succeed immediately and not be blocked by the writer. + card = db_instance.get_character_card_by_id(card_id) + read_result.append(card) + + w = threading.Thread(target=writer_thread) + r = threading.Thread(target=reader_thread) + + w.start() + r.start() + w.join() + r.join() + + # The reader thread should have completed successfully and read the *original* state. + assert len(read_result) == 1 + assert read_result[0] is not None + assert read_result[0]['description'] is None # It read the state before the writer committed. + + +class TestComplexQueries: + def test_get_keywords_for_conversation_filters_deleted_keywords(self, db_instance: CharactersRAGDB): + """ + Property: When fetching keywords for a conversation, soft-deleted + keywords must be excluded from the results. + """ + card_id = db_instance.add_character_card({'name': 'Filter Test'}) + conv_id = db_instance.add_conversation({'character_id': card_id}) + + kw1_id = db_instance.add_keyword("Active Keyword") + kw2_id = db_instance.add_keyword("Keyword to be Deleted") + kw2 = db_instance.get_keyword_by_id(kw2_id) + + db_instance.link_conversation_to_keyword(conv_id, kw1_id) + db_instance.link_conversation_to_keyword(conv_id, kw2_id) + + # Verify both are present initially + assert len(db_instance.get_keywords_for_conversation(conv_id)) == 2 + + # Soft-delete one of the keywords + db_instance.soft_delete_keyword(kw2_id, kw2['version']) + + # Fetch again and verify only the active one remains + remaining_keywords = db_instance.get_keywords_for_conversation(conv_id) + assert len(remaining_keywords) == 1 + assert remaining_keywords[0]['id'] == kw1_id + assert remaining_keywords[0]['keyword'] == "Active Keyword" + + +class TestDBOperations: + def test_backup_and_restore_correctness(self, db_instance: CharactersRAGDB, tmp_path: Path): + """ + Property: A database created from a backup file must contain the exact + same data as the original database at the time of backup. + """ + # 1. Populate the original database with known data + card_data = {'name': 'Backup Test Card', 'description': 'Data to be saved'} + card_id = db_instance.add_character_card(card_data) + original_card = db_instance.get_character_card_by_id(card_id) + + # 2. Perform the backup + backup_path = tmp_path / "test_backup.db" + assert db_instance.backup_database(str(backup_path)) is True + assert backup_path.exists() + + # 3. Close the original DB connection + db_instance.close_connection() + + # 4. Open the backup database as a new instance + backup_db = CharactersRAGDB(backup_path, "restore_client") + + # 5. Verify the data + restored_card = backup_db.get_character_card_by_id(card_id) + assert restored_card is not None + + # Compare the entire dictionaries + # sqlite.Row objects need to be converted to dicts for direct comparison + assert dict(restored_card) == dict(original_card) + + # Also check list methods + all_cards = backup_db.list_character_cards() + # Remember the default card! + assert len(all_cards) == 2 + + backup_db.close_connection() + + +# +# End of test_chachanotes_db_properties.py +######################################################################################################################## diff --git a/Tests/Character_Chat/test_character_chat.py b/Tests/Character_Chat/test_character_chat.py index c1089c0e..e76500eb 100644 --- a/Tests/Character_Chat/test_character_chat.py +++ b/Tests/Character_Chat/test_character_chat.py @@ -1,155 +1,323 @@ -# test_property_character_chat_lib.py - -import unittest -from hypothesis import given, strategies as st, settings, HealthCheck -import re - -import Character_Chat_Lib as ccl - -class TestCharacterChatLibProperty(unittest.TestCase): - - # --- replace_placeholders --- - @given(text=st.text(), - char_name=st.one_of(st.none(), st.text(min_size=1, max_size=50)), - user_name=st.one_of(st.none(), st.text(min_size=1, max_size=50))) - @settings(suppress_health_check=[HealthCheck.too_slow]) - def test_replace_placeholders_properties(self, text, char_name, user_name): - processed = ccl.replace_placeholders(text, char_name, user_name) - self.assertIsInstance(processed, str) - - expected_char = char_name if char_name else "Character" - expected_user = user_name if user_name else "User" - - if "{{char}}" in text: - self.assertIn(expected_char, processed) - if "{{user}}" in text: - self.assertIn(expected_user, processed) - if "" in text: - self.assertIn(expected_char, processed) - if "" in text: - self.assertIn(expected_user, processed) - - # If no placeholders, text should be identical - placeholders = ['{{char}}', '{{user}}', '{{random_user}}', '', ''] - if not any(p in text for p in placeholders): - self.assertEqual(processed, text) - - @given(text=st.one_of(st.none(), st.just("")), char_name=st.text(), user_name=st.text()) - def test_replace_placeholders_empty_or_none_input_text(self, text, char_name, user_name): - self.assertEqual(ccl.replace_placeholders(text, char_name, user_name), "") - - # --- replace_user_placeholder --- - @given(history=st.lists(st.tuples(st.one_of(st.none(), st.text()), st.one_of(st.none(), st.text()))), - user_name=st.one_of(st.none(), st.text(min_size=1, max_size=50))) - @settings(suppress_health_check=[HealthCheck.too_slow]) - def test_replace_user_placeholder_properties(self, history, user_name): - processed_history = ccl.replace_user_placeholder(history, user_name) - self.assertEqual(len(processed_history), len(history)) - expected_user = user_name if user_name else "User" - - for i, (original_user_msg, original_bot_msg) in enumerate(history): - processed_user_msg, processed_bot_msg = processed_history[i] - if original_user_msg is not None: - self.assertIsInstance(processed_user_msg, str) - if "{{user}}" in original_user_msg: - self.assertIn(expected_user, processed_user_msg) - else: - self.assertEqual(processed_user_msg, original_user_msg) - else: - self.assertIsNone(processed_user_msg) - - if original_bot_msg is not None: - self.assertIsInstance(processed_bot_msg, str) - if "{{user}}" in original_bot_msg: - self.assertIn(expected_user, processed_bot_msg) - else: - self.assertEqual(processed_bot_msg, original_bot_msg) - else: - self.assertIsNone(processed_bot_msg) - - # --- extract_character_id_from_ui_choice --- - @given(name=st.text(alphabet=st.characters(min_codepoint=65, max_codepoint=122), min_size=1, max_size=20).filter(lambda x: '(' not in x and ')' not in x), - id_val=st.integers(min_value=0, max_value=10**9)) - def test_extract_id_format_name_id(self, name, id_val): - choice = f"{name} (ID: {id_val})" - self.assertEqual(ccl.extract_character_id_from_ui_choice(choice), id_val) - - @given(id_val=st.integers(min_value=0, max_value=10**9)) - def test_extract_id_format_just_id(self, id_val): - choice = str(id_val) - self.assertEqual(ccl.extract_character_id_from_ui_choice(choice), id_val) - - @given(text=st.text().filter(lambda x: not re.search(r'\(\s*ID\s*:\s*\d+\s*\)\s*$', x) and not x.isdigit() and x != "")) - def test_extract_id_invalid_format_raises_valueerror(self, text): - with self.assertRaises(ValueError): - ccl.extract_character_id_from_ui_choice(text) - - @given(choice=st.just("")) - def test_extract_id_empty_string_raises_valueerror(self, choice): - with self.assertRaises(ValueError): - ccl.extract_character_id_from_ui_choice(choice) - - # --- process_db_messages_to_ui_history --- - # This one is complex for property-based testing due to stateful accumulation. - # We can test some basic properties. - @given(db_messages=st.lists(st.fixed_dictionaries({ - 'sender': st.sampled_from(["User", "TestChar", "OtherSender"]), - 'content': st.text(max_size=100) - }), max_size=10), - char_name=st.text(min_size=1, max_size=20), - user_name=st.one_of(st.none(), st.text(min_size=1, max_size=20))) - @settings(suppress_health_check=[HealthCheck.too_slow, HealthCheck.filter_too_much]) - def test_process_db_messages_to_ui_history_output_structure(self, db_messages, char_name, user_name): - if not db_messages: # Avoid issues with empty messages list if logic depends on non-empty - return - - ui_history = ccl.process_db_messages_to_ui_history(db_messages, char_name, user_name, - actual_char_sender_id_in_db=char_name) # Map TestChar to char_name - self.assertIsInstance(ui_history, list) - for item in ui_history: - self.assertIsInstance(item, tuple) - self.assertEqual(len(item), 2) - self.assertTrue(item[0] is None or isinstance(item[0], str)) - self.assertTrue(item[1] is None or isinstance(item[1], str)) - - # If all messages are from User, bot messages in UI should be None - if all(msg['sender'] == "User" for msg in db_messages): - for _, bot_msg in ui_history: - self.assertIsNone(bot_msg) - - # If all messages are from Character, user messages in UI should be None - if all(msg['sender'] == char_name for msg in db_messages): - for user_msg, _ in ui_history: - self.assertIsNone(user_msg) - - # --- Card Validation Properties (Example for validate_character_book_entry) --- - # Strategy for a valid character book entry core - valid_entry_core_st = st.fixed_dictionaries({ - 'keys': st.lists(st.text(min_size=1, max_size=50), min_size=1, max_size=5), - 'content': st.text(min_size=1, max_size=200), - 'enabled': st.booleans(), - 'insertion_order': st.integers() - }) - - @given(entry_core=valid_entry_core_st, - entry_id_set=st.sets(st.integers(min_value=0, max_value=1000))) - def test_validate_character_book_entry_valid_core(self, entry_core, entry_id_set): - is_valid, errors = ccl.validate_character_book_entry(entry_core, 0, entry_id_set) - self.assertTrue(is_valid, f"Errors for supposedly valid core: {errors}") - self.assertEqual(len(errors), 0) - - @given(entry_core=valid_entry_core_st, - bad_key_type=st.integers(), # Make keys not a list of strings - entry_id_set=st.sets(st.integers())) - def test_validate_character_book_entry_invalid_keys_type(self, entry_core, bad_key_type, entry_id_set): - invalid_entry = {**entry_core, 'keys': bad_key_type} - is_valid, errors = ccl.validate_character_book_entry(invalid_entry, 0, entry_id_set) - self.assertFalse(is_valid) - self.assertTrue(any("Field 'keys' must be of type 'list'" in e for e in errors)) - - # More properties can be added for other parsing/validation functions if they - # have clear invariants that can be tested with generated data. - - -if __name__ == '__main__': - unittest.main(argv=['first-arg-is-ignored'], exit=False) \ No newline at end of file +# test_character_chat_lib.py + +import pytest +import sqlite3 +import json +import uuid +import base64 +import io +from datetime import datetime, timezone +from pathlib import Path +from unittest.mock import patch, MagicMock + +from PIL import Image + +# Local Imports from this project +from tldw_chatbook.DB.ChaChaNotes_DB import ( + CharactersRAGDB, + CharactersRAGDBError, + ConflictError, + InputError +) +from tldw_chatbook.Character_Chat.Character_Chat_Lib import ( + create_conversation, + get_conversation_details_and_messages, + add_message_to_conversation, + get_character_list_for_ui, + extract_character_id_from_ui_choice, + load_character_and_image, + process_db_messages_to_ui_history, + load_chat_and_character, + load_character_wrapper, + import_and_save_character_from_file, + load_chat_history_from_file_and_save_to_db, + start_new_chat_session, + list_character_conversations, + update_conversation_metadata, + post_message_to_conversation, + retrieve_conversation_messages_for_ui +) + + +# --- Standalone Fixtures (No conftest.py) --- + +@pytest.fixture +def client_id(): + """Provides a consistent client ID for tests.""" + return "test_lib_client_001" + + +@pytest.fixture +def db_path(tmp_path): + """Provides a temporary path for the database file for each test.""" + return tmp_path / "test_lib_db.sqlite" + + +@pytest.fixture(scope="function") +def db_instance(db_path, client_id): + """Creates a DB instance for each test, ensuring a fresh database.""" + db = CharactersRAGDB(db_path, client_id) + yield db + db.close_connection() + + +# --- Helper Functions --- + +def create_dummy_png_bytes() -> bytes: + """Creates a 1x1 black PNG image in memory.""" + img = Image.new('RGB', (1, 1), color='black') + byte_arr = io.BytesIO() + img.save(byte_arr, format='PNG') + return byte_arr.getvalue() + + +# FIXME +# def create_dummy_png_with_chara(chara_json_str: str) -> bytes: +# """Creates a 1x1 PNG with embedded 'chara' metadata.""" +# img = Image.new('RGB', (1, 1), color='red') +# # The 'chara' metadata is a base64 encoded string of the JSON +# chara_b64 = base64.b64encode(chara_json_str.encode('utf-8')).decode('utf-8') +# +# byte_arr = io.BytesIO() +# # Pillow saves metadata in the 'info' dictionary for PNGs +# img.save(byte_arr, format='PNG', pnginfo=Image.PngImagePlugin.PngInfo()) +# byte_arr.seek(0) +# +# # Re-open to add the custom chunk, as 'info' doesn't directly map to chunks +# img_with_info = Image.open(byte_arr) +# img_with_info.info['chara'] = chara_b64 +# +# final_byte_arr = io.BytesIO() +# img_with_info.save(final_byte_arr, format='PNG') +# return final_byte_arr.getvalue() + + +def create_sample_v2_card_json(name: str) -> str: + """Returns a V2 character card as a JSON string.""" + card = { + "spec": "chara_card_v2", + "spec_version": "2.0", + "data": { + "name": name, + "description": "A test character from a V2 card.", + "personality": "Curious", + "scenario": "In a test.", + "first_mes": "Hello from the V2 card!", + "mes_example": "This is an example message.", + "creator_notes": "", + "system_prompt": "", + "post_history_instructions": "", + "tags": ["v2", "test"], + "creator": "Tester", + "character_version": "1.0", + "alternate_greetings": ["Hi there!", "Greetings!"] + } + } + return json.dumps(card) + + +# --- Test Classes --- + +class TestConversationManagement: + def test_create_conversation_with_defaults(self, db_instance: CharactersRAGDB): + # Relies on the default character (ID 1) created by the schema + conv_id = create_conversation(db_instance) + assert conv_id is not None + + details = db_instance.get_conversation_by_id(conv_id) + assert details is not None + assert details['character_id'] == 1 # DEFAULT_CHARACTER_ID + assert "Chat with Default Assistant" in details['title'] + + def test_create_conversation_with_initial_messages(self, db_instance: CharactersRAGDB): + char_id = db_instance.add_character_card({"name": "Talker"}) + initial_messages = [ + {'sender': 'User', 'content': 'Hi there!'}, + {'sender': 'AI', 'content': 'Hello, User!'} + ] + conv_id = create_conversation(db_instance, character_id=char_id, initial_messages=initial_messages) + assert conv_id is not None + + messages = db_instance.get_messages_for_conversation(conv_id) + assert len(messages) == 2 + assert messages[0]['sender'] == 'User' + assert messages[1]['sender'] == 'Talker' # Sender 'AI' is mapped to character name + + def test_get_conversation_details_and_messages(self, db_instance: CharactersRAGDB): + char_id = db_instance.add_character_card({"name": "DetailedChar"}) + conv_id = create_conversation(db_instance, character_id=char_id, title="Test Details") + add_message_to_conversation(db_instance, conv_id, "User", "A message") + + details = get_conversation_details_and_messages(db_instance, conv_id) + assert details is not None + assert details['metadata']['title'] == "Test Details" + assert details['character_name'] == "DetailedChar" + assert len(details['messages']) == 1 + assert details['messages'][0]['content'] == "A message" + + +class TestCharacterLoading: + def test_extract_character_id_from_ui_choice(self): + assert extract_character_id_from_ui_choice("My Character (ID: 123)") == 123 + assert extract_character_id_from_ui_choice("456") == 456 + with pytest.raises(ValueError): + extract_character_id_from_ui_choice("Invalid Format") + with pytest.raises(ValueError): + extract_character_id_from_ui_choice("") + + def test_load_character_and_image(self, db_instance: CharactersRAGDB): + image_bytes = create_dummy_png_bytes() + card_data = {"name": "ImgChar", "first_message": "Hello, {{user}}!", "image": image_bytes} + char_id = db_instance.add_character_card(card_data) + + char_data, history, img = load_character_and_image(db_instance, char_id, "Tester") + + assert char_data is not None + assert char_data['name'] == "ImgChar" + assert len(history) == 1 + assert history[0] == (None, "Hello, Tester!") # Placeholder replaced + assert isinstance(img, Image.Image) + + def test_process_db_messages_to_ui_history(self): + db_messages = [ + {"sender": "User", "content": "Msg 1"}, + {"sender": "MyChar", "content": "Reply 1"}, + {"sender": "User", "content": "Msg 2a"}, + {"sender": "User", "content": "Msg 2b"}, + {"sender": "MyChar", "content": "Reply 2"}, + ] + history = process_db_messages_to_ui_history(db_messages, "MyChar", "TestUser") + expected = [ + ("Msg 1", "Reply 1"), + ("Msg 2a", None), + ("Msg 2b", "Reply 2"), + ] + assert history == expected + + def test_load_chat_and_character(self, db_instance: CharactersRAGDB): + char_id = db_instance.add_character_card({"name": "Chatter", "first_message": "Hi"}) + conv_id = create_conversation(db_instance, character_id=char_id) + add_message_to_conversation(db_instance, conv_id, "User", "Hello there") + add_message_to_conversation(db_instance, conv_id, "Chatter", "General Kenobi") + + char_data, history, img = load_chat_and_character(db_instance, conv_id, "TestUser") + + assert char_data['name'] == "Chatter" + assert len(history) == 2 # Initial 'Hi' plus the two added messages + assert history[0] == (None, 'Hi') + assert history[1] == ("Hello there", "General Kenobi") + + def test_load_character_wrapper(self, db_instance: CharactersRAGDB): + char_id = db_instance.add_character_card({"name": "WrappedChar"}) + # Test with int ID + char_data_int, _, _ = load_character_wrapper(db_instance, char_id, "User") + assert char_data_int['id'] == char_id + # Test with UI string + char_data_str, _, _ = load_character_wrapper(db_instance, f"WrappedChar (ID: {char_id})", "User") + assert char_data_str['id'] == char_id + + +class TestCharacterImport: + def test_import_from_json_string(self, db_instance: CharactersRAGDB): + card_name = "ImportedV2Char" + v2_json = create_sample_v2_card_json(card_name) + + file_obj = io.BytesIO(v2_json.encode('utf-8')) + char_id = import_and_save_character_from_file(db_instance, file_obj) + + assert char_id is not None + retrieved = db_instance.get_character_card_by_id(char_id) + assert retrieved['name'] == card_name + assert retrieved['description'] == "A test character from a V2 card." + + def test_import_from_png_with_chara_metadata(self, db_instance: CharactersRAGDB): + card_name = "PngChar" + v2_json = create_sample_v2_card_json(card_name) + png_bytes = create_dummy_png_with_chara(v2_json) + + file_obj = io.BytesIO(png_bytes) + char_id = import_and_save_character_from_file(db_instance, file_obj) + + assert char_id is not None + retrieved = db_instance.get_character_card_by_id(char_id) + assert retrieved['name'] == card_name + assert retrieved['image'] is not None + + def test_import_chat_history_and_save(self, db_instance: CharactersRAGDB): + # 1. Create the character that the chat log refers to + char_name = "LogChar" + char_id = db_instance.add_character_card({"name": char_name}) + + # 2. Create a sample chat log JSON + chat_log = { + "char_name": char_name, + "history": { + "internal": [ + ["Hello", "Hi there"], + ["How are you?", "I am a test, I am fine."] + ] + } + } + chat_log_json = json.dumps(chat_log) + file_obj = io.BytesIO(chat_log_json.encode('utf-8')) + + # 3. Import the log + conv_id, new_char_id = load_chat_history_from_file_and_save_to_db(db_instance, file_obj) + + assert conv_id is not None + assert new_char_id == char_id + + # 4. Verify messages were saved + messages = db_instance.get_messages_for_conversation(conv_id) + assert len(messages) == 4 + assert messages[0]['content'] == "Hello" + assert messages[1]['content'] == "Hi there" + assert messages[1]['sender'] == char_name + + +class TestHighLevelChatFlow: + def test_start_new_chat_session(self, db_instance: CharactersRAGDB): + char_id = db_instance.add_character_card({"name": "SessionStart", "first_message": "Greetings!"}) + + conv_id, char_data, ui_history, img = start_new_chat_session(db_instance, char_id, "TestUser") + + assert conv_id is not None + assert char_data['name'] == "SessionStart" + assert ui_history == [(None, "Greetings!")] + + # Verify conversation and first message were saved to DB + conv_details = db_instance.get_conversation_by_id(conv_id) + assert conv_details['character_id'] == char_id + + messages = db_instance.get_messages_for_conversation(conv_id) + assert len(messages) == 1 + assert messages[0]['content'] == "Greetings!" + assert messages[0]['sender'] == "SessionStart" + + def test_post_message_to_conversation(self, db_instance: CharactersRAGDB): + char_id = db_instance.add_character_card({"name": "Poster"}) + conv_id = db_instance.add_conversation({"character_id": char_id}) + + # Post user message + user_msg_id = post_message_to_conversation(db_instance, conv_id, "Poster", "User msg", is_user_message=True) + assert user_msg_id is not None + # Post character message + char_msg_id = post_message_to_conversation(db_instance, conv_id, "Poster", "Char reply", is_user_message=False) + assert char_msg_id is not None + + messages = db_instance.get_messages_for_conversation(conv_id) + assert len(messages) == 2 + assert messages[0]['sender'] == "User" + assert messages[1]['sender'] == "Poster" + + def test_retrieve_conversation_messages_for_ui(self, db_instance: CharactersRAGDB): + char_id = db_instance.add_character_card({"name": "UIRetriever"}) + conv_id = start_new_chat_session(db_instance, char_id, "TestUser")[0] # Get conv_id + post_message_to_conversation(db_instance, conv_id, "UIRetriever", "User message", True) + post_message_to_conversation(db_instance, conv_id, "UIRetriever", "Bot reply", False) + + ui_history = retrieve_conversation_messages_for_ui(db_instance, conv_id, "UIRetriever", "TestUser") + + # Initial message from start_new_chat_session + 1 pair + assert len(ui_history) == 2 + assert ui_history[1] == ("User message", "Bot reply") \ No newline at end of file diff --git a/Tests/Chat/test_chat_functions.py b/Tests/Chat/test_chat_functions.py index 690a6e33..0bec02a6 100644 --- a/Tests/Chat/test_chat_functions.py +++ b/Tests/Chat/test_chat_functions.py @@ -1,739 +1,338 @@ -# tests/unit/core/chat/test_chat_functions.py -# Description: Unit tests for chat functions in the tldw_app.Chat module. +# test_chat_functions.py # # Imports +import pytest import base64 -import re -import os # For tmp_path with post_gen_replacement_dict -import textwrap +import io from unittest.mock import patch, MagicMock # # 3rd-party Libraries -import pytest -from hypothesis import given, strategies as st, settings, HealthCheck -from typing import Optional -import requests # For mocking requests.exceptions +import requests +from PIL import Image # # Local Imports -# Ensure these paths are correct for your project structure +from tldw_chatbook.DB.ChaChaNotes_DB import CharactersRAGDB, ConflictError, InputError from tldw_chatbook.Chat.Chat_Functions import ( chat_api_call, chat, save_chat_history_to_db_wrapper, - API_CALL_HANDLERS, - PROVIDER_PARAM_MAP, - load_characters, save_character, - ChatDictionary, + load_characters, + get_character_names, parse_user_dict_markdown_file, - process_user_input, # Added for direct testing if needed - # Import the actual LLM handler functions if you intend to test them directly (though mostly tested via chat_api_call) - # e.g., _chat_with_openai_compatible_local_server, chat_with_kobold, etc. - # For unit testing chat_api_call, we mock these handlers. - # For unit testing chat, we mock chat_api_call. + process_user_input, + ChatDictionary, + DEFAULT_CHARACTER_NAME ) from tldw_chatbook.Chat.Chat_Deps import ( - ChatAuthenticationError, ChatRateLimitError, ChatBadRequestError, - ChatConfigurationError, ChatProviderError, ChatAPIError + ChatBadRequestError, + ChatAuthenticationError, + ChatRateLimitError, + ChatProviderError, + ChatAPIError ) -from tldw_chatbook.DB.ChaChaNotes_DB import CharactersRAGDB # For mocking - -# Placeholder for load_settings if it's not directly in Chat_Functions or needs specific mocking path -# from tldw_app.Chat.Chat_Functions import load_settings # Already imported effectively - -# Define a common set of known providers for hypothesis strategies -KNOWN_PROVIDERS = list(API_CALL_HANDLERS.keys()) -if not KNOWN_PROVIDERS: # Should not happen if API_CALL_HANDLERS is populated - KNOWN_PROVIDERS = ["openai", "anthropic", "ollama"] # Fallback for safety - -######################################################################################################################## # -# Fixtures - -class ChatErrorBase(Exception): - def __init__(self, provider: str, message: str, status_code: Optional[int] = None): - self.provider = provider - self.message = message - self.status_code = status_code - super().__init__(f"[{provider}] {message}" + (f" (HTTP {status_code})" if status_code else "")) - -@pytest.fixture(autouse=True) -def mock_load_settings_globally(): - """ - Mocks load_settings used by Chat_Functions.py (e.g., in LLM handlers and chat function for post-gen dict). - This ensures that tests don't rely on actual configuration files. - """ - # This default mock should provide minimal valid config for most providers - # to avoid KeyError if a handler tries to access its specific config section. - default_provider_config = { - "api_ip": "http://mock.host:1234", - "api_url": "http://mock.host:1234", - "api_key": "mock_api_key_from_settings", - "model": "mock_model_from_settings", - "temperature": 0.7, - "streaming": False, - "max_tokens": 1024, - "n_predict": 1024, # for llama - "num_predict": 1024, # for ollama - "max_length": 150, # for kobold - "api_timeout": 60, - "api_retries": 1, - "api_retry_delay": 1, - # Add other common keys that might be accessed to avoid KeyErrors in handlers - } - mock_settings_data = { - "chat_dictionaries": { # For `chat` function's post-gen replacement - "post_gen_replacement": "False", # Default to off unless a test enables it - "post_gen_replacement_dict": "dummy_path.md" - }, - # Provide a default config for all known providers to prevent KeyErrors - # when handlers try to load their specific sections via cfg.get(...) - **{f"{provider}_api": default_provider_config.copy() for provider in KNOWN_PROVIDERS}, - # Specific overrides if needed by a default test case - "local_llm": default_provider_config.copy(), - "llama_api": default_provider_config.copy(), - "kobold_api": default_provider_config.copy(), - "ooba_api": default_provider_config.copy(), - "tabby_api": default_provider_config.copy(), - "vllm_api": default_provider_config.copy(), - "aphrodite_api": default_provider_config.copy(), - "ollama_api": default_provider_config.copy(), - "custom_openai_api": default_provider_config.copy(), - "custom_openai_api_2": default_provider_config.copy(), - "openai": default_provider_config.copy(), # For actual OpenAI if it uses load_settings - "anthropic": default_provider_config.copy(), - } - # Ensure all keys from PROVIDER_PARAM_MAP also have a generic config section if a handler uses it - for provider_key in KNOWN_PROVIDERS: - if provider_key not in mock_settings_data: # e.g. if provider is 'openai' not 'openai_api' - mock_settings_data[provider_key] = default_provider_config.copy() - - - with patch("tldw_app.Chat.Chat_Functions.load_settings", return_value=mock_settings_data) as mock_load: - yield mock_load +####################################################################################################################### +# +# --- Standalone Fixtures (No conftest.py) --- + +@pytest.fixture +def client_id(): + """Provides a consistent client ID for tests.""" + return "test_chat_func_client" @pytest.fixture -def mock_llm_handlers(): # Renamed for clarity, same functionality - original_handlers = API_CALL_HANDLERS.copy() - mocked_handlers_dict = {} - for provider_name, original_func in original_handlers.items(): - mock_handler = MagicMock(name=f"mock_{getattr(original_func, '__name__', provider_name)}") - # Preserve the signature or relevant attributes if necessary, but MagicMock is often enough - mock_handler.__name__ = getattr(original_func, '__name__', f"mock_{provider_name}_handler") - mocked_handlers_dict[provider_name] = mock_handler - - with patch("tldw_app.Chat.Chat_Functions.API_CALL_HANDLERS", new=mocked_handlers_dict): - yield mocked_handlers_dict - -# --- Tests for chat_api_call --- -@pytest.mark.unit -def test_chat_api_call_routing_and_param_mapping_openai(mock_llm_handlers): - provider = "openai" - mock_openai_handler = mock_llm_handlers[provider] - mock_openai_handler.return_value = "OpenAI success" - - args = { - "api_endpoint": provider, - "messages_payload": [{"role": "user", "content": "Hi OpenAI"}], - "api_key": "test_openai_key", - "temp": 0.5, - "system_message": "Be concise.", - "streaming": False, - "maxp": 0.9, # Generic name, maps to top_p for openai - "model": "gpt-4o-mini", - "tools": [{"type": "function", "function": {"name": "get_weather"}}], - "tool_choice": "auto", - "seed": 123, - "response_format": {"type": "json_object"}, - "logit_bias": {"123": 10}, - } - result = chat_api_call(**args) - assert result == "OpenAI success" - mock_openai_handler.assert_called_once() - called_kwargs = mock_openai_handler.call_args.kwargs - - param_map_for_provider = PROVIDER_PARAM_MAP[provider] - - # Check that generic params were mapped to provider-specific names in the handler call - assert called_kwargs[param_map_for_provider['messages_payload']] == args["messages_payload"] - assert called_kwargs[param_map_for_provider['api_key']] == args["api_key"] - assert called_kwargs[param_map_for_provider['temp']] == args["temp"] - assert called_kwargs[param_map_for_provider['system_message']] == args["system_message"] - assert called_kwargs[param_map_for_provider['streaming']] == args["streaming"] - assert called_kwargs[param_map_for_provider['maxp']] == args["maxp"] # 'maxp' generic maps to 'top_p' (OpenAI specific) - assert called_kwargs[param_map_for_provider['model']] == args["model"] - assert called_kwargs[param_map_for_provider['tools']] == args["tools"] - assert called_kwargs[param_map_for_provider['tool_choice']] == args["tool_choice"] - assert called_kwargs[param_map_for_provider['seed']] == args["seed"] - assert called_kwargs[param_map_for_provider['response_format']] == args["response_format"] - assert called_kwargs[param_map_for_provider['logit_bias']] == args["logit_bias"] - - -@pytest.mark.unit -def test_chat_api_call_routing_and_param_mapping_anthropic(mock_llm_handlers): - provider = "anthropic" - mock_anthropic_handler = mock_llm_handlers[provider] - mock_anthropic_handler.return_value = "Anthropic success" - - args = { - "api_endpoint": provider, - "messages_payload": [{"role": "user", "content": "Hi Anthropic"}], - "api_key": "test_anthropic_key", - "temp": 0.6, - "system_message": "Be friendly.", - "streaming": True, - "model": "claude-3-opus-20240229", - "topp": 0.92, # Generic name, maps to top_p for anthropic - "topk": 50, - "max_tokens": 100, # Generic, maps to max_tokens for anthropic - "stop": ["\nHuman:", "\nAssistant:"] # Generic, maps to stop_sequences - } - result = chat_api_call(**args) - assert result == "Anthropic success" - mock_anthropic_handler.assert_called_once() - called_kwargs = mock_anthropic_handler.call_args.kwargs - - param_map = PROVIDER_PARAM_MAP[provider] - assert called_kwargs[param_map['messages_payload']] == args["messages_payload"] - assert called_kwargs[param_map['api_key']] == args["api_key"] - assert called_kwargs[param_map['temp']] == args["temp"] - assert called_kwargs[param_map['system_message']] == args["system_message"] # maps to 'system_prompt' - assert called_kwargs[param_map['streaming']] == args["streaming"] - assert called_kwargs[param_map['model']] == args["model"] - assert called_kwargs[param_map['topp']] == args["topp"] - assert called_kwargs[param_map['topk']] == args["topk"] - assert called_kwargs[param_map['max_tokens']] == args["max_tokens"] - assert called_kwargs[param_map['stop']] == args["stop"] - - -@pytest.mark.unit -def test_chat_api_call_unsupported_provider(): - with pytest.raises(ValueError, match="Unsupported API endpoint: non_existent_provider"): - chat_api_call(api_endpoint="non_existent_provider", messages_payload=[]) - - -@pytest.mark.unit -@pytest.mark.parametrize("raised_exception, expected_custom_error_type, expected_status_code_in_error", [ - (requests.exceptions.HTTPError(response=MagicMock(status_code=401, text="Auth error text")), ChatAuthenticationError, 401), - (requests.exceptions.HTTPError(response=MagicMock(status_code=429, text="Rate limit text")), ChatRateLimitError, 429), - (requests.exceptions.HTTPError(response=MagicMock(status_code=400, text="Bad req text")), ChatBadRequestError, 400), - (requests.exceptions.HTTPError(response=MagicMock(status_code=503, text="Provider down text")), ChatProviderError, 503), - (requests.exceptions.ConnectionError("Network fail"), ChatProviderError, 504), # Default for RequestException - (ValueError("Internal value error"), ChatBadRequestError, None), # Status code might not be set by default for these - (TypeError("Internal type error"), ChatBadRequestError, None), - (KeyError("Internal key error"), ChatBadRequestError, None), - (ChatConfigurationError("config issue", provider="openai"), ChatConfigurationError, None), # Direct raise - (Exception("Very generic error"), ChatAPIError, 500), -]) -def test_chat_api_call_exception_mapping( - mock_llm_handlers, - raised_exception, expected_custom_error_type, expected_status_code_in_error -): - provider_to_test = "openai" # Use any valid provider name that is mocked - mock_handler = mock_llm_handlers[provider_to_test] - mock_handler.side_effect = raised_exception - - with pytest.raises(expected_custom_error_type) as exc_info: - chat_api_call(api_endpoint=provider_to_test, messages_payload=[{"role": "user", "content": "test"}]) - - assert exc_info.value.provider == provider_to_test - if expected_status_code_in_error is not None and hasattr(exc_info.value, 'status_code'): - assert exc_info.value.status_code == expected_status_code_in_error - - # Check that original error message part is in the custom error message if applicable - if hasattr(raised_exception, 'response') and hasattr(raised_exception.response, 'text'): - assert raised_exception.response.text[:100] in exc_info.value.message # Check beginning of text - elif not isinstance(raised_exception, ChatErrorBase): # Don't double-check message for already custom errors - assert str(raised_exception) in exc_info.value.message - - -# --- Tests for the `chat` function (multimodal chat coordinator) --- - -@pytest.mark.unit -@patch("tldw_app.Chat.Chat_Functions.chat_api_call") -@patch("tldw_app.Chat.Chat_Functions.process_user_input", side_effect=lambda text, *args, **kwargs: text) -# mock_load_settings_globally is active via autouse=True -def test_chat_function_basic_text_call(mock_process_input, mock_chat_api_call_shim): - mock_chat_api_call_shim.return_value = "LLM Response from chat function" - - response = chat( - message="Hello LLM", - history=[], - media_content=None, selected_parts=[], api_endpoint="test_provider_for_chat", - api_key="test_key_for_chat", custom_prompt="Be very brief.", temperature=0.1, - system_message="You are a test bot for chat.", - llm_seed=42, llm_max_tokens=100, llm_user_identifier="user123" - ) - assert response == "LLM Response from chat function" - mock_chat_api_call_shim.assert_called_once() - call_args = mock_chat_api_call_shim.call_args.kwargs - - assert call_args["api_endpoint"] == "test_provider_for_chat" - assert call_args["api_key"] == "test_key_for_chat" - assert call_args["temp"] == 0.1 - assert call_args["system_message"] == "You are a test bot for chat." - assert call_args["seed"] == 42 # Check new llm_param - assert call_args["max_tokens"] == 100 # Check new llm_param - assert call_args["user_identifier"] == "user123" # Check new llm_param - - - payload = call_args["messages_payload"] - assert len(payload) == 1 - assert payload[0]["role"] == "user" - assert isinstance(payload[0]["content"], list) - assert len(payload[0]["content"]) == 1 - assert payload[0]["content"][0]["type"] == "text" - # Custom prompt is prepended to user message - assert payload[0]["content"][0]["text"] == "Be very brief.\n\nHello LLM" - - -@pytest.mark.unit -@patch("tldw_app.Chat.Chat_Functions.chat_api_call") -@patch("tldw_app.Chat.Chat_Functions.process_user_input", side_effect=lambda x, *a, **kw: x) -def test_chat_function_with_text_history(mock_process_input, mock_chat_api_call_shim): - mock_chat_api_call_shim.return_value = "LLM Response with history" - history_for_chat_func = [ - {"role": "user", "content": "Previous question?"}, # Will be wrapped - {"role": "assistant", "content": [{"type": "text", "text": "Previous answer."}]} # Already wrapped - ] - response = chat( - message="New question", history=history_for_chat_func, media_content=None, - selected_parts=[], api_endpoint="hist_provider", api_key="hist_key", - custom_prompt=None, temperature=0.2, system_message="Sys History" - ) - assert response == "LLM Response with history" - mock_chat_api_call_shim.assert_called_once() - payload = mock_chat_api_call_shim.call_args.kwargs["messages_payload"] - assert len(payload) == 3 - assert payload[0]["content"][0]["type"] == "text" - assert payload[0]["content"][0]["text"] == "Previous question?" - assert payload[1]["content"][0]["type"] == "text" - assert payload[1]["content"][0]["text"] == "Previous answer." - assert payload[2]["content"][0]["type"] == "text" - assert payload[2]["content"][0]["text"] == "New question" - - -@pytest.mark.unit -@patch("tldw_app.Chat.Chat_Functions.chat_api_call") -@patch("tldw_app.Chat.Chat_Functions.process_user_input", side_effect=lambda x, *a, **kw: x) -def test_chat_function_with_current_image(mock_process_input, mock_chat_api_call_shim): - mock_chat_api_call_shim.return_value = "LLM image Response" - current_image = {"base64_data": "fakeb64imagedata", "mime_type": "image/png"} - - response = chat( - message="What is this image?", history=[], media_content=None, selected_parts=[], - api_endpoint="img_provider", api_key="img_key", custom_prompt=None, temperature=0.3, - current_image_input=current_image - ) - assert response == "LLM image Response" - mock_chat_api_call_shim.assert_called_once() - payload = mock_chat_api_call_shim.call_args.kwargs["messages_payload"] - assert len(payload) == 1 - user_content_parts = payload[0]["content"] - assert isinstance(user_content_parts, list) - text_part_found = any(p["type"] == "text" and p["text"] == "What is this image?" for p in user_content_parts) - image_part_found = any(p["type"] == "image_url" and p["image_url"]["url"] == "" for p in user_content_parts) - assert text_part_found and image_part_found - assert len(user_content_parts) == 2 # one text, one image - - -@pytest.mark.unit -@patch("tldw_app.Chat.Chat_Functions.chat_api_call") -@patch("tldw_app.Chat.Chat_Functions.process_user_input", side_effect=lambda x, *a, **kw: x) -def test_chat_function_image_history_tag_past(mock_process_input, mock_chat_api_call_shim): - mock_chat_api_call_shim.return_value = "Tagged image history response" - history_with_image = [ - {"role": "user", "content": [ - {"type": "text", "text": "Here is an image."}, - {"type": "image_url", "image_url": {"url": ""}} - ]}, - {"role": "assistant", "content": "I see the image."} - ] - response = chat( - message="What about that previous image?", - media_content=None, - selected_parts=[], - api_endpoint="tag_provider", - history=history_with_image, - api_key="tag_key", - custom_prompt=None, - temperature=0.4, - image_history_mode="tag_past" - ) - payload = mock_chat_api_call_shim.call_args.kwargs["messages_payload"] - assert len(payload) == 3 # 2 history items processed + 1 current - - # First user message from history - user_hist_content = payload[0]["content"] - assert isinstance(user_hist_content, list) - assert {"type": "text", "text": "Here is an image."} in user_hist_content - assert {"type": "text", "text": ""} in user_hist_content - assert not any(p["type"] == "image_url" for p in user_hist_content) # Image should be replaced by tag - - # Assistant message from history - assistant_hist_content = payload[1]["content"] - assert assistant_hist_content[0]["text"] == "I see the image." - - -@pytest.mark.unit -@patch("tldw_app.Chat.Chat_Functions.chat_api_call") -@patch("tldw_app.Chat.Chat_Functions.process_user_input", side_effect=lambda x, *args, **kwargs: x) -def test_chat_function_streaming_passthrough(mock_process_input, mock_chat_api_call_shim): - def dummy_stream_gen(): - yield "stream chunk 1" - yield "stream chunk 2" - mock_chat_api_call_shim.return_value = dummy_stream_gen() - - response_gen = chat( - message="Stream this", - media_content=None, - selected_parts=[], - api_endpoint="stream_provider", - history=[], - api_key="key", - custom_prompt=None, - temperature=0.1, - streaming=True - ) - assert hasattr(response_gen, '__iter__') - result = list(response_gen) - assert result == ["stream chunk 1", "stream chunk 2"] - mock_chat_api_call_shim.assert_called_once() - assert mock_chat_api_call_shim.call_args.kwargs["streaming"] is True - - -# --- Tests for save_chat_history_to_db_wrapper --- -# These tests seem okay with the TUI context as they mock the DB. - -@pytest.mark.unit -@patch("tldw_app.Chat.Chat_Functions.DEFAULT_CHARACTER_NAME", "TestDefaultChar") -def test_save_chat_history_new_conversation_default_char(mock_load_settings_globally): # Ensure settings mock is active if needed by save_chat - mock_db = MagicMock(spec=CharactersRAGDB) - mock_db.client_id = "unit_test_client" - mock_db.get_character_card_by_name.return_value = {"id": 99, "name": "TestDefaultChar", "version":1} - mock_db.add_conversation.return_value = "new_conv_id_123" - mock_db.transaction.return_value.__enter__.return_value = None # for 'with db.transaction():' - - history_to_save = [ - {"role": "user", "content": "Hello, default character!"}, - {"role": "assistant", "content": [{"type": "text", "text": "Hello, user!"}, {"type": "image_url", "image_url": { - "url": ""}}]} - ] - conv_id, message = save_chat_history_to_db_wrapper( - db=mock_db, - chatbot_history=history_to_save, - conversation_id=None, - media_content_for_char_assoc=None, - character_name_for_chat=None - ) - assert conv_id == "new_conv_id_123" - assert "success" in message.lower() - mock_db.get_character_card_by_name.assert_called_once_with("TestDefaultChar") - # ... (rest of assertions from your original test are likely still valid) - first_message_call_args = mock_db.add_message.call_args_list[0].args[0] - assert first_message_call_args["image_data"] is None - - second_message_call_args = mock_db.add_message.call_args_list[1].args[0] - assert second_message_call_args["image_mime_type"] == "image/gif" - assert isinstance(second_message_call_args["image_data"], bytes) - - -@pytest.mark.unit -def test_save_chat_history_resave_conversation_specific_char(mock_load_settings_globally): - mock_db = MagicMock(spec=CharactersRAGDB) - mock_db.client_id = "unit_test_client_resave" - existing_conv_id = "existing_conv_456" - char_id_for_resave = 77 - char_name_for_resave = "SpecificResaveChar" - mock_db.get_character_card_by_name.return_value = {"id": char_id_for_resave, "name": char_name_for_resave, "version": 1} - mock_db.get_conversation_by_id.return_value = {"id": existing_conv_id, "character_id": char_id_for_resave, "title": "Old Title", "version": 2} - mock_db.get_messages_for_conversation.return_value = [{"id": "msg1", "version": 1}, {"id": "msg2", "version": 1}] - mock_db.transaction.return_value.__enter__.return_value = None - - history_to_resave = [{"role": "user", "content": "Updated question for resave."}] - conv_id, message = save_chat_history_to_db_wrapper( - db=mock_db, - chatbot_history=history_to_resave, - conversation_id=existing_conv_id, - media_content_for_char_assoc=None, - character_name_for_chat=char_name_for_resave - ) - assert conv_id == existing_conv_id - assert "success" in message.lower() - # ... (rest of assertions from your original test) - - -# --- Chat Dictionary and Character Save/Load Tests --- -# These tests seem okay with the TUI context. - -@pytest.mark.unit -def test_parse_user_dict_markdown_file_various_formats(tmp_path): - # ... (your original test content is good) - md_content = textwrap.dedent(""" +def db_path(tmp_path): + """Provides a temporary path for the database file for each test.""" + return tmp_path / "test_chat_func_db.sqlite" + + +@pytest.fixture(scope="function") +def db_instance(db_path, client_id): + """Creates a DB instance for each test, ensuring a fresh database.""" + db = CharactersRAGDB(db_path, client_id) + yield db + db.close_connection() + + +# --- Helper Functions --- + +def create_base64_image(): + """Creates a dummy 1x1 png and returns its base64 string.""" + img_bytes = io.BytesIO() + Image.new('RGB', (1, 1)).save(img_bytes, format='PNG') + return base64.b64encode(img_bytes.getvalue()).decode('utf-8') + + +# --- Test Classes --- + +@patch('tldw_chatbook.Chat.Chat_Functions.API_CALL_HANDLERS') +class TestChatApiCall: + def test_routes_to_correct_handler(self, mock_handlers, mocker): + mock_openai_handler = mocker.MagicMock(return_value="OpenAI response") + mock_handlers.get.return_value = mock_openai_handler + + response = chat_api_call( + api_endpoint="openai", + messages_payload=[{"role": "user", "content": "test"}], + model="gpt-4" + ) + + mock_handlers.get.assert_called_with("openai") + mock_openai_handler.assert_called_once() + kwargs = mock_openai_handler.call_args.kwargs + assert kwargs['input_data'][0]['content'] == "test" # Mapped to 'input_data' for openai + assert kwargs['model'] == "gpt-4" + assert response == "OpenAI response" + + def test_unsupported_endpoint_raises_error(self, mock_handlers): + mock_handlers.get.return_value = None + with pytest.raises(ValueError, match="Unsupported API endpoint: unsupported"): + chat_api_call("unsupported", messages_payload=[]) + + def test_http_error_401_raises_auth_error(self, mock_handlers, mocker): + mock_response = MagicMock() + mock_response.status_code = 401 + mock_response.text = "Invalid API key" + http_error = requests.exceptions.HTTPError(response=mock_response) + + mock_handler = mocker.MagicMock(side_effect=http_error) + mock_handlers.get.return_value = mock_handler + + with pytest.raises(ChatAuthenticationError): + chat_api_call("openai", messages_payload=[]) + + +class TestChatFunction: + @patch('tldw_chatbook.Chat.Chat_Functions.chat_api_call') + def test_chat_basic_flow(self, mock_chat_api_call): + mock_chat_api_call.return_value = "LLM says hi" + + response = chat( + message="Hello", + history=[], + media_content=None, + selected_parts=[], + api_endpoint="openai", + api_key="sk-123", + model="gpt-4", + temperature=0.7, + custom_prompt="Be brief." + ) + + assert response == "LLM says hi" + mock_chat_api_call.assert_called_once() + kwargs = mock_chat_api_call.call_args.kwargs + + assert kwargs['api_endpoint'] == 'openai' + assert kwargs['model'] == 'gpt-4' + payload = kwargs['messages_payload'] + assert len(payload) == 1 + assert payload[0]['role'] == 'user' + user_content = payload[0]['content'] + assert isinstance(user_content, list) + assert user_content[0]['type'] == 'text' + assert user_content[0]['text'] == "Be brief.\n\nHello" + + @patch('tldw_chatbook.Chat.Chat_Functions.chat_api_call') + def test_chat_with_image_and_rag(self, mock_chat_api_call): + b64_img = create_base64_image() + + chat( + message="Describe this.", + history=[], + media_content={"summary": "This is a summary."}, + selected_parts=["summary"], + api_endpoint="openai", + api_key="sk-123", + model="gpt-4-vision-preview", + temperature=0.5, + current_image_input={'base64_data': b64_img, 'mime_type': 'image/png'}, + custom_prompt=None + ) + + kwargs = mock_chat_api_call.call_args.kwargs + payload = kwargs['messages_payload'] + user_content_parts = payload[0]['content'] + + assert len(user_content_parts) == 2 # RAG text + image + + text_part = next(p for p in user_content_parts if p['type'] == 'text') + image_part = next(p for p in user_content_parts if p['type'] == 'image_url') + + assert "Summary: This is a summary." in text_part['text'] + assert "Describe this." in text_part['text'] + assert image_part['image_url']['url'].startswith("data:image/png;base64,") + + @patch('tldw_chatbook.Chat.Chat_Functions.chat_api_call') + def test_chat_adapts_payload_for_deepseek(self, mock_chat_api_call): + chat( + message="Hello", + history=[ + {"role": "user", "content": [{"type": "text", "text": "Old message"}, + {"type": "image_url", "image_url": {"url": "data:..."}}]}, + {"role": "assistant", "content": "Old reply"} + ], + media_content=None, + selected_parts=[], + api_endpoint="deepseek", # The endpoint that needs adaptation + api_key="sk-123", + model="deepseek-chat", + temperature=0.7, + custom_prompt=None, + image_history_mode="tag_past" + ) + + kwargs = mock_chat_api_call.call_args.kwargs + adapted_payload = kwargs['messages_payload'] + + # Check that all content fields are strings, not lists of parts + assert isinstance(adapted_payload[0]['content'], str) + assert adapted_payload[0]['content'] == "Old message\n" + assert isinstance(adapted_payload[1]['content'], str) + assert adapted_payload[1]['content'] == "Old reply" + assert isinstance(adapted_payload[2]['content'], str) + assert adapted_payload[2]['content'] == "Hello" + + +class TestChatHistorySaving: + def test_save_chat_history_to_db_new_conversation(self, db_instance: CharactersRAGDB): + # The history format is now OpenAI's message objects + chatbot_history = [ + {"role": "user", "content": "Hello there"}, + {"role": "assistant", "content": "General Kenobi"} + ] + + # Uses default character + conv_id, status = save_chat_history_to_db_wrapper( + db=db_instance, + chatbot_history=chatbot_history, + conversation_id=None, + media_content_for_char_assoc=None, + character_name_for_chat=None + ) + + assert "success" in status.lower() + assert conv_id is not None + + messages = db_instance.get_messages_for_conversation(conv_id) + assert len(messages) == 2 + assert messages[0]['sender'] == 'user' + assert messages[1]['sender'] == 'assistant' + + conv_details = db_instance.get_conversation_by_id(conv_id) + assert conv_details['character_id'] == 1 # Default character + + def test_save_chat_history_with_image(self, db_instance: CharactersRAGDB): + b64_img = create_base64_image() + chatbot_history = [ + {"role": "user", "content": [ + {"type": "text", "text": "Look at this image"}, + {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{b64_img}"}} + ]}, + {"role": "assistant", "content": "I see a 1x1 black square."} + ] + + conv_id, status = save_chat_history_to_db_wrapper(db_instance, chatbot_history, None, None, None) + assert "success" in status.lower() + + messages = db_instance.get_messages_for_conversation(conv_id) + assert len(messages) == 2 + assert messages[0]['content'] == "Look at this image" + assert messages[0]['image_data'] is not None + assert messages[0]['image_mime_type'] == "image/png" + assert messages[1]['image_data'] is None + + def test_resave_chat_history(self, db_instance: CharactersRAGDB): + char_id = db_instance.add_character_card({"name": "Resaver"}) + initial_history = [{"role": "user", "content": "First message"}] + conv_id, _ = save_chat_history_to_db_wrapper(db_instance, initial_history, None, None, "Resaver") + + updated_history = [ + {"role": "user", "content": "New first message"}, + {"role": "assistant", "content": "New reply"} + ] + + # Resave with same conv_id + resave_id, status = save_chat_history_to_db_wrapper(db_instance, updated_history, conv_id, None, "Resaver") + assert "success" in status.lower() + assert resave_id == conv_id + + messages = db_instance.get_messages_for_conversation(conv_id) + assert len(messages) == 2 + assert messages[0]['content'] == "New first message" + + +class TestCharacterManagement: + def test_save_and_load_character(self, db_instance: CharactersRAGDB): + char_data = { + "name": "Super Coder", + "description": "A character that codes.", + "image": create_base64_image() + } + + char_id = save_character(db_instance, char_data) + assert isinstance(char_id, int) + + loaded_chars = load_characters(db_instance) + assert "Super Coder" in loaded_chars + loaded_char_data = loaded_chars["Super Coder"] + assert loaded_char_data['description'] == "A character that codes." + assert loaded_char_data['image_base64'] is not None + + def test_get_character_names(self, db_instance: CharactersRAGDB): + save_character(db_instance, {"name": "Beta"}) + save_character(db_instance, {"name": "Alpha"}) + + # Default character is also present + names = get_character_names(db_instance) + assert names == ["Alpha", "Beta", DEFAULT_CHARACTER_NAME] + + +class TestChatDictionary: + def test_parse_user_dict_markdown_file(self, tmp_path): + dict_content = """ key1: value1 key2: | - This is a - multi-line value for key2. - It has several lines. - # This comment line is part of key2's value. + This is a + multiline value. ---@@@--- - key_after_term: after_terminator_value - """).strip() - dict_file = tmp_path / "test_dict.md" - dict_file.write_text(md_content) - parsed = parse_user_dict_markdown_file(str(dict_file)) - expected_key2_value = ("This is a\n multi-line value for key2.\n It has several lines.\n# This comment line is part of key2's value.") - assert parsed.get("key1") == "value1" - assert parsed.get("key2") == expected_key2_value - assert parsed.get("key_after_term") == "after_terminator_value" - - -@pytest.mark.unit -def test_chat_dictionary_class_methods(): - # ... (your original test content is good) - entry_plain = ChatDictionary(key="hello", content="hi there") - entry_regex = ChatDictionary(key=r"/\bworld\b/i", content="planet") # Added /i for ignore case in regex - assert entry_plain.matches("hello world") - assert entry_regex.matches("Hello World!") - assert isinstance(entry_regex.key, re.Pattern) - assert entry_regex.key.flags & re.IGNORECASE - - -@pytest.mark.unit -@patch("tldw_app.Chat.Chat_Functions.chat_api_call") # Mock the inner chat_api_call used by `chat` -@patch("tldw_app.Chat.Chat_Functions.parse_user_dict_markdown_file") -# mock_load_settings_globally is active -def test_chat_function_with_chat_dictionary_post_replacement( - mock_parse_dict, mock_chat_api_call_inner_shim, tmp_path, mock_load_settings_globally -): - post_gen_dict_path = str(tmp_path / "post_gen.md") - # Override the global mock for this specific test case - mock_load_settings_globally.return_value = { - "chat_dictionaries": { - "post_gen_replacement": "True", - "post_gen_replacement_dict": post_gen_dict_path - }, - # Ensure other necessary default configs for the provider are present if chat_api_call or its handlers need them - "openai_api": {"api_key": "testkey"} # Example - } - - post_gen_dict_file = tmp_path / "post_gen.md" # Actual file creation - post_gen_dict_file.write_text("AI: Artificial Intelligence\nLLM: Large Language Model") - os.path.exists(post_gen_dict_path) # For the check in `chat` - - mock_parse_dict.return_value = {"AI": "Artificial Intelligence", "LLM": "Large Language Model"} - raw_llm_response = "The AI assistant uses an LLM." - mock_chat_api_call_inner_shim.return_value = raw_llm_response - - final_response = chat( - message="Tell me about AI.", - media_content=None, - selected_parts=[], - api_endpoint="openai", - api_key="testkey", - custom_prompt=None, - history=[], - temperature=0.7, - streaming=False - ) - expected_response = "The Artificial Intelligence assistant uses an Large Language Model." - assert final_response == expected_response - mock_parse_dict.assert_called_once_with(post_gen_dict_path) - - -@pytest.mark.unit -def test_save_character_new_and_update(): - # ... (your original test content is good) - mock_db = MagicMock(spec=CharactersRAGDB) - char_data_v1 = {"name": "TestCharacter", "description": "Hero.", "image": ""} - mock_db.get_character_card_by_name.return_value = None - mock_db.add_character_card.return_value = 1 - save_character(db=mock_db, character_data=char_data_v1) - mock_db.add_character_card.assert_called_once() - # ... more assertions ... - -@pytest.mark.unit -def test_load_characters_empty_and_with_data(): - # ... (your original test content is good) - mock_db = MagicMock(spec=CharactersRAGDB) - mock_db.list_character_cards.return_value = [] - assert load_characters(db=mock_db) == {} - # ... more assertions for data ... - -# --- Property-Based Tests --- - -# Helper strategy for generating message content (simple text or list of parts) -st_text_content = st.text(min_size=1, max_size=50) -st_image_url_part = st.fixed_dictionaries({ - "type": st.just("image_url"), - "image_url": st.fixed_dictionaries({ - "url": st.text(min_size=10, max_size=30).map(lambda s: f"data:image/png;base64,{s}") - }) -}) -st_text_part = st.fixed_dictionaries({"type": st.just("text"), "text": st_text_content}) -st_message_part = st.one_of(st_text_part, st_image_url_part) - -st_message_content_list = st.lists(st_message_part, min_size=1, max_size=3) -st_valid_message_content = st.one_of(st_text_content, st_message_content_list) - -st_message = st.fixed_dictionaries({ - "role": st.sampled_from(["user", "assistant"]), - "content": st_valid_message_content -}) -st_history = st.lists(st_message, max_size=5) - -# Strategy for optional float parameters like temperature, top_p -st_optional_float_0_to_1 = st.one_of(st.none(), st.floats(min_value=0.0, max_value=1.0)) -st_optional_float_0_to_2 = st.one_of(st.none(), st.floats(min_value=0.0, max_value=2.0)) # For penalties -st_optional_int_gt_0 = st.one_of(st.none(), st.integers(min_value=1, max_value=2048)) # For max_tokens, top_k - - -@given( - api_endpoint=st.sampled_from(KNOWN_PROVIDERS), - temp=st_optional_float_0_to_1, - system_message=st.one_of(st.none(), st.text(max_size=50)), - streaming=st.booleans(), - max_tokens=st_optional_int_gt_0, - seed=st.one_of(st.none(), st.integers()), - # Add more strategies for other chat_api_call params if desired -) -@settings(max_examples=20, deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture]) -def test_property_chat_api_call_param_passing( - mock_llm_handlers, # Fixture to mock the actual handlers - api_endpoint, temp, system_message, streaming, max_tokens, seed -): - """ - Tests that chat_api_call correctly routes to the mocked handler - and passes through known parameters, mapping them if necessary. - """ - mock_handler = mock_llm_handlers[api_endpoint] - mock_handler.reset_mock() - mock_handler.return_value = "Property test success" - param_map_for_provider = PROVIDER_PARAM_MAP.get(api_endpoint, {}) - - messages = [{"role": "user", "content": "Hypothesis test"}] - args_to_call = { - "api_endpoint": api_endpoint, - "messages_payload": messages, - "api_key": "prop_test_key", # Assuming all handlers take api_key mapped - "model": "prop_test_model", # Assuming all handlers take model mapped - } - # Add optional params to args_to_call only if they are not None - if temp is not None: args_to_call["temp"] = temp - if system_message is not None: args_to_call["system_message"] = system_message - if streaming is not None: args_to_call["streaming"] = streaming # streaming is not optional in signature, but can be None in map - if max_tokens is not None: args_to_call["max_tokens"] = max_tokens - if seed is not None: args_to_call["seed"] = seed - - result = chat_api_call(**args_to_call) - assert result == "Property test success" - mock_handler.assert_called_once() - called_kwargs = mock_handler.call_args.kwargs - - # Check messages_payload (or its mapped equivalent) - mapped_messages_key = param_map_for_provider.get('messages_payload', 'messages_payload') # Default if not in map - assert called_kwargs.get(mapped_messages_key) == messages - - # Check other params if they were passed and are in the map - if temp is not None and 'temp' in param_map_for_provider: - assert called_kwargs.get(param_map_for_provider['temp']) == temp - if system_message is not None and 'system_message' in param_map_for_provider: - assert called_kwargs.get(param_map_for_provider['system_message']) == system_message - if streaming is not None and 'streaming' in param_map_for_provider: - assert called_kwargs.get(param_map_for_provider['streaming']) == streaming - if max_tokens is not None and 'max_tokens' in param_map_for_provider: - assert called_kwargs.get(param_map_for_provider['max_tokens']) == max_tokens - if seed is not None and 'seed' in param_map_for_provider: - assert called_kwargs.get(param_map_for_provider['seed']) == seed - - -@given( - message=st.text(max_size=100), - history=st_history, - custom_prompt=st.one_of(st.none(), st.text(max_size=50)), - temperature=st.floats(min_value=0.0, max_value=1.0), - system_message=st.one_of(st.none(), st.text(max_size=50)), - streaming=st.booleans(), - llm_max_tokens=st_optional_int_gt_0, - llm_seed=st.one_of(st.none(), st.integers()), - image_history_mode=st.sampled_from(["send_all", "send_last_user_image", "tag_past", "ignore_past"]), - current_image_input=st.one_of( - st.none(), - st.fixed_dictionaries({ - "base64_data": st.text(min_size=5, max_size=20).map(lambda s: base64.b64encode(s.encode()).decode()), - "mime_type": st.sampled_from(["image/png", "image/jpeg"]) - }) - ) -) -@settings(max_examples=20, deadline=None) -@patch("tldw_app.Chat.Chat_Functions.chat_api_call") -@patch("tldw_app.Chat.Chat_Functions.process_user_input", side_effect=lambda x, *a, **kw: x) -# mock_load_settings_globally is active -def test_property_chat_function_payload_construction( - mock_process_input, mock_chat_api_call_shim, # Mocked dependencies first - message, history, custom_prompt, temperature, system_message, streaming, # Generated inputs - llm_max_tokens, llm_seed, image_history_mode, current_image_input -): - mock_chat_api_call_shim.return_value = "Property LLM Response" if not streaming else (lambda: (yield "Stream"))() - - response = chat( - message=message, history=history, media_content=None, selected_parts=[], - api_endpoint="prop_provider", api_key="prop_key", - custom_prompt=custom_prompt, temperature=temperature, system_message=system_message, - streaming=streaming, llm_max_tokens=llm_max_tokens, llm_seed=llm_seed, - image_history_mode=image_history_mode, current_image_input=current_image_input - ) - - if streaming: - assert hasattr(response, '__iter__') - list(response) # Consume - else: - assert response == "Property LLM Response" - - mock_chat_api_call_shim.assert_called_once() - call_args = mock_chat_api_call_shim.call_args.kwargs - - assert call_args["api_endpoint"] == "prop_provider" - assert call_args["temp"] == temperature - if system_message is not None: - assert call_args["system_message"] == system_message - assert call_args["streaming"] == streaming - if llm_max_tokens is not None: - assert call_args["max_tokens"] == llm_max_tokens - if llm_seed is not None: - assert call_args["seed"] == llm_seed - - payload = call_args["messages_payload"] - assert isinstance(payload, list) - if not payload: # Should not happen if message is non-empty, but good to check - assert not message and not history # Only if input is truly empty - return - - # Verify structure of the last message (current user input) - last_message_in_payload = payload[-1] - assert last_message_in_payload["role"] == "user" - assert isinstance(last_message_in_payload["content"], list) - - # Check if custom_prompt is prepended - expected_current_text = message - if custom_prompt: - expected_current_text = f"{custom_prompt}\n\n{expected_current_text}" - - text_part_found = any(p["type"] == "text" and p["text"] == expected_current_text.strip() for p in last_message_in_payload["content"]) - if not message and not custom_prompt and not current_image_input: # if no user text and no image - assert any(p["type"] == "text" and "(No user input for this turn)" in p["text"] for p in last_message_in_payload["content"]) - elif expected_current_text.strip() or (not expected_current_text.strip() and not current_image_input): # if only text or no text and no image - assert text_part_found or (not expected_current_text.strip() and not any(p["type"] == "text" for p in last_message_in_payload["content"])) - - - if current_image_input: - expected_image_url = f"data:{current_image_input['mime_type']};base64,{current_image_input['base64_data']}" - image_part_found = any(p["type"] == "image_url" and p["image_url"]["url"] == expected_image_url for p in last_message_in_payload["content"]) - assert image_part_found - - # Further checks on history processing (e.g., image_history_mode effects) could be added here, - # but they become complex for property tests. Unit tests are better for those specifics. + /key3/i: value3 + """ + dict_file = tmp_path / "test_dict.md" + dict_file.write_text(dict_content) + + parsed = parse_user_dict_markdown_file(str(dict_file)) + assert parsed["key1"] == "value1" + assert parsed["key2"] == "This is a\nmultiline value." + assert parsed["/key3/i"] == "value3" + + def test_process_user_input_simple_replacement(self): + entries = [ChatDictionary(key="hello", content="GREETING")] + user_input = "I said hello to the world." + result = process_user_input(user_input, entries) + assert result == "I said GREETING to the world." + + def test_process_user_input_regex_replacement(self): + entries = [ChatDictionary(key=r"/h[aeiou]llo/i", content="GREETING")] + user_input = "I said hallo and heLlo." + # It replaces only the first match + result = process_user_input(user_input, entries) + assert result == "I said GREETING and heLlo." + + def test_process_user_input_token_budget(self): + # Content is 4 tokens, budget is 3. Should not replace. + entries = [ChatDictionary(key="long", content="this is too long")] + user_input = "This is a long test." + result = process_user_input(user_input, entries, max_tokens=3) + assert result == "This is a long test." + + # Content is 3 tokens, budget is 3. Should replace. + entries = [ChatDictionary(key="short", content="this is fine")] + user_input = "This is a short test." + result = process_user_input(user_input, entries, max_tokens=3) + assert result == "This is a this is fine test." # # End of test_chat_functions.py diff --git a/Tests/Chat/test_prompt_template_manager.py b/Tests/Chat/test_prompt_template_manager.py index f5e71015..922411e0 100644 --- a/Tests/Chat/test_prompt_template_manager.py +++ b/Tests/Chat/test_prompt_template_manager.py @@ -1,142 +1,163 @@ -# Tests/Chat/test_prompt_template_manager.py +# test_prompt_template_manager.py + import pytest import json from pathlib import Path -from unittest.mock import patch, mock_open +# Local Imports from this project from tldw_chatbook.Chat.prompt_template_manager import ( PromptTemplate, load_template, apply_template_to_string, get_available_templates, - DEFAULT_RAW_PASSTHROUGH_TEMPLATE, - _loaded_templates # For clearing cache in tests + _loaded_templates, + DEFAULT_RAW_PASSTHROUGH_TEMPLATE ) -# Fixture to clear the template cache before each test +# --- Test Setup --- + @pytest.fixture(autouse=True) def clear_template_cache(): + """Fixture to clear the template cache before each test.""" + original_templates = _loaded_templates.copy() _loaded_templates.clear() - # Re-add the default passthrough because it's normally added at module load + # Ensure the default is always there for tests that might rely on it _loaded_templates["raw_passthrough"] = DEFAULT_RAW_PASSTHROUGH_TEMPLATE + yield + # Restore original cache state if needed, though clearing is usually sufficient + _loaded_templates.clear() + _loaded_templates.update(original_templates) @pytest.fixture -def mock_templates_dir(tmp_path: Path): - templates_dir = tmp_path / "prompt_templates_test" +def mock_templates_dir(tmp_path, monkeypatch): + """Creates a temporary directory for prompt templates and patches the module-level constant.""" + templates_dir = tmp_path / "prompt_templates" templates_dir.mkdir() - # Create a valid template file - valid_template_data = { - "name": "test_valid", - "description": "A valid test template", - "system_message_template": "System: {sys_var}", - "user_message_content_template": "User: {user_var} - {message_content}" + # Patch the PROMPT_TEMPLATES_DIR constant in the target module + monkeypatch.setattr('tldw_chatbook.Chat.prompt_template_manager.PROMPT_TEMPLATES_DIR', templates_dir) + + # Create some dummy template files + template1_data = { + "name": "test_template_1", + "description": "A simple test template.", + "user_message_content_template": "User said: {{message_content}}" } - with open(templates_dir / "test_valid.json", "w") as f: - json.dump(valid_template_data, f) + (templates_dir / "test_template_1.json").write_text(json.dumps(template1_data)) - # Create an invalid JSON template file - with open(templates_dir / "test_invalid_json.json", "w") as f: - f.write("{'name': 'invalid', 'description': 'bad json'") # Invalid JSON + template2_data = { + "name": "test_template_2", + "system_message_template": "System context: {{system_context}}", + "user_message_content_template": "{{message_content}}" + } + (templates_dir / "test_template_2.json").write_text(json.dumps(template2_data)) - # Create an empty template file (valid JSON but might be handled as error by Pydantic) - with open(templates_dir / "test_empty.json", "w") as f: - json.dump({}, f) + # Create a malformed JSON file + (templates_dir / "malformed.json").write_text("{'invalid': 'json'") return templates_dir -@pytest.mark.unit -def test_load_template_success(mock_templates_dir): - with patch("tldw_app.Chat.prompt_template_manager.PROMPT_TEMPLATES_DIR", mock_templates_dir): - template = load_template("test_valid") - assert template is not None - assert template.name == "test_valid" - assert template.system_message_template == "System: {sys_var}" - # Test caching - template_cached = load_template("test_valid") - assert template_cached is template # Should be the same object from cache +# --- Test Cases --- +class TestPromptTemplateManager: -@pytest.mark.unit -def test_load_template_not_found(mock_templates_dir): - with patch("tldw_app.Chat.prompt_template_manager.PROMPT_TEMPLATES_DIR", mock_templates_dir): + def test_load_template_success(self, mock_templates_dir): + """Test successfully loading a valid template file.""" + template = load_template("test_template_1") + assert template is not None + assert isinstance(template, PromptTemplate) + assert template.name == "test_template_1" + assert template.user_message_content_template == "User said: {{message_content}}" + + def test_load_template_not_found(self, mock_templates_dir): + """Test loading a template that does not exist.""" template = load_template("non_existent_template") assert template is None - -@pytest.mark.unit -def test_load_template_invalid_json(mock_templates_dir): - with patch("tldw_app.Chat.prompt_template_manager.PROMPT_TEMPLATES_DIR", mock_templates_dir): - template = load_template("test_invalid_json") - assert template is None # Should fail to parse - - -@pytest.mark.unit -def test_load_template_empty_json_fails_validation(mock_templates_dir): - with patch("tldw_app.Chat.prompt_template_manager.PROMPT_TEMPLATES_DIR", mock_templates_dir): - template = load_template("test_empty") - # Pydantic will raise validation error because 'name' is missing, - # load_template should catch this and return None. + def test_load_template_malformed_json(self, mock_templates_dir): + """Test loading a template from a file with invalid JSON.""" + template = load_template("malformed") assert template is None + def test_load_template_caching(self, mock_templates_dir): + """Test that a loaded template is cached and not re-read from disk.""" + template1 = load_template("test_template_1") + assert "test_template_1" in _loaded_templates -@pytest.mark.unit -# Inside test_apply_template_to_string(): -def test_apply_template_to_string(): - template_str_jinja = "Hello {{name}}, welcome to {{place}}." # Use Jinja - data_full = {"name": "Alice", "place": "Wonderland"} - assert apply_template_to_string(template_str_jinja, data_full) == "Hello Alice, welcome to Wonderland." + # Modify the file on disk + (mock_templates_dir / "test_template_1.json").write_text(json.dumps({"name": "modified"})) - template_partial_jinja = "Hello {{name}}." # Use Jinja - data_partial = {"name": "Bob"} - assert apply_template_to_string(template_partial_jinja, data_partial) == "Hello Bob." + # Load again - should return the cached version + template2 = load_template("test_template_1") + assert template2 is not None + assert template2.name == "test_template_1" # Original name from cache + assert template2 == template1 - # Test with missing data - Jinja renders empty for missing by default if not strict - assert apply_template_to_string(template_partial_jinja, {}) == "Hello ." - - # Test with None template string - assert apply_template_to_string(None, data_full) == "" - - -@pytest.mark.unit -def test_get_available_templates(mock_templates_dir): - with patch("tldw_app.Chat.prompt_template_manager.PROMPT_TEMPLATES_DIR", mock_templates_dir): + def test_get_available_templates(self, mock_templates_dir): + """Test discovering available templates from the directory.""" available = get_available_templates() assert isinstance(available, list) - assert "test_valid" in available - assert "test_invalid_json" in available - assert "test_empty" in available - assert len(available) == 3 - - -@pytest.mark.unit -def test_get_available_templates_no_dir(): - with patch("tldw_app.Chat.prompt_template_manager.PROMPT_TEMPLATES_DIR", Path("/non/existent/dir")): - available = get_available_templates() - assert available == [] + assert set(available) == {"test_template_1", "test_template_2", "malformed"} + def test_get_available_templates_no_dir(self, tmp_path, monkeypatch): + """Test getting templates when the directory doesn't exist.""" + non_existent_dir = tmp_path / "non_existent_dir" + monkeypatch.setattr('tldw_chatbook.Chat.prompt_template_manager.PROMPT_TEMPLATES_DIR', non_existent_dir) + assert get_available_templates() == [] -@pytest.mark.unit -def test_default_raw_passthrough_template(): - assert DEFAULT_RAW_PASSTHROUGH_TEMPLATE is not None - assert DEFAULT_RAW_PASSTHROUGH_TEMPLATE.name == "raw_passthrough" - data = {"message_content": "test content", "original_system_message_from_request": "system content"} - - # User message template (is "{{message_content}}") - assert apply_template_to_string(DEFAULT_RAW_PASSTHROUGH_TEMPLATE.user_message_content_template, - data) == "test content" - # System message template (is "{{original_system_message_from_request}}") - assert apply_template_to_string(DEFAULT_RAW_PASSTHROUGH_TEMPLATE.system_message_template, - data) == "system content" - - data_empty_sys = {"original_system_message_from_request": ""} - assert apply_template_to_string(DEFAULT_RAW_PASSTHROUGH_TEMPLATE.system_message_template, - data_empty_sys) == "" - - data_missing_sys = {"message_content": "some_content"} # original_system_message_from_request is missing - assert apply_template_to_string(DEFAULT_RAW_PASSTHROUGH_TEMPLATE.system_message_template, - data_missing_sys) == "" # Jinja renders missing as empty - + def test_default_passthrough_template_is_available(self): + """Test that the default 'raw_passthrough' template is loaded.""" + template = load_template("raw_passthrough") + assert template is not None + assert template.name == "raw_passthrough" + assert template.user_message_content_template == "{{message_content}}" + + +class TestTemplateRendering: + + def test_apply_template_to_string_success(self): + """Test basic successful rendering.""" + template_str = "Hello, {{ name }}!" + data = {"name": "World"} + result = apply_template_to_string(template_str, data) + assert result == "Hello, World!" + + def test_apply_template_to_string_missing_placeholder(self): + """Test rendering when a placeholder in the template is not in the data.""" + template_str = "Hello, {{ name }}! Your age is {{ age }}." + data = {"name": "World"} # 'age' is missing + result = apply_template_to_string(template_str, data) + assert result == "Hello, World! Your age is ." # Jinja renders missing variables as empty strings + + def test_apply_template_with_none_input_string(self): + """Test that a None template string returns an empty string.""" + data = {"name": "World"} + result = apply_template_to_string(None, data) + assert result == "" + + def test_apply_template_with_complex_data(self): + """Test rendering with more complex data structures like lists and dicts.""" + template_str = "User: {{ user.name }}. Items: {% for item in items %}{{ item }}{% if not loop.last %}, {% endif %}{% endfor %}." + data = { + "user": {"name": "Alice"}, + "items": ["apple", "banana", "cherry"] + } + result = apply_template_to_string(template_str, data) + assert result == "User: Alice. Items: apple, banana, cherry." + + def test_safe_render_prevents_unsafe_operations(self): + """Test that the sandboxed environment prevents access to unsafe attributes.""" + # Attempt to access a private attribute or a method that could be unsafe + template_str = "Unsafe access: {{ my_obj.__class__ }}" + + class MyObj: pass + + data = {"my_obj": MyObj()} + + # In a sandboxed environment, this should raise a SecurityError, which our wrapper catches. + # The wrapper then returns the original string. + result = apply_template_to_string(template_str, data) + assert result == template_str \ No newline at end of file diff --git a/Tests/DB/test_sqlite_db.py b/Tests/DB/test_sqlite_db.py deleted file mode 100644 index 5eabfd8b..00000000 --- a/Tests/DB/test_sqlite_db.py +++ /dev/null @@ -1,553 +0,0 @@ -# tests/test_sqlite_db.py -# Description: Unit tests for SQLite database operations, including CRUD, transactions, and sync log management. -# -# Imports: -import json -import os -import pytest -import time -import sqlite3 -from datetime import datetime, timezone, timedelta -# -# 3rd-Party Imports: -# -# Local imports -from tldw_cli.tldw_app.DB.Client_Media_DB_v2 import Database, ConflictError -# Import from src using adjusted sys.path in conftest -# -####################################################################################################################### -# -# Functions: - -# Helper to get sync log entries for assertions -def get_log_count(db: Database, entity_uuid: str) -> int: - cursor = db.execute_query("SELECT COUNT(*) FROM sync_log WHERE entity_uuid = ?", (entity_uuid,)) - return cursor.fetchone()[0] - -def get_latest_log(db: Database, entity_uuid: str) -> dict | None: - cursor = db.execute_query( - "SELECT * FROM sync_log WHERE entity_uuid = ? ORDER BY change_id DESC LIMIT 1", - (entity_uuid,) - ) - row = cursor.fetchone() - return dict(row) if row else None - -def get_entity_version(db: Database, entity_table: str, uuid: str) -> int | None: - cursor = db.execute_query(f"SELECT version FROM {entity_table} WHERE uuid = ?", (uuid,)) - row = cursor.fetchone() - return row['version'] if row else None - -class TestDatabaseInitialization: - def test_memory_db_creation(self, memory_db_factory): - """Test creating an in-memory database.""" - db = memory_db_factory("client_mem") - assert db.is_memory_db - assert db.client_id == "client_mem" - # Check if a table exists (schema creation check) - cursor = db.execute_query("SELECT name FROM sqlite_master WHERE type='table' AND name='Media'") - assert cursor.fetchone() is not None - db.close_connection() - - def test_file_db_creation(self, file_db, temp_db_path): - """Test creating a file-based database.""" - assert not file_db.is_memory_db - assert file_db.client_id == "file_client" - assert os.path.exists(temp_db_path) - cursor = file_db.execute_query("SELECT name FROM sqlite_master WHERE type='table' AND name='Media'") - assert cursor.fetchone() is not None - # file_db fixture handles closure - - def test_missing_client_id(self): - """Test that ValueError is raised if client_id is missing.""" - with pytest.raises(ValueError, match="Client ID cannot be empty"): - Database(db_path=":memory:", client_id="") - with pytest.raises(ValueError, match="Client ID cannot be empty"): - Database(db_path=":memory:", client_id=None) - - -class TestDatabaseTransactions: - def test_transaction_commit(self, memory_db_factory): - db = memory_db_factory() - keyword = "commit_test" - with db.transaction(): - # Use internal method _add_keyword_internal or simplified version for test - kw_id, kw_uuid = db.add_keyword(keyword) # add_keyword uses transaction internally too, nested is ok - # Verify outside transaction - cursor = db.execute_query("SELECT keyword FROM Keywords WHERE id = ?", (kw_id,)) - assert cursor.fetchone()['keyword'] == keyword - - def test_transaction_rollback(self, memory_db_factory): - db = memory_db_factory() - keyword = "rollback_test" - initial_count_cursor = db.execute_query("SELECT COUNT(*) FROM Keywords") - initial_count = initial_count_cursor.fetchone()[0] - try: - with db.transaction(): - # Simplified insert for test clarity - new_uuid = db._generate_uuid() - db.execute_query( - "INSERT INTO Keywords (keyword, uuid, last_modified, version, client_id, deleted) VALUES (?, ?, ?, 1, ?, 0)", - (keyword, new_uuid, db._get_current_utc_timestamp_str(), db.client_id), - commit=False # Important: commit=False inside transaction block - ) - # Check *inside* transaction - cursor_inside = db.execute_query("SELECT COUNT(*) FROM Keywords") - assert cursor_inside.fetchone()[0] == initial_count + 1 - raise ValueError("Simulating error to trigger rollback") # Force rollback - except ValueError: - pass # Expected error - except Exception as e: - pytest.fail(f"Unexpected exception during rollback test: {e}") - - # Verify outside transaction (count should be back to initial) - final_count_cursor = db.execute_query("SELECT COUNT(*) FROM Keywords") - assert final_count_cursor.fetchone()[0] == initial_count - - -class TestDatabaseCRUDAndSync: - - @pytest.fixture - def db_instance(self, memory_db_factory): - """Provides a fresh in-memory DB for each test in this class.""" - return memory_db_factory("crud_client") - - def test_add_keyword(self, db_instance): - keyword = " test keyword " - expected_keyword = "test keyword" - kw_id, kw_uuid = db_instance.add_keyword(keyword) - - assert kw_id is not None - assert kw_uuid is not None - - # Verify DB state - cursor = db_instance.execute_query("SELECT * FROM Keywords WHERE id = ?", (kw_id,)) - row = cursor.fetchone() - assert row['keyword'] == expected_keyword - assert row['uuid'] == kw_uuid - assert row['version'] == 1 - assert row['client_id'] == db_instance.client_id - assert not row['deleted'] - - # Verify Sync Log - log_entry = get_latest_log(db_instance, kw_uuid) - assert log_entry['operation'] == 'create' - assert log_entry['entity'] == 'Keywords' - assert log_entry['version'] == 1 - assert log_entry['client_id'] == db_instance.client_id - payload = json.loads(log_entry['payload']) - assert payload['keyword'] == expected_keyword - assert payload['uuid'] == kw_uuid - - def test_add_existing_keyword(self, db_instance): - keyword = "existing" - kw_id1, kw_uuid1 = db_instance.add_keyword(keyword) - log_count1 = get_log_count(db_instance, kw_uuid1) - - kw_id2, kw_uuid2 = db_instance.add_keyword(keyword) # Add again - log_count2 = get_log_count(db_instance, kw_uuid1) - - assert kw_id1 == kw_id2 - assert kw_uuid1 == kw_uuid2 - assert log_count1 == log_count2 # No new log entry - - def test_soft_delete_keyword(self, db_instance): - keyword = "to_delete" - kw_id, kw_uuid = db_instance.add_keyword(keyword) - initial_version = get_entity_version(db_instance, "Keywords", kw_uuid) - - deleted = db_instance.soft_delete_keyword(keyword) - assert deleted is True - - # Verify DB state - cursor = db_instance.execute_query("SELECT deleted, version FROM Keywords WHERE id = ?", (kw_id,)) - row = cursor.fetchone() - assert row['deleted'] == 1 - assert row['version'] == initial_version + 1 - - # Verify Sync Log - log_entry = get_latest_log(db_instance, kw_uuid) - assert log_entry['operation'] == 'delete' - assert log_entry['entity'] == 'Keywords' - assert log_entry['version'] == initial_version + 1 - payload = json.loads(log_entry['payload']) - assert payload['uuid'] == kw_uuid # Delete payload is minimal - - def test_undelete_keyword(self, db_instance): - keyword = "to_undelete" - kw_id, kw_uuid = db_instance.add_keyword(keyword) - db_instance.soft_delete_keyword(keyword) # Delete it first - deleted_version = get_entity_version(db_instance, "Keywords", kw_uuid) - - # Adding it again should undelete it - undelete_id, undelete_uuid = db_instance.add_keyword(keyword) - - assert undelete_id == kw_id - assert undelete_uuid == kw_uuid - - # Verify DB state - cursor = db_instance.execute_query("SELECT deleted, version FROM Keywords WHERE id = ?", (kw_id,)) - row = cursor.fetchone() - assert row['deleted'] == 0 - assert row['version'] == deleted_version + 1 - - # Verify Sync Log - log_entry = get_latest_log(db_instance, kw_uuid) - # Undelete is logged as an 'update' - assert log_entry['operation'] == 'update' - assert log_entry['entity'] == 'Keywords' - assert log_entry['version'] == deleted_version + 1 - payload = json.loads(log_entry['payload']) - assert payload['uuid'] == kw_uuid - assert payload['deleted'] == 0 # Payload shows undeleted state - - def test_add_media_with_keywords_create(self, db_instance): - title = "Test Media Create" - content = "Some unique content for create." - keywords = ["create_kw1", "create_kw2"] - - media_id, media_uuid, msg = db_instance.add_media_with_keywords( - title=title, - media_type="article", - content=content, - keywords=keywords, - author="Tester" - ) - - assert media_id is not None - assert media_uuid is not None - assert f"Media '{title}' added." == msg # NEW (Exact match) - - # Verify Media DB state - cursor = db_instance.execute_query("SELECT * FROM Media WHERE id = ?", (media_id,)) - media_row = cursor.fetchone() - assert media_row['title'] == title - assert media_row['uuid'] == media_uuid - assert media_row['version'] == 1 # Initial version - assert not media_row['deleted'] - - # Verify Keywords exist - cursor = db_instance.execute_query("SELECT COUNT(*) FROM Keywords WHERE keyword IN (?, ?)", tuple(keywords)) - assert cursor.fetchone()[0] == 2 - - # Verify MediaKeywords links - cursor = db_instance.execute_query("SELECT COUNT(*) FROM MediaKeywords WHERE media_id = ?", (media_id,)) - assert cursor.fetchone()[0] == 2 - - # Verify DocumentVersion creation - cursor = db_instance.execute_query("SELECT version_number, content FROM DocumentVersions WHERE media_id = ? ORDER BY version_number DESC LIMIT 1", (media_id,)) - version_row = cursor.fetchone() - assert version_row['version_number'] == 1 - assert version_row['content'] == content - - # Verify Sync Log for Media - log_entry = get_latest_log(db_instance, media_uuid) - assert log_entry['operation'] == 'create' - assert log_entry['entity'] == 'Media' - # Note: MediaKeywords triggers might log *after* the media create trigger - - def test_add_media_with_keywords_update(self, db_instance): - title = "Test Media Update" - content1 = "Initial content." - content2 = "Updated content." - keywords1 = ["update_kw1"] - keywords2 = ["update_kw2", "update_kw3"] - - # Add initial version - media_id, media_uuid, _ = db_instance.add_media_with_keywords( - title=title, media_type="text", content=content1, keywords=keywords1 - ) - initial_version = get_entity_version(db_instance, "Media", media_uuid) - - # --- FIX: Fetch the hash AFTER creation --- - cursor_check_initial = db_instance.execute_query("SELECT content_hash FROM Media WHERE id = ?", (media_id,)) - initial_hash_row = cursor_check_initial.fetchone() - assert initial_hash_row is not None # Ensure fetch worked - initial_content_hash = initial_hash_row['content_hash'] - # --- End Fix --- - - # --- Attempt 1: Update using a generated URL with the initial hash --- - # (This test might be slightly less relevant if your primary update mechanism - # relies on UUID or finding via hash internally when URL is None) - generated_url = f"local://text/{initial_content_hash}" - media_id_up1, media_uuid_up1, msg1 = db_instance.add_media_with_keywords( - title=title + " Updated Via URL", - media_type="text", - content=content2, # Update content (changes hash) - keywords=["url_update_kw"], - overwrite=True, - url=generated_url # Use the generated URL - ) - - # Assertions for the first update attempt (if you keep it) - assert media_id_up1 == media_id - assert media_uuid_up1 == media_uuid - assert f"Media '{title + ' Updated Via URL'}' updated." == msg1 - # Check version incremented after first update - version_after_update1 = get_entity_version(db_instance, "Media", media_uuid) - assert version_after_update1 == initial_version + 1 - - # --- Attempt 2: Simulate finding by hash (URL=None) --- - # Update again, changing keywords - media_id_up2, media_uuid_up2, msg2 = db_instance.add_media_with_keywords( - title=title + " Updated Via Hash", # Change title again - media_type="text", - content=content2, # Keep content same as first update - keywords=keywords2, # Use the final keyword set - overwrite=True, - url=None # Force lookup by hash (which is now hash of content2) - ) - - # Assertions for the second update attempt - assert media_id_up2 == media_id - assert media_uuid_up2 == media_uuid - assert f"Media '{title + ' Updated Via Hash'}' updated." == msg2 - - # Verify Final Media DB state - cursor = db_instance.execute_query("SELECT title, content, version FROM Media WHERE id = ?", (media_id,)) - media_row = cursor.fetchone() # Now media_row is correctly defined for assertions - assert media_row['title'] == title + " Updated Via Hash" - assert media_row['content'] == content2 - # Version should have incremented again from the second update - assert media_row['version'] == version_after_update1 + 1 - - # Verify Keywords links updated to the final set - cursor = db_instance.execute_query(""" - SELECT k.keyword - FROM MediaKeywords mk - JOIN Keywords k ON mk.keyword_id = k.id - WHERE mk.media_id = ? - ORDER BY k.keyword - """, (media_id,)) - current_keywords = [r['keyword'] for r in cursor.fetchall()] - assert current_keywords == sorted(keywords2) - - # Verify latest DocumentVersion reflects the last content state (content2) - cursor = db_instance.execute_query( - "SELECT version_number, content FROM DocumentVersions WHERE media_id = ? ORDER BY version_number DESC LIMIT 1", - (media_id,)) - version_row = cursor.fetchone() - # There should be 3 versions now (initial create, update 1, update 2) - assert version_row['version_number'] == 3 - assert version_row['content'] == content2 - - # Verify Sync Log for the *last* Media update - log_entry = get_latest_log(db_instance, media_uuid) - assert log_entry['operation'] == 'update' - assert log_entry['entity'] == 'Media' - assert log_entry['version'] == version_after_update1 + 1 - - def test_soft_delete_media_cascade(self, db_instance): - # 1. Setup complex item - media_id, media_uuid, _ = db_instance.add_media_with_keywords( - title="Cascade Test", content="Cascade content", media_type="article", - keywords=["cascade1", "cascade2"], author="Cascade Author" - ) - # Add a transcript manually (assuming no direct add_transcript method) - t_uuid = db_instance._generate_uuid() - db_instance.execute_query( - """INSERT INTO Transcripts (media_id, whisper_model, transcription, uuid, last_modified, version, client_id, deleted) - VALUES (?, ?, ?, ?, ?, 1, ?, 0)""", - (media_id, "model_xyz", "Transcript text", t_uuid, db_instance._get_current_utc_timestamp_str(), db_instance.client_id), - commit=True - ) - # Add a chunk manually - c_uuid = db_instance._generate_uuid() - db_instance.execute_query( - """INSERT INTO MediaChunks (media_id, chunk_text, uuid, last_modified, version, client_id, deleted) - VALUES (?, ?, ?, ?, 1, ?, 0)""", - (media_id, "Chunk text", c_uuid, db_instance._get_current_utc_timestamp_str(), db_instance.client_id), - commit=True - ) - media_version = get_entity_version(db_instance, "Media", media_uuid) - transcript_version = get_entity_version(db_instance, "Transcripts", t_uuid) - chunk_version = get_entity_version(db_instance, "MediaChunks", c_uuid) - - - # 2. Perform soft delete with cascade - deleted = db_instance.soft_delete_media(media_id, cascade=True) - assert deleted is True - - # 3. Verify parent and children are marked deleted and versioned - cursor = db_instance.execute_query("SELECT deleted, version FROM Media WHERE id = ?", (media_id,)) - assert dict(cursor.fetchone()) == {'deleted': 1, 'version': media_version + 1} - - cursor = db_instance.execute_query("SELECT deleted, version FROM Transcripts WHERE uuid = ?", (t_uuid,)) - assert dict(cursor.fetchone()) == {'deleted': 1, 'version': transcript_version + 1} - - cursor = db_instance.execute_query("SELECT deleted, version FROM MediaChunks WHERE uuid = ?", (c_uuid,)) - assert dict(cursor.fetchone()) == {'deleted': 1, 'version': chunk_version + 1} - - # 4. Verify keywords are unlinked - cursor = db_instance.execute_query("SELECT COUNT(*) FROM MediaKeywords WHERE media_id = ?", (media_id,)) - assert cursor.fetchone()[0] == 0 - - # 5. Verify Sync Logs - media_log = get_latest_log(db_instance, media_uuid) - assert media_log['operation'] == 'delete' - assert media_log['version'] == media_version + 1 - - transcript_log = get_latest_log(db_instance, t_uuid) - assert transcript_log['operation'] == 'delete' - assert transcript_log['version'] == transcript_version + 1 - - chunk_log = get_latest_log(db_instance, c_uuid) - assert chunk_log['operation'] == 'delete' - assert chunk_log['version'] == chunk_version + 1 - - # Check MediaKeywords unlink logs (tricky to get exact UUIDs, check count) - cursor = db_instance.execute_query("SELECT COUNT(*) FROM sync_log WHERE entity = 'MediaKeywords' AND operation = 'unlink' AND payload LIKE ?", (f'%{media_uuid}%',)) - assert cursor.fetchone()[0] == 2 # Should be 2 unlink events - - def test_optimistic_locking_prevents_update_with_stale_version(self, db_instance): - """Test that an UPDATE with a stale version number fails (rowcount 0).""" - keyword = "conflict_test" - kw_id, kw_uuid = db_instance.add_keyword(keyword) - original_version = get_entity_version(db_instance, "Keywords", kw_uuid) # Should be 1 - assert original_version == 1, "Initial version should be 1" - - # Simulate external update incrementing version - db_instance.execute_query( - "UPDATE Keywords SET version = ?, client_id = ? WHERE id = ?", - (original_version + 1, "external_client", kw_id), - commit=True - ) - version_after_external_update = get_entity_version(db_instance, "Keywords", kw_uuid) # Should be 2 - assert version_after_external_update == original_version + 1, "Version after external update should be 2" - - # Now, manually attempt an update using the *original stale version* (version=1) - # This mimics what would happen if a process read version 1, then tried - # to update after the external process bumped it to version 2. - current_time = db_instance._get_current_utc_timestamp_str() - client_id = db_instance.client_id - cursor = db_instance.execute_query( - "UPDATE Keywords SET keyword='stale_update', last_modified=?, version=?, client_id=? WHERE id=? AND version=?", - (current_time, original_version + 1, client_id, kw_id, original_version), # <<< WHERE version = 1 (stale) - commit=True # Commit needed to actually perform the check - ) - - # Assert that the update failed because the WHERE clause (version=1) didn't match any rows - assert cursor.rowcount == 0, "Update with stale version should affect 0 rows" - - # Verify DB state is unchanged by the failed update (still shows external update's state) - cursor_check = db_instance.execute_query("SELECT keyword, version, client_id FROM Keywords WHERE id = ?", - (kw_id,)) - row = cursor_check.fetchone() - assert row is not None, "Keyword should still exist" - assert row['keyword'] == keyword, "Keyword text should not have changed to 'stale_update'" - assert row['version'] == original_version + 1, "Version should remain 2 from the external update" - assert row['client_id'] == "external_client", "Client ID should remain from the external update" - - def test_version_validation_trigger(self, db_instance): - """Test trigger preventing non-sequential version updates.""" - kw_id, kw_uuid = db_instance.add_keyword("validation_test") - current_version = get_entity_version(db_instance, "Keywords", kw_uuid) - - # Try to update version incorrectly (skipping a version) - with pytest.raises(sqlite3.IntegrityError, match="Sync Error \(Keywords\): Version must increment by exactly 1"): - db_instance.execute_query( - "UPDATE Keywords SET version = ?, keyword = ? WHERE id = ?", - (current_version + 2, "bad version", kw_id), - commit=True - ) - - # Try to update version incorrectly (same version) - with pytest.raises(sqlite3.IntegrityError, match="Sync Error \(Keywords\): Version must increment by exactly 1"): - db_instance.execute_query( - "UPDATE Keywords SET version = ?, keyword = ? WHERE id = ?", - (current_version, "same version", kw_id), - commit=True - ) - - def test_client_id_validation_trigger(self, db_instance): - """Test trigger preventing null/empty client_id on update.""" - kw_id, kw_uuid = db_instance.add_keyword("clientid_test") - current_version = get_entity_version(db_instance, "Keywords", kw_uuid) - - # Try to update with NULL client_id - with pytest.raises(sqlite3.IntegrityError, match="Sync Error \(Keywords\): Client ID cannot be NULL or empty"): - db_instance.execute_query( - "UPDATE Keywords SET version = ?, client_id = NULL WHERE id = ?", - (current_version + 1, kw_id), - commit=True - ) - - # Try to update with empty client_id - with pytest.raises(sqlite3.IntegrityError, match="Sync Error \(Keywords\): Client ID cannot be NULL or empty"): - db_instance.execute_query( - "UPDATE Keywords SET version = ?, client_id = '' WHERE id = ?", - (current_version + 1, kw_id), - commit=True - ) - - -class TestSyncLogManagement: - - @pytest.fixture - def db_instance(self, memory_db_factory): - db = memory_db_factory("log_client") - # Add some initial data to generate logs - db.add_keyword("log_kw_1") - time.sleep(0.01) # Ensure timestamp difference - db.add_keyword("log_kw_2") - time.sleep(0.01) - db.add_keyword("log_kw_3") - db.soft_delete_keyword("log_kw_2") - return db - - def test_get_sync_log_entries_all(self, db_instance): - logs = db_instance.get_sync_log_entries() - # Expect 3 creates + 1 delete = 4 entries - assert len(logs) == 4 - assert logs[0]['change_id'] == 1 - assert logs[-1]['change_id'] == 4 - - def test_get_sync_log_entries_since(self, db_instance): - logs = db_instance.get_sync_log_entries(since_change_id=2) # Get 3 and 4 - assert len(logs) == 2 - assert logs[0]['change_id'] == 3 - assert logs[1]['change_id'] == 4 - - def test_get_sync_log_entries_limit(self, db_instance): - logs = db_instance.get_sync_log_entries(limit=2) # Get 1 and 2 - assert len(logs) == 2 - assert logs[0]['change_id'] == 1 - assert logs[1]['change_id'] == 2 - - def test_get_sync_log_entries_since_and_limit(self, db_instance): - logs = db_instance.get_sync_log_entries(since_change_id=1, limit=2) # Get 2 and 3 - assert len(logs) == 2 - assert logs[0]['change_id'] == 2 - assert logs[1]['change_id'] == 3 - - def test_delete_sync_log_entries_specific(self, db_instance): - initial_logs = db_instance.get_sync_log_entries() - initial_count = len(initial_logs) # Should be 4 - ids_to_delete = [initial_logs[1]['change_id'], initial_logs[2]['change_id']] # Delete 2 and 3 - - deleted_count = db_instance.delete_sync_log_entries(ids_to_delete) - assert deleted_count == 2 - - remaining_logs = db_instance.get_sync_log_entries() - assert len(remaining_logs) == initial_count - 2 - remaining_ids = {log['change_id'] for log in remaining_logs} - assert remaining_ids == {initial_logs[0]['change_id'], initial_logs[3]['change_id']} # 1 and 4 should remain - - def test_delete_sync_log_entries_before(self, db_instance): - initial_logs = db_instance.get_sync_log_entries() - initial_count = len(initial_logs) # Should be 4 - threshold_id = initial_logs[2]['change_id'] # Delete up to and including ID 3 - - deleted_count = db_instance.delete_sync_log_entries_before(threshold_id) - assert deleted_count == 3 # Deleted 1, 2, 3 - - remaining_logs = db_instance.get_sync_log_entries() - assert len(remaining_logs) == 1 - assert remaining_logs[0]['change_id'] == initial_logs[3]['change_id'] # Only 4 remains - - def test_delete_sync_log_entries_empty(self, db_instance): - deleted_count = db_instance.delete_sync_log_entries([]) - assert deleted_count == 0 - - def test_delete_sync_log_entries_invalid_id(self, db_instance): - with pytest.raises(ValueError): - db_instance.delete_sync_log_entries([1, "two", 3]) \ No newline at end of file diff --git a/Tests/MediaDB2/__init__.py b/Tests/Event_Handlers/Chat_Events/__init__.py similarity index 100% rename from Tests/MediaDB2/__init__.py rename to Tests/Event_Handlers/Chat_Events/__init__.py diff --git a/Tests/Event_Handlers/Chat_Events/test_chat_events.py b/Tests/Event_Handlers/Chat_Events/test_chat_events.py new file mode 100644 index 00000000..9a2b3433 --- /dev/null +++ b/Tests/Event_Handlers/Chat_Events/test_chat_events.py @@ -0,0 +1,301 @@ +# /tests/Event_Handlers/Chat_Events/test_chat_events.py + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch, call + +from rich.text import Text +# Mock Textual UI elements before they are imported by the module under test +from textual.widgets import ( + Button, Input, TextArea, Static, Select, Checkbox, ListView, ListItem, Label +) +from textual.containers import VerticalScroll +from textual.css.query import QueryError + +# Mock DB Errors +from tldw_chatbook.DB.ChaChaNotes_DB import ConflictError, CharactersRAGDBError, InputError + +# Functions to test +from tldw_chatbook.Event_Handlers.Chat_Events.chat_events import ( + handle_chat_send_button_pressed, + handle_chat_action_button_pressed, + handle_chat_new_conversation_button_pressed, + handle_chat_save_current_chat_button_pressed, + handle_chat_load_character_button_pressed, + handle_chat_clear_active_character_button_pressed, + # ... import other handlers as you write tests for them +) +from tldw_chatbook.Widgets.chat_message import ChatMessage + +pytestmark = pytest.mark.asyncio + + +# A very comprehensive mock app fixture is needed here +@pytest.fixture +def mock_app(): + app = AsyncMock() + + # Mock services and DBs + app.chachanotes_db = MagicMock() + app.notes_service = MagicMock() + app.notes_service._get_db.return_value = app.chachanotes_db + app.media_db = MagicMock() + + # Mock core app properties + app.API_IMPORTS_SUCCESSFUL = True + app.app_config = { + "api_settings": { + "openai": {"streaming": True, "api_key_env_var": "OPENAI_API_KEY"}, + "anthropic": {"streaming": False, "api_key": "xyz-key"} + }, + "chat_defaults": {"system_prompt": "Default system prompt."}, + "USERS_NAME": "Tester" + } + + # Mock app state + app.current_chat_conversation_id = None + app.current_chat_is_ephemeral = True + app.current_chat_active_character_data = None + app.current_ai_message_widget = None + + # Mock app methods + app.query_one = MagicMock() + app.notify = AsyncMock() + app.copy_to_clipboard = MagicMock() + app.set_timer = MagicMock() + app.run_worker = MagicMock() + app.chat_wrapper = AsyncMock() + + # Timers + app._conversation_search_timer = None + + # --- Set up mock widgets --- + # This is complex; a helper function simplifies it. + def setup_mock_widgets(q_one_mock): + widgets = { + "#chat-input": MagicMock(spec=TextArea, text="User message", is_mounted=True), + "#chat-log": AsyncMock(spec=VerticalScroll, is_mounted=True), + "#chat-api-provider": MagicMock(spec=Select, value="OpenAI"), + "#chat-api-model": MagicMock(spec=Select, value="gpt-4"), + "#chat-system-prompt": MagicMock(spec=TextArea, text="UI system prompt"), + "#chat-temperature": MagicMock(spec=Input, value="0.7"), + "#chat-top-p": MagicMock(spec=Input, value="0.9"), + "#chat-min-p": MagicMock(spec=Input, value="0.1"), + "#chat-top-k": MagicMock(spec=Input, value="40"), + "#chat-llm-max-tokens": MagicMock(spec=Input, value="1024"), + "#chat-llm-seed": MagicMock(spec=Input, value=""), + "#chat-llm-stop": MagicMock(spec=Input, value=""), + "#chat-llm-response-format": MagicMock(spec=Select, value="text"), + "#chat-llm-n": MagicMock(spec=Input, value="1"), + "#chat-llm-user-identifier": MagicMock(spec=Input, value=""), + "#chat-llm-logprobs": MagicMock(spec=Checkbox, value=False), + "#chat-llm-top-logprobs": MagicMock(spec=Input, value=""), + "#chat-llm-logit-bias": MagicMock(spec=TextArea, text="{}"), + "#chat-llm-presence-penalty": MagicMock(spec=Input, value="0.0"), + "#chat-llm-frequency-penalty": MagicMock(spec=Input, value="0.0"), + "#chat-llm-tools": MagicMock(spec=TextArea, text="[]"), + "#chat-llm-tool-choice": MagicMock(spec=Input, value=""), + "#chat-llm-fixed-tokens-kobold": MagicMock(spec=Checkbox, value=False), + "#chat-strip-thinking-tags-checkbox": MagicMock(spec=Checkbox, value=True), + "#chat-character-search-results-list": AsyncMock(spec=ListView), + "#chat-character-name-edit": MagicMock(spec=Input), + "#chat-character-description-edit": MagicMock(spec=TextArea), + "#chat-character-personality-edit": MagicMock(spec=TextArea), + "#chat-character-scenario-edit": MagicMock(spec=TextArea), + "#chat-character-system-prompt-edit": MagicMock(spec=TextArea), + "#chat-character-first-message-edit": MagicMock(spec=TextArea), + "#chat-right-sidebar": MagicMock(), # Mock container + } + + def query_one_side_effect(selector, _type=None): + # Special case for querying within the sidebar + if isinstance(selector, MagicMock) and hasattr(selector, 'query_one'): + return selector.query_one(selector, _type) + + if selector in widgets: + return widgets[selector] + + # Allow querying for sub-widgets inside a container like the right sidebar + if widgets["#chat-right-sidebar"].query_one.call_args: + inner_selector = widgets["#chat-right-sidebar"].query_one.call_args[0][0] + if inner_selector in widgets: + return widgets[inner_selector] + + raise QueryError(f"Widget not found by mock: {selector}") + + q_one_mock.side_effect = query_one_side_effect + + # Make the sidebar mock also use the main query_one logic + widgets["#chat-right-sidebar"].query_one.side_effect = lambda sel, _type: widgets[sel] + + setup_mock_widgets(app.query_one) + + return app + + +# Mock external dependencies used in chat_events.py +@patch('tldw_chatbook.Event_Handlers.Chat_Events.chat_events.ccl') +@patch('tldw_chatbook.Event_Handlers.Chat_Events.chat_events.os') +@patch('tldw_chatbook.Event_Handlers.Chat_Events.chat_events.ChatMessage', new_callable=AsyncMock) +async def test_handle_chat_send_button_pressed_basic(mock_chat_message_class, mock_os, mock_ccl, mock_app): + """Test a basic message send operation.""" + mock_os.environ.get.return_value = "fake-key" + + await handle_chat_send_button_pressed(mock_app, MagicMock()) + + # Assert UI updates + mock_app.query_one("#chat-input").clear.assert_called_once() + mock_app.query_one("#chat-log").mount.assert_any_call(mock_chat_message_class.return_value) # Mounts user message + mock_app.query_one("#chat-log").mount.assert_any_call(mock_app.current_ai_message_widget) # Mounts AI placeholder + + # Assert worker is called + mock_app.run_worker.assert_called_once() + + # Assert chat_wrapper is called with correct parameters by the worker + worker_lambda = mock_app.run_worker.call_args[0][0] + worker_lambda() # Execute the lambda to trigger the call to chat_wrapper + + mock_app.chat_wrapper.assert_called_once() + wrapper_kwargs = mock_app.chat_wrapper.call_args.kwargs + assert wrapper_kwargs['message'] == "User message" + assert wrapper_kwargs['api_endpoint'] == "OpenAI" + assert wrapper_kwargs['api_key'] == "fake-key" + assert wrapper_kwargs['system_message'] == "UI system prompt" + assert wrapper_kwargs['streaming'] is True # From config + + +@patch('tldw_chatbook.Event_Handlers.Chat_Events.chat_events.ccl') +@patch('tldw_chatbook.Event_Handlers.Chat_Events.chat_events.os') +@patch('tldw_chatbook.Event_Handlers.Chat_Events.chat_events.ChatMessage', new_callable=AsyncMock) +async def test_handle_chat_send_with_active_character(mock_chat_message_class, mock_os, mock_ccl, mock_app): + """Test that an active character's system prompt overrides the UI.""" + mock_os.environ.get.return_value = "fake-key" + mock_app.current_chat_active_character_data = { + 'name': 'TestChar', + 'system_prompt': 'You are TestChar.' + } + + await handle_chat_send_button_pressed(mock_app, MagicMock()) + + worker_lambda = mock_app.run_worker.call_args[0][0] + worker_lambda() + + wrapper_kwargs = mock_app.chat_wrapper.call_args.kwargs + assert wrapper_kwargs['system_message'] == "You are TestChar." + + +async def test_handle_new_conversation_button_pressed(mock_app): + """Test that the new chat button clears state and UI.""" + # Set some state to ensure it's cleared + mock_app.current_chat_conversation_id = "conv_123" + mock_app.current_chat_is_ephemeral = False + mock_app.current_chat_active_character_data = {'name': 'char'} + + await handle_chat_new_conversation_button_pressed(mock_app, MagicMock()) + + mock_app.query_one("#chat-log").remove_children.assert_called_once() + assert mock_app.current_chat_conversation_id is None + assert mock_app.current_chat_is_ephemeral is True + assert mock_app.current_chat_active_character_data is None + # Check that a UI field was reset + assert mock_app.query_one("#chat-system-prompt").text == "Default system prompt." + + +@patch('tldw_chatbook.Event_Handlers.Chat_Events.chat_events.ccl') +@patch('tldw_chatbook.Event_Handlers.Chat_Events.chat_events.display_conversation_in_chat_tab_ui', + new_callable=AsyncMock) +async def test_handle_save_current_chat_button_pressed(mock_display_conv, mock_ccl, mock_app): + """Test saving an ephemeral chat.""" + mock_app.current_chat_is_ephemeral = True + mock_app.current_chat_conversation_id = None + + # Setup mock messages in the chat log + mock_msg1 = MagicMock(spec=ChatMessage, role="User", message_text="Hello", generation_complete=True, + image_data=None, image_mime_type=None) + mock_msg2 = MagicMock(spec=ChatMessage, role="AI", message_text="Hi", generation_complete=True, image_data=None, + image_mime_type=None) + mock_app.query_one("#chat-log").query.return_value = [mock_msg1, mock_msg2] + + mock_ccl.create_conversation.return_value = "new_conv_id" + + await handle_chat_save_current_chat_button_pressed(mock_app, MagicMock()) + + mock_ccl.create_conversation.assert_called_once() + create_kwargs = mock_ccl.create_conversation.call_args.kwargs + assert create_kwargs['title'].startswith("Chat: Hello...") + assert len(create_kwargs['initial_messages']) == 2 + assert create_kwargs['initial_messages'][0]['content'] == "Hello" + + assert mock_app.current_chat_conversation_id == "new_conv_id" + assert mock_app.current_chat_is_ephemeral is False + mock_app.notify.assert_called_with("Chat saved successfully!", severity="information") + mock_display_conv.assert_called_once_with(mock_app, "new_conv_id") + + +@patch('tldw_chatbook.Event_Handlers.Chat_Events.chat_events.ccl') +async def test_handle_chat_action_button_pressed_edit_and_save(mock_ccl, mock_app): + """Test the edit->save workflow for a chat message.""" + mock_button = MagicMock(spec=Button, classes=["edit-button"]) + mock_action_widget = AsyncMock(spec=ChatMessage) + mock_action_widget.message_text = "Original text" + mock_action_widget.message_id_internal = "msg_123" + mock_action_widget.message_version_internal = 0 + mock_action_widget._editing = False # Start in non-editing mode + mock_static_text = mock_action_widget.query_one.return_value + + # --- 1. First press: Start editing --- + await handle_chat_action_button_pressed(mock_app, mock_button, mock_action_widget) + + mock_action_widget.mount.assert_called_once() # Mounts the TextArea + assert mock_action_widget._editing is True + assert "💾" in mock_button.label # Check for save emoji + + # --- 2. Second press: Save edit --- + mock_action_widget._editing = True # Simulate being in editing mode + mock_edit_area = MagicMock(spec=TextArea, text="New edited text") + mock_action_widget.query_one.return_value = mock_edit_area + mock_ccl.edit_message_content.return_value = True + + await handle_chat_action_button_pressed(mock_app, mock_button, mock_action_widget) + + mock_edit_area.remove.assert_called_once() + assert mock_action_widget.message_text == "New edited text" + assert isinstance(mock_static_text.update.call_args[0][0], Text) + assert mock_static_text.update.call_args[0][0].plain == "New edited text" + + mock_ccl.edit_message_content.assert_called_with( + mock_app.chachanotes_db, "msg_123", "New edited text", 0 + ) + assert mock_action_widget.message_version_internal == 1 # Version incremented + assert "✏️" in mock_button.label # Check for edit emoji + + +@patch('tldw_chatbook.Event_Handlers.Chat_Events.chat_events.load_character_and_image') +async def test_handle_chat_load_character_with_greeting(mock_load_char, mock_app): + """Test that loading a character into an empty, ephemeral chat posts a greeting.""" + mock_app.current_chat_is_ephemeral = True + mock_app.query_one("#chat-log").query.return_value = [] # Empty chat log + + char_data = { + 'id': 'char_abc', 'name': 'Greeter', 'first_message': 'Hello, adventurer!' + } + mock_load_char.return_value = (char_data, None, None) + + # Mock the list item from the character search list + mock_list_item = MagicMock(spec=ListItem) + mock_list_item.character_id = 'char_abc' + mock_app.query_one("#chat-character-search-results-list").highlighted_child = mock_list_item + + with patch('tldw_chatbook.Event_Handlers.Chat_Events.chat_events.ChatMessage', + new_callable=AsyncMock) as mock_chat_msg_class: + await handle_chat_load_character_button_pressed(mock_app, MagicMock()) + + # Assert character data is loaded + assert mock_app.current_chat_active_character_data == char_data + + # Assert greeting message was created and mounted + mock_chat_msg_class.assert_called_with( + message='Hello, adventurer!', + role='Greeter', + generation_complete=True + ) + mock_app.query_one("#chat-log").mount.assert_called_once_with(mock_chat_msg_class.return_value) \ No newline at end of file diff --git a/Tests/Event_Handlers/Chat_Events/test_chat_events_sidebar.py b/Tests/Event_Handlers/Chat_Events/test_chat_events_sidebar.py new file mode 100644 index 00000000..6cec77c1 --- /dev/null +++ b/Tests/Event_Handlers/Chat_Events/test_chat_events_sidebar.py @@ -0,0 +1,227 @@ +# /tests/Event_Handlers/Chat_Events/test_chat_events_sidebar.py + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from textual.widgets import Button, Input, ListView, TextArea, ListItem, Label +from textual.css.query import QueryError + +# Functions to test +from tldw_chatbook.Event_Handlers.Chat_Events.chat_events_sidebar import ( + _disable_media_copy_buttons, + perform_media_sidebar_search, + handle_chat_media_search_input_changed, + handle_chat_media_load_selected_button_pressed, + handle_chat_media_copy_title_button_pressed, + handle_chat_media_copy_content_button_pressed, + handle_chat_media_copy_author_button_pressed, + handle_chat_media_copy_url_button_pressed, +) + +pytestmark = pytest.mark.asyncio + + +@pytest.fixture +def mock_app(): + """Provides a comprehensive mock of the TldwCli app.""" + app = AsyncMock() + + # Mock UI components + app.query_one = MagicMock() + mock_results_list = AsyncMock(spec=ListView) + mock_review_display = AsyncMock(spec=TextArea) + mock_copy_title_btn = MagicMock(spec=Button) + mock_copy_content_btn = MagicMock(spec=Button) + mock_copy_author_btn = MagicMock(spec=Button) + mock_copy_url_btn = MagicMock(spec=Button) + mock_search_input = MagicMock(spec=Input) + + # Configure query_one to return the correct mock widget + def query_one_side_effect(selector, _type): + if selector == "#chat-media-search-results-listview": + return mock_results_list + if selector == "#chat-media-content-display": + return mock_review_display + if selector == "#chat-media-copy-title-button": + return mock_copy_title_btn + if selector == "#chat-media-copy-content-button": + return mock_copy_content_btn + if selector == "#chat-media-copy-author-button": + return mock_copy_author_btn + if selector == "#chat-media-copy-url-button": + return mock_copy_url_btn + if selector == "#chat-media-search-input": + return mock_search_input + raise QueryError(f"Widget not found: {selector}") + + app.query_one.side_effect = query_one_side_effect + + # Mock DB and state + app.media_db = MagicMock() + app.current_sidebar_media_item = None + + # Mock app methods + app.notify = AsyncMock() + app.copy_to_clipboard = MagicMock() + app.set_timer = MagicMock() + app.run_worker = MagicMock() + + # For debouncing timer + app._media_sidebar_search_timer = None + + return app + + +async def test_disable_media_copy_buttons(mock_app): + """Test that all copy buttons are disabled and the current item is cleared.""" + await _disable_media_copy_buttons(mock_app) + + assert mock_app.current_sidebar_media_item is None + assert mock_app.query_one("#chat-media-copy-title-button", Button).disabled is True + assert mock_app.query_one("#chat-media-copy-content-button", Button).disabled is True + assert mock_app.query_one("#chat-media-copy-author-button", Button).disabled is True + assert mock_app.query_one("#chat-media-copy-url-button", Button).disabled is True + + +async def test_perform_media_sidebar_search_with_results(mock_app): + """Test searching with a term that returns results.""" + mock_media_items = [ + {'title': 'Test Title 1', 'media_id': 'id12345678'}, + {'title': 'Test Title 2', 'media_id': 'id87654321'}, + ] + mock_app.media_db.search_media_db.return_value = mock_media_items + mock_results_list = mock_app.query_one("#chat-media-search-results-listview", ListView) + + with patch('tldw_chatbook.Event_Handlers.Chat_Events.chat_events_sidebar.ListItem', + side_effect=ListItem) as mock_list_item_class: + await perform_media_sidebar_search(mock_app, "test term") + + mock_results_list.clear.assert_called_once() + mock_app.query_one("#chat-media-content-display", TextArea).clear.assert_called_once() + mock_app.media_db.search_media_db.assert_called_once() + assert mock_results_list.append.call_count == 2 + + # Check that ListItem was called with a Label containing the correct text + first_call_args = mock_list_item_class.call_args_list[0].args + assert isinstance(first_call_args[0], Label) + assert "Test Title 1" in first_call_args[0].renderable + + +async def test_perform_media_sidebar_search_no_results(mock_app): + """Test searching with a term that returns no results.""" + mock_app.media_db.search_media_db.return_value = [] + mock_results_list = mock_app.query_one("#chat-media-search-results-listview", ListView) + + await perform_media_sidebar_search(mock_app, "no results term") + + mock_results_list.append.assert_called_once() + # The call argument is a ListItem, which contains a Label. We check the Label's content. + call_arg = mock_results_list.append.call_args[0][0] + assert isinstance(call_arg, ListItem) + assert call_arg.children[0].renderable == "No media found." + + +async def test_perform_media_sidebar_search_empty_term(mock_app): + """Test that an empty search term clears results and does not search.""" + await perform_media_sidebar_search(mock_app, "") + mock_app.media_db.search_media_db.assert_not_called() + mock_app.query_one("#chat-media-search-results-listview", ListView).clear.assert_called_once() + + +async def test_handle_chat_media_search_input_changed_debouncing(mock_app): + """Test that input changes are debounced via set_timer.""" + mock_timer = MagicMock() + mock_app._media_sidebar_search_timer = mock_timer + mock_input_widget = MagicMock(spec=Input, value=" new search ") + + await handle_chat_media_search_input_changed(mock_app, mock_input_widget) + + mock_timer.stop.assert_called_once() + mock_app.set_timer.assert_called_once() + # Check that run_worker is part of the callback, which calls perform_media_sidebar_search + callback_lambda = mock_app.set_timer.call_args[0][1] + # We can't easily execute the lambda here, but we can verify it's set. + assert callable(callback_lambda) + + +async def test_handle_chat_media_load_selected_button_pressed(mock_app): + """Test loading a selected media item into the display.""" + media_data = { + 'title': 'Loaded Title', 'author': 'Author Name', 'media_type': 'Article', + 'url': 'http://example.com', 'content': 'This is the full content.' + } + mock_list_item = MagicMock(spec=ListItem) + mock_list_item.media_data = media_data + + mock_results_list = mock_app.query_one("#chat-media-search-results-listview", ListView) + mock_results_list.highlighted_child = mock_list_item + + await handle_chat_media_load_selected_button_pressed(mock_app, MagicMock()) + + assert mock_app.current_sidebar_media_item == media_data + mock_app.query_one("#chat-media-content-display", TextArea).load_text.assert_called_once() + loaded_text = mock_app.query_one("#chat-media-content-display", TextArea).load_text.call_args[0][0] + assert "Title: Loaded Title" in loaded_text + assert "Author: Author Name" in loaded_text + assert "This is the full content." in loaded_text + + assert mock_app.query_one("#chat-media-copy-title-button", Button).disabled is False + + +async def test_handle_chat_media_load_selected_no_selection(mock_app): + """Test load button when nothing is selected.""" + mock_results_list = mock_app.query_one("#chat-media-search-results-listview", ListView) + mock_results_list.highlighted_child = None + + await handle_chat_media_load_selected_button_pressed(mock_app, MagicMock()) + + mock_app.notify.assert_called_with("No media item selected.", severity="warning") + mock_app.query_one("#chat-media-content-display", TextArea).clear.assert_called_once() + assert mock_app.query_one("#chat-media-copy-title-button", Button).disabled is True + + +async def test_handle_copy_buttons_with_data(mock_app): + """Test all copy buttons when data is available.""" + media_data = {'title': 'Copy Title', 'content': 'Copy Content', 'author': 'Copy Author', 'url': 'http://copy.url'} + mock_app.current_sidebar_media_item = media_data + + # Test copy title + await handle_chat_media_copy_title_button_pressed(mock_app, MagicMock()) + mock_app.copy_to_clipboard.assert_called_with('Copy Title') + mock_app.notify.assert_called_with("Title copied to clipboard.") + + # Test copy content + await handle_chat_media_copy_content_button_pressed(mock_app, MagicMock()) + mock_app.copy_to_clipboard.assert_called_with('Copy Content') + mock_app.notify.assert_called_with("Content copied to clipboard.") + + # Test copy author + await handle_chat_media_copy_author_button_pressed(mock_app, MagicMock()) + mock_app.copy_to_clipboard.assert_called_with('Copy Author') + mock_app.notify.assert_called_with("Author copied to clipboard.") + + # Test copy URL + await handle_chat_media_copy_url_button_pressed(mock_app, MagicMock()) + mock_app.copy_to_clipboard.assert_called_with('http://copy.url') + mock_app.notify.assert_called_with("URL copied to clipboard.") + + +async def test_handle_copy_buttons_no_data(mock_app): + """Test copy buttons when data is not available.""" + mock_app.current_sidebar_media_item = None + + # Test copy title + await handle_chat_media_copy_title_button_pressed(mock_app, MagicMock()) + mock_app.notify.assert_called_with("No media title to copy.", severity="warning") + + # Test copy content + await handle_chat_media_copy_content_button_pressed(mock_app, MagicMock()) + mock_app.notify.assert_called_with("No media content to copy.", severity="warning") + + # Test copy author + await handle_chat_media_copy_author_button_pressed(mock_app, MagicMock()) + mock_app.notify.assert_called_with("No media author to copy.", severity="warning") + + # Test copy URL + await handle_chat_media_copy_url_button_pressed(mock_app, MagicMock()) + mock_app.notify.assert_called_with("No media URL to copy.", severity="warning") \ No newline at end of file diff --git a/Tests/Event_Handlers/Chat_Events/test_chat_streaming_events.py b/Tests/Event_Handlers/Chat_Events/test_chat_streaming_events.py new file mode 100644 index 00000000..39fe0f46 --- /dev/null +++ b/Tests/Event_Handlers/Chat_Events/test_chat_streaming_events.py @@ -0,0 +1,158 @@ +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from rich.text import Text +from textual.containers import VerticalScroll +from textual.widgets import Static, TextArea + +from tldw_chatbook.Event_Handlers.worker_events import StreamingChunk, StreamDone +from tldw_chatbook.Constants import TAB_CHAT, TAB_CCP + +# Functions to test (they are methods on the app, so we test them as such) +from tldw_chatbook.Event_Handlers.Chat_Events.chat_streaming_events import ( + handle_streaming_chunk, + handle_stream_done +) + +pytestmark = pytest.mark.asyncio + + +@pytest.fixture +def mock_app(): + """Provides a mock app instance ('self' for the handlers).""" + app = AsyncMock() + + # Mock logger and config + app.loguru_logger = MagicMock() + app.app_config = {"chat_defaults": {"strip_thinking_tags": True}} + + # Mock UI state and components + app.current_tab = TAB_CHAT + mock_static_text = AsyncMock(spec=Static) + mock_chat_message_widget = AsyncMock() + mock_chat_message_widget.is_mounted = True + mock_chat_message_widget.message_text = "" + mock_chat_message_widget.query_one.return_value = mock_static_text + app.current_ai_message_widget = mock_chat_message_widget + + mock_chat_log = AsyncMock(spec=VerticalScroll) + mock_chat_input = AsyncMock(spec=TextArea) + + app.query_one = MagicMock(side_effect=lambda sel, type: mock_chat_log if sel == "#chat-log" else mock_chat_input) + + # Mock DB and state + app.chachanotes_db = MagicMock() + app.current_chat_conversation_id = "conv_123" + app.current_chat_is_ephemeral = False + + # Mock app methods + app.notify = AsyncMock() + + return app + + +async def test_handle_streaming_chunk_appends_text(mock_app): + """Test that a streaming chunk appends text and updates the widget.""" + event = StreamingChunk(text_chunk="Hello, ") + mock_app.current_ai_message_widget.message_text = "Initial." + + await handle_streaming_chunk(mock_app, event) + + assert mock_app.current_ai_message_widget.message_text == "Initial.Hello, " + + # Check that update is called with the full, escaped text + with patch('tldw_chatbook.Event_Handlers.Chat_Events.chat_streaming_events.escape_markup', + return_value="Escaped: Initial.Hello, ") as mock_escape: + await handle_streaming_chunk(mock_app, event) + mock_escape.assert_called_with("Initial.Hello, Hello, ") + mock_app.current_ai_message_widget.query_one().update.assert_called_with("Escaped: Initial.Hello, ") + + # Check that scroll_end is called + mock_app.query_one.assert_called_with("#chat-log", VerticalScroll) + mock_app.query_one().scroll_end.assert_called() + + +async def test_handle_stream_done_success_and_save(mock_app): + """Test successful stream completion and saving to DB.""" + event = StreamDone(full_text="This is the final response.", error=None) + mock_app.current_ai_message_widget.role = "AI" + + with patch('tldw_chatbook.Event_Handlers.Chat_Events.chat_streaming_events.ccl') as mock_ccl: + # Mock DB returns for getting the saved message details + mock_ccl.add_message_to_conversation.return_value = "msg_abc" + mock_app.chachanotes_db.get_message_by_id.return_value = {'id': 'msg_abc', 'version': 0} + + await handle_stream_done(mock_app, event) + + # Assert UI update + mock_app.current_ai_message_widget.query_one().update.assert_called_with("This is the final response.") + mock_app.current_ai_message_widget.mark_generation_complete.assert_called_once() + + # Assert DB call + mock_ccl.add_message_to_conversation.assert_called_once_with( + mock_app.chachanotes_db, "conv_123", "AI", "This is the final response." + ) + assert mock_app.current_ai_message_widget.message_id_internal == 'msg_abc' + assert mock_app.current_ai_message_widget.message_version_internal == 0 + + # Assert state reset + assert mock_app.current_ai_message_widget is None + mock_app.query_one().focus.assert_called_once() + + +async def test_handle_stream_done_with_tag_stripping(mock_app): + """Test that tags are stripped from the final text before saving.""" + full_text = "I should start.This is the actual response.I am done now." + expected_text = "This is the actual response." + event = StreamDone(full_text=full_text, error=None) + mock_app.app_config["chat_defaults"]["strip_thinking_tags"] = True + + with patch('tldw_chatbook.Event_Handlers.Chat_Events.chat_streaming_events.ccl') as mock_ccl: + await handle_stream_done(mock_app, event) + + # Check that the saved text is the stripped version + mock_ccl.add_message_to_conversation.assert_called_once() + saved_text = mock_ccl.add_message_to_conversation.call_args[0][3] + assert saved_text == expected_text + + # Check that the displayed text is also the stripped version (escaped) + mock_app.current_ai_message_widget.query_one().update.assert_called_with(expected_text) + + +async def test_handle_stream_done_with_error(mock_app): + """Test stream completion when an error occurred.""" + event = StreamDone(full_text="Partial response.", error="API limit reached") + + with patch('tldw_chatbook.Event_Handlers.Chat_Events.chat_streaming_events.ccl') as mock_ccl: + await handle_stream_done(mock_app, event) + + # Assert UI is updated with error message + mock_static_widget = mock_app.current_ai_message_widget.query_one() + mock_static_widget.update.assert_called_once() + update_call_arg = mock_static_widget.update.call_args[0][0] + assert isinstance(update_call_arg, Text) + assert "Partial response." in update_call_arg.plain + assert "Stream Error" in update_call_arg.plain + assert "API limit reached" in update_call_arg.plain + + # Assert role is changed and DB is NOT called + assert mock_app.current_ai_message_widget.role == "System" + mock_ccl.add_message_to_conversation.assert_not_called() + + # Assert state reset + assert mock_app.current_ai_message_widget is None + + +async def test_handle_stream_done_no_widget(mock_app): + """Test graceful handling when the AI widget is missing.""" + mock_app.current_ai_message_widget = None + event = StreamDone(full_text="Some text", error="Some error") + + await handle_stream_done(mock_app, event) + + # Just ensure it doesn't crash and notifies about the error + mock_app.notify.assert_called_once_with( + "Stream error (display widget missing): Some error", + severity="error", + timeout=10 + ) \ No newline at end of file diff --git a/Tests/Sync/__init__.py b/Tests/Event_Handlers/__init__.py similarity index 100% rename from Tests/Sync/__init__.py rename to Tests/Event_Handlers/__init__.py diff --git a/Tests/LLM_Management/__init__.py b/Tests/LLM_Management/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Tests/MediaDB2/conftest.py b/Tests/MediaDB2/conftest.py deleted file mode 100644 index f6e4c1dc..00000000 --- a/Tests/MediaDB2/conftest.py +++ /dev/null @@ -1,82 +0,0 @@ -# tests/conftest.py -import pytest -import tempfile -import os -import shutil -from pathlib import Path -import sys - -# Add src directory to sys.path to allow importing library/engine code -src_path = Path(__file__).parent.parent / "src" -sys.path.insert(0, str(src_path)) - -# Now import after modifying path -try: - from tldw_Server_API.app.core.DB_Management.Media_DB_v2 import MediaDatabase -except ImportError as e: - print(f"ERROR in conftest: Could not import Database from sqlite_db. Error: {e}") - # Define dummy class if import fails to avoid crashing pytest collection - class Database: - def __init__(self, *args, **kwargs): pass - def close_connection(self): pass - def get_sync_log_entries(self, *args, **kwargs): return [] - def execute_query(self, *args, **kwargs): - class MockCursor: - rowcount = 0 - def fetchone(self): return None - def fetchall(self): return [] - def execute(self, *a, **k): pass - return MockCursor() - def transaction(self): - class MockTransaction: - def __enter__(self): return None # Return a mock connection/cursor if needed - def __exit__(self, *args): pass - return MockTransaction() - - -# --- Database Fixtures --- - -@pytest.fixture(scope="function") -def temp_db_path(): - """Creates a temporary directory and returns a unique DB path within it.""" - temp_dir = tempfile.mkdtemp() - db_file = Path(temp_dir) / "test_db.sqlite" - yield str(db_file) # Provide the path to the test function - # Teardown: remove the directory after the test function finishes - shutil.rmtree(temp_dir, ignore_errors=True) - -@pytest.fixture(scope="function") -def memory_db_factory(): - """Factory fixture to create in-memory Database instances.""" - created_dbs = [] - def _create_db(client_id="test_client"): - db = MediaDatabase(db_path=":memory:", client_id=client_id) - created_dbs.append(db) - return db - yield _create_db - # Teardown: close connections for all created in-memory DBs - for db in created_dbs: - try: - db.close_connection() - except: # Ignore errors during cleanup - pass - -@pytest.fixture(scope="function") -def file_db(temp_db_path): - """Creates a file-based Database instance using a temporary path.""" - db = MediaDatabase(db_path=temp_db_path, client_id="file_client") - yield db - db.close_connection() # Ensure connection is closed - -# --- Sync Engine State Fixtures --- - -@pytest.fixture(scope="function") -def temp_state_file(): - """Provides a path to a temporary file for sync state.""" - # Use NamedTemporaryFile which handles deletion automatically - with tempfile.NamedTemporaryFile(delete=False, mode='w', suffix=".json") as tf: - state_path = tf.name - yield state_path # Provide the path - # Teardown: Ensure the file is deleted even if the test fails mid-write - if os.path.exists(state_path): - os.remove(state_path) \ No newline at end of file diff --git a/Tests/MediaDB2/test_sqlite_db.py b/Tests/MediaDB2/test_sqlite_db.py deleted file mode 100644 index 73049081..00000000 --- a/Tests/MediaDB2/test_sqlite_db.py +++ /dev/null @@ -1,668 +0,0 @@ -# tests/test_sqlite_db.py -# Description: Unit tests for SQLite database operations, including CRUD, transactions, and sync log management. -# -# Imports: -import json -import os -import shutil -import tempfile -from pathlib import Path - -import pytest -import time -import sqlite3 -from datetime import datetime, timezone, timedelta - -from tldw_Server_API.app.core.DB_Management.Media_DB_v2 import MediaDatabase, ConflictError, DatabaseError - - -# -# 3rd-Party Imports: -# -# Local imports -# Import from src using adjusted sys.path in conftest -# -####################################################################################################################### -# -# Functions: - -# Helper to get sync log entries for assertions -def get_log_count(db: MediaDatabase, entity_uuid: str) -> int: - cursor = db.execute_query("SELECT COUNT(*) FROM sync_log WHERE entity_uuid = ?", (entity_uuid,)) - return cursor.fetchone()[0] - -def get_latest_log(db: MediaDatabase, entity_uuid: str) -> dict | None: - cursor = db.execute_query( - "SELECT * FROM sync_log WHERE entity_uuid = ? ORDER BY change_id DESC LIMIT 1", - (entity_uuid,) - ) - row = cursor.fetchone() - return dict(row) if row else None - -def get_entity_version(db: MediaDatabase, entity_table: str, uuid: str) -> int | None: - cursor = db.execute_query(f"SELECT version FROM {entity_table} WHERE uuid = ?", (uuid,)) - row = cursor.fetchone() - return row['version'] if row else None - -class TestDatabaseInitialization: - def test_memory_db_creation(self, memory_db_factory): - """Test creating an in-memory database.""" - db = memory_db_factory("client_mem") - assert db.is_memory_db - assert db.client_id == "client_mem" - # Check if a table exists (schema creation check) - cursor = db.execute_query("SELECT name FROM sqlite_master WHERE type='table' AND name='Media'") - assert cursor.fetchone() is not None - db.close_connection() - - def test_file_db_creation(self, file_db, temp_db_path): - """Test creating a file-based database.""" - assert not file_db.is_memory_db - assert file_db.client_id == "file_client" - assert os.path.exists(temp_db_path) - cursor = file_db.execute_query("SELECT name FROM sqlite_master WHERE type='table' AND name='Media'") - assert cursor.fetchone() is not None - # file_db fixture handles closure - - def test_missing_client_id(self): - """Test that ValueError is raised if client_id is missing.""" - with pytest.raises(ValueError, match="Client ID cannot be empty"): - MediaDatabase(db_path=":memory:", client_id="") - with pytest.raises(ValueError, match="Client ID cannot be empty"): - MediaDatabase(db_path=":memory:", client_id=None) - -def test_schema_versioning_new_file_db(file_db): # Use the file_db fixture - """Test that a new file DB gets the correct schema version.""" - # Initialization happened in the fixture - cursor = file_db.execute_query("SELECT version FROM schema_version") - version = cursor.fetchone()['version'] - assert version == MediaDatabase._CURRENT_SCHEMA_VERSION - -class TestDatabaseTransactions: - def test_transaction_commit(self, memory_db_factory): - db = memory_db_factory() - keyword = "commit_test" - with db.transaction(): - # Use internal method _add_keyword_internal or simplified version for test - kw_id, kw_uuid = db.add_keyword(keyword) # add_keyword uses transaction internally too, nested is ok - # Verify outside transaction - cursor = db.execute_query("SELECT keyword FROM Keywords WHERE id = ?", (kw_id,)) - assert cursor.fetchone()['keyword'] == keyword - - def test_transaction_rollback(self, memory_db_factory): - db = memory_db_factory() - keyword = "rollback_test" - initial_count_cursor = db.execute_query("SELECT COUNT(*) FROM Keywords") - initial_count = initial_count_cursor.fetchone()[0] - try: - with db.transaction(): - # Simplified insert for test clarity - new_uuid = db._generate_uuid() - db.execute_query( - "INSERT INTO Keywords (keyword, uuid, last_modified, version, client_id, deleted) VALUES (?, ?, ?, 1, ?, 0)", - (keyword, new_uuid, db._get_current_utc_timestamp_str(), db.client_id), - commit=False # Important: commit=False inside transaction block - ) - # Check *inside* transaction - cursor_inside = db.execute_query("SELECT COUNT(*) FROM Keywords") - assert cursor_inside.fetchone()[0] == initial_count + 1 - raise ValueError("Simulating error to trigger rollback") # Force rollback - except ValueError: - pass # Expected error - except Exception as e: - pytest.fail(f"Unexpected exception during rollback test: {e}") - - # Verify outside transaction (count should be back to initial) - final_count_cursor = db.execute_query("SELECT COUNT(*) FROM Keywords") - assert final_count_cursor.fetchone()[0] == initial_count - - -class TestDatabaseCRUDAndSync: - - @pytest.fixture - def db_instance(self, memory_db_factory): - """Provides a fresh in-memory DB for each test in this class.""" - return memory_db_factory("crud_client") - - def test_add_keyword(self, db_instance): - keyword = " test keyword " - expected_keyword = "test keyword" - kw_id, kw_uuid = db_instance.add_keyword(keyword) - - assert kw_id is not None - assert kw_uuid is not None - - # Verify DB state - cursor = db_instance.execute_query("SELECT * FROM Keywords WHERE id = ?", (kw_id,)) - row = cursor.fetchone() - assert row['keyword'] == expected_keyword - assert row['uuid'] == kw_uuid - assert row['version'] == 1 - assert row['client_id'] == db_instance.client_id - assert not row['deleted'] - - # Verify Sync Log - log_entry = get_latest_log(db_instance, kw_uuid) - assert log_entry['operation'] == 'create' - assert log_entry['entity'] == 'Keywords' - assert log_entry['version'] == 1 - assert log_entry['client_id'] == db_instance.client_id - payload = json.loads(log_entry['payload']) - assert payload['keyword'] == expected_keyword - assert payload['uuid'] == kw_uuid - - def test_add_existing_keyword(self, db_instance): - keyword = "existing" - kw_id1, kw_uuid1 = db_instance.add_keyword(keyword) - log_count1 = get_log_count(db_instance, kw_uuid1) - - kw_id2, kw_uuid2 = db_instance.add_keyword(keyword) # Add again - log_count2 = get_log_count(db_instance, kw_uuid1) - - assert kw_id1 == kw_id2 - assert kw_uuid1 == kw_uuid2 - assert log_count1 == log_count2 # No new log entry - - def test_soft_delete_keyword(self, db_instance): - keyword = "to_delete" - kw_id, kw_uuid = db_instance.add_keyword(keyword) - initial_version = get_entity_version(db_instance, "Keywords", kw_uuid) - - deleted = db_instance.soft_delete_keyword(keyword) - assert deleted is True - - # Verify DB state - cursor = db_instance.execute_query("SELECT deleted, version FROM Keywords WHERE id = ?", (kw_id,)) - row = cursor.fetchone() - assert row['deleted'] == 1 - assert row['version'] == initial_version + 1 - - # Verify Sync Log - log_entry = get_latest_log(db_instance, kw_uuid) - assert log_entry['operation'] == 'delete' - assert log_entry['entity'] == 'Keywords' - assert log_entry['version'] == initial_version + 1 - payload = json.loads(log_entry['payload']) - assert payload['uuid'] == kw_uuid # Delete payload is minimal - - def test_undelete_keyword(self, db_instance): - keyword = "to_undelete" - kw_id, kw_uuid = db_instance.add_keyword(keyword) - db_instance.soft_delete_keyword(keyword) # Delete it first - deleted_version = get_entity_version(db_instance, "Keywords", kw_uuid) - - # Adding it again should undelete it - undelete_id, undelete_uuid = db_instance.add_keyword(keyword) - - assert undelete_id == kw_id - assert undelete_uuid == kw_uuid - - # Verify DB state - cursor = db_instance.execute_query("SELECT deleted, version FROM Keywords WHERE id = ?", (kw_id,)) - row = cursor.fetchone() - assert row['deleted'] == 0 - assert row['version'] == deleted_version + 1 - - # Verify Sync Log - log_entry = get_latest_log(db_instance, kw_uuid) - # Undelete is logged as an 'update' - assert log_entry['operation'] == 'update' - assert log_entry['entity'] == 'Keywords' - assert log_entry['version'] == deleted_version + 1 - payload = json.loads(log_entry['payload']) - assert payload['uuid'] == kw_uuid - assert payload['deleted'] == 0 # Payload shows undeleted state - - def test_add_media_with_keywords_create(self, db_instance): - title = "Test Media Create" - content = "Some unique content for create." - keywords = ["create_kw1", "create_kw2"] - - media_id, media_uuid, msg = db_instance.add_media_with_keywords( - title=title, - media_type="article", - content=content, - keywords=keywords, - author="Tester" - ) - - assert media_id is not None - assert media_uuid is not None - # FIX: Adjust assertion to match actual return message - assert msg == f"Media '{title}' added." - - # Verify DB state (unchanged) - cursor = db_instance.execute_query("SELECT * FROM Media WHERE id = ?", (media_id,)) - media_row = cursor.fetchone() - assert media_row['title'] == title - assert media_row['uuid'] == media_uuid - assert media_row['version'] == 1 # Initial version - assert not media_row['deleted'] - - # Verify Keywords exist (unchanged) - cursor = db_instance.execute_query("SELECT COUNT(*) FROM Keywords WHERE keyword IN (?, ?)", tuple(keywords)) - assert cursor.fetchone()[0] == 2 - - # Verify MediaKeywords links (unchanged) - cursor = db_instance.execute_query("SELECT COUNT(*) FROM MediaKeywords WHERE media_id = ?", (media_id,)) - assert cursor.fetchone()[0] == 2 - - # Verify DocumentVersion creation (unchanged) - cursor = db_instance.execute_query("SELECT version_number, content FROM DocumentVersions WHERE media_id = ? ORDER BY version_number DESC LIMIT 1", (media_id,)) - version_row = cursor.fetchone() - assert version_row['version_number'] == 1 - assert version_row['content'] == content - - # Verify Sync Log for Media (Now Python generated) - log_entry = get_latest_log(db_instance, media_uuid) - # The *last* log might be DocumentVersion or MediaKeywords link depending on order. - # Find the Media create log specifically. - cursor_log = db_instance.execute_query("SELECT * FROM sync_log WHERE entity_uuid = ? AND operation = 'create' AND entity = 'Media'", (media_uuid,)) - log_entry = dict(cursor_log.fetchone()) - - assert log_entry['operation'] == 'create' - assert log_entry['entity'] == 'Media' - assert log_entry['version'] == 1 # Check version - payload = json.loads(log_entry['payload']) - assert payload['uuid'] == media_uuid - assert payload['title'] == title - - - def test_add_media_with_keywords_update(self, db_instance): - title = "Test Media Update" - content1 = "Initial content." - content2 = "Updated content." - keywords1 = ["update_kw1"] - keywords2 = ["update_kw2", "update_kw3"] - - media_id, media_uuid, _ = db_instance.add_media_with_keywords( - title=title, media_type="text", content=content1, keywords=keywords1 - ) - initial_version = get_entity_version(db_instance, "Media", media_uuid) - cursor_check_initial = db_instance.execute_query("SELECT content_hash FROM Media WHERE id = ?", (media_id,)) - initial_hash_row = cursor_check_initial.fetchone() - assert initial_hash_row is not None - initial_content_hash = initial_hash_row['content_hash'] - - # Update 1: Using explicit URL (optional part of test) - generated_url = f"local://text/{initial_content_hash}" - media_id_up1, media_uuid_up1, msg1 = db_instance.add_media_with_keywords( - title=title + " Updated Via URL", media_type="text", content=content2, - keywords=["url_update_kw"], overwrite=True, url=generated_url - ) - assert media_id_up1 == media_id - assert media_uuid_up1 == media_uuid - # FIX: Adjust assertion - assert msg1 == f"Media '{title + ' Updated Via URL'}' updated." - version_after_update1 = get_entity_version(db_instance, "Media", media_uuid) - assert version_after_update1 == initial_version + 1 - - # Update 2: Simulate finding by hash (URL=None) - media_id_up2, media_uuid_up2, msg2 = db_instance.add_media_with_keywords( - title=title + " Updated Via Hash", media_type="text", content=content2, - keywords=keywords2, overwrite=True, url=None - ) - assert media_id_up2 == media_id - assert media_uuid_up2 == media_uuid - # FIX: Adjust assertion - assert msg2 == f"Media '{title + ' Updated Via Hash'}' updated." - - # Verify Final State (unchanged checks for DB content) - cursor = db_instance.execute_query("SELECT title, content, version FROM Media WHERE id = ?", (media_id,)) - media_row = cursor.fetchone() - assert media_row['title'] == title + " Updated Via Hash" - assert media_row['content'] == content2 - assert media_row['version'] == version_after_update1 + 1 - - # Verify Keywords links updated (unchanged) - cursor = db_instance.execute_query("SELECT k.keyword FROM MediaKeywords mk JOIN Keywords k ON mk.keyword_id = k.id WHERE mk.media_id = ? ORDER BY k.keyword", (media_id,)) - current_keywords = [r['keyword'] for r in cursor.fetchall()] - assert current_keywords == sorted(keywords2) - - # Verify latest DocumentVersion (unchanged) - cursor = db_instance.execute_query("SELECT version_number, content FROM DocumentVersions WHERE media_id = ? ORDER BY version_number DESC LIMIT 1", (media_id,)) - version_row = cursor.fetchone(); assert version_row['version_number'] == 3; assert version_row['content'] == content2 - - # Verify Sync Log for the *last* Media update - log_entry = get_latest_log(db_instance, media_uuid) # Should be the Media update - assert log_entry['operation'] == 'update' - assert log_entry['entity'] == 'Media' - assert log_entry['version'] == version_after_update1 + 1 - payload = json.loads(log_entry['payload']) - assert payload['title'] == title + " Updated Via Hash" - - def test_soft_delete_media_cascade(self, db_instance): - # 1. Setup complex item - media_id, media_uuid, _ = db_instance.add_media_with_keywords( - title="Cascade Test", content="Cascade content", media_type="article", - keywords=["cascade1", "cascade2"], author="Cascade Author" - ) - # Add a transcript manually (assuming no direct add_transcript method) - t_uuid = db_instance._generate_uuid() - db_instance.execute_query( - """INSERT INTO Transcripts (media_id, whisper_model, transcription, uuid, last_modified, version, client_id, deleted) - VALUES (?, ?, ?, ?, ?, 1, ?, 0)""", - (media_id, "model_xyz", "Transcript text", t_uuid, db_instance._get_current_utc_timestamp_str(), db_instance.client_id), - commit=True - ) - # Add a chunk manually - c_uuid = db_instance._generate_uuid() - db_instance.execute_query( - """INSERT INTO MediaChunks (media_id, chunk_text, uuid, last_modified, version, client_id, deleted) - VALUES (?, ?, ?, ?, 1, ?, 0)""", - (media_id, "Chunk text", c_uuid, db_instance._get_current_utc_timestamp_str(), db_instance.client_id), - commit=True - ) - media_version = get_entity_version(db_instance, "Media", media_uuid) - transcript_version = get_entity_version(db_instance, "Transcripts", t_uuid) - chunk_version = get_entity_version(db_instance, "MediaChunks", c_uuid) - - - # 2. Perform soft delete with cascade - deleted = db_instance.soft_delete_media(media_id, cascade=True) - assert deleted is True - - # 3. Verify parent and children are marked deleted and versioned - cursor = db_instance.execute_query("SELECT deleted, version FROM Media WHERE id = ?", (media_id,)) - assert dict(cursor.fetchone()) == {'deleted': 1, 'version': media_version + 1} - - cursor = db_instance.execute_query("SELECT deleted, version FROM Transcripts WHERE uuid = ?", (t_uuid,)) - assert dict(cursor.fetchone()) == {'deleted': 1, 'version': transcript_version + 1} - - cursor = db_instance.execute_query("SELECT deleted, version FROM MediaChunks WHERE uuid = ?", (c_uuid,)) - assert dict(cursor.fetchone()) == {'deleted': 1, 'version': chunk_version + 1} - - # 4. Verify keywords are unlinked - cursor = db_instance.execute_query("SELECT COUNT(*) FROM MediaKeywords WHERE media_id = ?", (media_id,)) - assert cursor.fetchone()[0] == 0 - - # 5. Verify Sync Logs - media_log = get_latest_log(db_instance, media_uuid) - assert media_log['operation'] == 'delete' - assert media_log['version'] == media_version + 1 - - transcript_log = get_latest_log(db_instance, t_uuid) - assert transcript_log['operation'] == 'delete' - assert transcript_log['version'] == transcript_version + 1 - - chunk_log = get_latest_log(db_instance, c_uuid) - assert chunk_log['operation'] == 'delete' - assert chunk_log['version'] == chunk_version + 1 - - # Check MediaKeywords unlink logs (tricky to get exact UUIDs, check count) - cursor = db_instance.execute_query("SELECT COUNT(*) FROM sync_log WHERE entity = 'MediaKeywords' AND operation = 'unlink' AND payload LIKE ?", (f'%{media_uuid}%',)) - assert cursor.fetchone()[0] == 2 # Should be 2 unlink events - - def test_optimistic_locking_prevents_update_with_stale_version(self, db_instance): - """Test that an UPDATE with a stale version number fails (rowcount 0).""" - keyword = "conflict_test" - kw_id, kw_uuid = db_instance.add_keyword(keyword) - original_version = get_entity_version(db_instance, "Keywords", kw_uuid) # Should be 1 - assert original_version == 1, "Initial version should be 1" - - # Simulate external update incrementing version - db_instance.execute_query( - "UPDATE Keywords SET version = ?, client_id = ? WHERE id = ?", - (original_version + 1, "external_client", kw_id), - commit=True - ) - version_after_external_update = get_entity_version(db_instance, "Keywords", kw_uuid) # Should be 2 - assert version_after_external_update == original_version + 1, "Version after external update should be 2" - - # Now, manually attempt an update using the *original stale version* (version=1) - # This mimics what would happen if a process read version 1, then tried - # to update after the external process bumped it to version 2. - current_time = db_instance._get_current_utc_timestamp_str() - client_id = db_instance.client_id - cursor = db_instance.execute_query( - "UPDATE Keywords SET keyword='stale_update', last_modified=?, version=?, client_id=? WHERE id=? AND version=?", - (current_time, original_version + 1, client_id, kw_id, original_version), # <<< WHERE version = 1 (stale) - commit=True # Commit needed to actually perform the check - ) - - # Assert that the update failed because the WHERE clause (version=1) didn't match any rows - assert cursor.rowcount == 0, "Update with stale version should affect 0 rows" - - # Verify DB state is unchanged by the failed update (still shows external update's state) - cursor_check = db_instance.execute_query("SELECT keyword, version, client_id FROM Keywords WHERE id = ?", - (kw_id,)) - row = cursor_check.fetchone() - assert row is not None, "Keyword should still exist" - assert row['keyword'] == keyword, "Keyword text should not have changed to 'stale_update'" - assert row['version'] == original_version + 1, "Version should remain 2 from the external update" - assert row['client_id'] == "external_client", "Client ID should remain from the external update" - - def test_version_validation_trigger(self, db_instance): - """Test trigger preventing non-sequential version updates.""" - kw_id, kw_uuid = db_instance.add_keyword("validation_test") - current_version = get_entity_version(db_instance, "Keywords", kw_uuid) - - # Try to update version incorrectly (skipping a version) - with pytest.raises(sqlite3.IntegrityError, - match=r"Sync Error \(Keywords\): Version must increment by exactly 1"): - # Provide client_id to prevent the *other* validation trigger firing - client_id = db_instance.client_id - db_instance.execute_query( - "UPDATE Keywords SET version = ?, keyword = ?, client_id = ? WHERE id = ?", - (current_version + 2, "bad version", client_id, kw_id), - commit=True - ) - - # Try to update version incorrectly (same version) - with pytest.raises(sqlite3.IntegrityError, match=r"Sync Error \(Keywords\): Version must increment by exactly 1"): - client_id = db_instance.client_id - db_instance.execute_query( - "UPDATE Keywords SET version = ?, keyword = ?, client_id = ? WHERE id = ?", - (current_version + 2, "bad version", client_id, kw_id), - commit=True - ) - - def test_client_id_validation_trigger(self, db_instance): - """Test trigger preventing null/empty client_id on update.""" - kw_id, kw_uuid = db_instance.add_keyword("clientid_test") - current_version = get_entity_version(db_instance, "Keywords", kw_uuid) - - # Test the EMPTY STRING case handled by the trigger - # Use raw string for regex match safety - with pytest.raises(sqlite3.IntegrityError, match=r"Sync Error \(Keywords\): Client ID cannot be NULL or empty"): - db_instance.execute_query( - "UPDATE Keywords SET version = ?, client_id = '' WHERE id = ?", - (current_version + 1, kw_id), - commit=True - ) - - # Optional: Test the NULL case separately, expecting the NOT NULL constraint error - # This confirms the underlying table constraint works, though not the trigger message. - with pytest.raises(sqlite3.IntegrityError, match=r"Sync Error \(Keywords\): Client ID cannot be NULL or empty"): - db_instance.execute_query( - "UPDATE Keywords SET version = ?, client_id = NULL WHERE id = ?", - (current_version + 1, kw_id), # Increment version correctly - commit=True - ) - - -class TestSyncLogManagement: - - @pytest.fixture - def db_instance(self, memory_db_factory): - db = memory_db_factory("log_client") - # Add some initial data to generate logs - db.add_keyword("log_kw_1") - time.sleep(0.01) # Ensure timestamp difference - db.add_keyword("log_kw_2") - time.sleep(0.01) - db.add_keyword("log_kw_3") - db.soft_delete_keyword("log_kw_2") - return db - - def test_get_sync_log_entries_all(self, db_instance): - logs = db_instance.get_sync_log_entries() - # Expect 3 creates + 1 delete = 4 entries - assert len(logs) == 4 - assert logs[0]['change_id'] == 1 - assert logs[-1]['change_id'] == 4 - - def test_get_sync_log_entries_since(self, db_instance): - logs = db_instance.get_sync_log_entries(since_change_id=2) # Get 3 and 4 - assert len(logs) == 2 - assert logs[0]['change_id'] == 3 - assert logs[1]['change_id'] == 4 - - def test_get_sync_log_entries_limit(self, db_instance): - logs = db_instance.get_sync_log_entries(limit=2) # Get 1 and 2 - assert len(logs) == 2 - assert logs[0]['change_id'] == 1 - assert logs[1]['change_id'] == 2 - - def test_get_sync_log_entries_since_and_limit(self, db_instance): - logs = db_instance.get_sync_log_entries(since_change_id=1, limit=2) # Get 2 and 3 - assert len(logs) == 2 - assert logs[0]['change_id'] == 2 - assert logs[1]['change_id'] == 3 - - def test_delete_sync_log_entries_specific(self, db_instance): - initial_logs = db_instance.get_sync_log_entries() - initial_count = len(initial_logs) # Should be 4 - ids_to_delete = [initial_logs[1]['change_id'], initial_logs[2]['change_id']] # Delete 2 and 3 - - deleted_count = db_instance.delete_sync_log_entries(ids_to_delete) - assert deleted_count == 2 - - remaining_logs = db_instance.get_sync_log_entries() - assert len(remaining_logs) == initial_count - 2 - remaining_ids = {log['change_id'] for log in remaining_logs} - assert remaining_ids == {initial_logs[0]['change_id'], initial_logs[3]['change_id']} # 1 and 4 should remain - - def test_delete_sync_log_entries_before(self, db_instance): - initial_logs = db_instance.get_sync_log_entries() - initial_count = len(initial_logs) # Should be 4 - threshold_id = initial_logs[2]['change_id'] # Delete up to and including ID 3 - - deleted_count = db_instance.delete_sync_log_entries_before(threshold_id) - assert deleted_count == 3 # Deleted 1, 2, 3 - - remaining_logs = db_instance.get_sync_log_entries() - assert len(remaining_logs) == 1 - assert remaining_logs[0]['change_id'] == initial_logs[3]['change_id'] # Only 4 remains - - def test_delete_sync_log_entries_empty(self, db_instance): - deleted_count = db_instance.delete_sync_log_entries([]) - assert deleted_count == 0 - - def test_delete_sync_log_entries_invalid_id(self, db_instance): - with pytest.raises(ValueError): - db_instance.delete_sync_log_entries([1, "two", 3]) - - -# Add FTS specific tests -class TestDatabaseFTS: - @pytest.fixture - def db_instance(self, memory_db_factory): - # Use file DB for FTS tests if memory DB proves unstable - # return memory_db_factory("fts_client") - temp_dir = tempfile.mkdtemp() - db_file = Path(temp_dir) / "fts_test_db.sqlite" - db = MediaDatabase(db_path=str(db_file), client_id="fts_client") - yield db - db.close_connection() - shutil.rmtree(temp_dir, ignore_errors=True) - - def test_fts_media_create_search(self, db_instance): - """Test searching media via FTS after creation.""" - title = "FTS Test Alpha" - content = "Unique content string omega gamma beta." - media_id, media_uuid, _ = db_instance.add_media_with_keywords(title=title, content=content, media_type="fts_test") - - # Search by title fragment - results, total = MediaDatabase.search_media_db(db_instance, search_query="Alpha", search_fields=["title"]) - assert total == 1 - assert len(results) == 1 - assert results[0]['id'] == media_id - - # Search by content fragment - results, total = MediaDatabase.search_media_db(db_instance, search_query="omega", search_fields=["content"]) - assert total == 1 - assert len(results) == 1 - assert results[0]['id'] == media_id - - # Search by content phrase - results, total = MediaDatabase.search_media_db(db_instance, search_query='"omega gamma"', search_fields=["content"]) - assert total == 1 - - # Search non-existent term - results, total = MediaDatabase.search_media_db(db_instance, search_query="nonexistent", search_fields=["content", "title"]) - assert total == 0 - - def test_fts_media_update_search(self, db_instance): - """Test searching media via FTS after update.""" - title1 = "FTS Update Initial" - content1 = "Original text epsilon." - title2 = "FTS Update Final Zeta" - content2 = "Replacement stuff delta." - - media_id, media_uuid, _ = db_instance.add_media_with_keywords(title=title1, content=content1, - media_type="fts_update") - - # Verify initial search works - results, total = MediaDatabase.search_media_db(db_instance, search_query="epsilon", search_fields=["content"]) - assert total == 1 - initial_url = results[0]['url'] # Get URL for update lookup - - # Update the media - db_instance.add_media_with_keywords(title=title2, content=content2, media_type="fts_update", overwrite=True, - url=initial_url) - - # Search for OLD content - REMOVE this assertion as immediate consistency isn't guaranteed - # results, total = search_media_db(db_instance, search_query="epsilon", search_fields=["content"]) - # assert total == 0 - - # Search for NEW content should work - results, total = MediaDatabase.search_media_db(db_instance, search_query="delta", search_fields=["content"]) - assert total == 1 - assert results[0]['id'] == media_id - - # Search for NEW title should work - results, total = MediaDatabase.search_media_db(db_instance, search_query="Zeta", search_fields=["title"]) - assert total == 1 - assert results[0]['id'] == media_id - - def test_fts_media_delete_search(self, db_instance): - """Test searching media via FTS after soft deletion.""" - title = "FTS To Delete" - content = "Content will vanish theta." - media_id, media_uuid, _ = db_instance.add_media_with_keywords(title=title, content=content, media_type="fts_delete") - - # Verify initial search works - results, total = MediaDatabase.search_media_db(db_instance, search_query="theta", search_fields=["content"]) - assert total == 1 - - # Soft delete the media - deleted = db_instance.soft_delete_media(media_id) - assert deleted is True - - # Search should now fail - results, total = MediaDatabase.search_media_db(db_instance, search_query="theta", search_fields=["content"]) - assert total == 0 - - def test_fts_keyword_search(self, db_instance): - """Test searching keywords via FTS.""" - kw1_id, kw1_uuid = db_instance.add_keyword("fts_keyword_apple") - kw2_id, kw2_uuid = db_instance.add_keyword("fts_keyword_banana") - - # Search keyword FTS directly (not typically done, but tests population) - # NOTE: search_media_db doesn't search keyword_fts directly, this is just to test population - cursor = db_instance.execute_query("SELECT rowid, keyword FROM keyword_fts WHERE keyword_fts MATCH ?", ("apple",)) - fts_results = cursor.fetchall() - assert len(fts_results) == 1 - assert fts_results[0]['rowid'] == kw1_id - - # Soft delete keyword 1 - db_instance.soft_delete_keyword("fts_keyword_apple") - - # Search should now fail for apple - cursor = db_instance.execute_query("SELECT rowid FROM keyword_fts WHERE keyword_fts MATCH ?", ("apple",)) - assert cursor.fetchone() is None - - # Search for banana should still work - cursor = db_instance.execute_query("SELECT rowid FROM keyword_fts WHERE keyword_fts MATCH ?", ("banana",)) - assert cursor.fetchone()['rowid'] == kw2_id \ No newline at end of file diff --git a/Tests/Media_DB/__init__.py b/Tests/Media_DB/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Tests/DB/conftest.py b/Tests/Media_DB/conftest.py similarity index 100% rename from Tests/DB/conftest.py rename to Tests/Media_DB/conftest.py diff --git a/Tests/Media_DB/test_media_db_properties.py b/Tests/Media_DB/test_media_db_properties.py new file mode 100644 index 00000000..3fcc3c37 --- /dev/null +++ b/Tests/Media_DB/test_media_db_properties.py @@ -0,0 +1,348 @@ +# test_media_db_v2_properties.py +# +# Property-based tests for the Media_DB_v2 library using Hypothesis. +# These tests verify the logical correctness and invariants of the database +# operations across a wide range of generated data. +# +# Imports +from datetime import datetime, timezone, timedelta +from typing import Iterator, Callable, Any, Generator +import pytest +import uuid +from pathlib import Path +# +# Third-Party Imports +from hypothesis import given, strategies as st, settings, HealthCheck, assume +# +# Local Imports +# Adjust the import path based on your project structure +from tldw_chatbook.DB.Client_Media_DB_v2 import ( + MediaDatabase, + InputError, + DatabaseError, + ConflictError, fetch_keywords_for_media, empty_trash +) +# +####################################################################################################################### +# +# --- Hypothesis Settings --- + +# A custom profile for database-intensive tests. +# It increases the deadline and suppresses health checks that are common +# but expected in I/O-heavy testing scenarios. +settings.register_profile( + "db_test_suite", + deadline=2000, + suppress_health_check=[ + HealthCheck.too_slow, + HealthCheck.function_scoped_fixture, + HealthCheck.data_too_large, + ] +) +settings.load_profile("db_test_suite") + + +# --- Pytest Fixtures --- + + +@pytest.fixture +def db_factory(tmp_path: Path) -> Generator[Callable[[], MediaDatabase], Any, None]: + """ + A factory that creates fresh, isolated MediaDatabase instances on demand. + Manages cleanup of all created instances. + """ + created_dbs = [] + + def _create_db_instance() -> MediaDatabase: + db_file = tmp_path / f"prop_test_{uuid.uuid4().hex}.db" + client_id = f"client_{uuid.uuid4().hex[:8]}" + db = MediaDatabase(db_path=db_file, client_id=client_id) + created_dbs.append(db) + return db + + yield _create_db_instance + + # Teardown: close all connections that were created by the factory + for db in created_dbs: + db.close_connection() + +@pytest.fixture +def db_instance(db_factory: Callable[[], MediaDatabase]) -> MediaDatabase: + """ + Provides a single, fresh MediaDatabase instance for a test function. + This fixture uses the `db_factory` to create and manage the instance. + """ + return db_factory() + +# --- Hypothesis Strategies --- + +# Strategy for generating text that is guaranteed to have non-whitespace content. +st_required_text = st.text(min_size=1, max_size=50).map(lambda s: s.strip()).filter(lambda s: len(s) > 0) + +# Strategy for a single, clean keyword. +st_keyword_text = st.text( + alphabet=st.characters(whitelist_categories=["L", "N", "S", "P"]), + min_size=2, + max_size=20 +).map(lambda s: s.strip()).filter(lambda s: len(s) > 0) + +# Strategy for generating a list of unique, case-insensitive keywords. +st_keywords_list = st.lists( + st_keyword_text, + min_size=1, + max_size=5, + unique_by=lambda s: s.lower() +).filter(lambda l: len(l) > 0) # Ensure list is not empty after filtering + + +# A composite strategy to generate a valid dictionary of media data for creation. +@st.composite +def st_media_data(draw): + """Generates a dictionary of plausible data for a new media item.""" + return { + "title": draw(st_required_text), + "content": draw(st.text(min_size=10, max_size=500)), + "media_type": draw(st.sampled_from(['article', 'video', 'obsidian_note', 'pdf'])), + "author": draw(st.one_of(st.none(), st.text(min_size=3, max_size=30))), + "keywords": draw(st_keywords_list) + } + + +# --- Property Test Classes --- + +class TestMediaItemProperties: + """Property-based tests for the core Media item lifecycle.""" + + @given(media_data=st_media_data()) + def test_media_item_roundtrip(self, db_instance: MediaDatabase, media_data: dict): + """ + Property: A media item, once added, should be retrievable with the same data. + """ + media_data["content"] += f" {uuid.uuid4().hex}" + + media_id, media_uuid, msg = db_instance.add_media_with_keywords(**media_data) + + assert "added" in msg + assert media_id is not None + assert media_uuid is not None + + retrieved = db_instance.get_media_by_id(media_id) + assert retrieved is not None + + assert retrieved['title'] == media_data['title'] + assert retrieved['content'] == media_data['content'] + assert retrieved['type'] == media_data['media_type'] + assert retrieved['author'] == media_data['author'] + assert retrieved['version'] == 1 + assert not retrieved['deleted'] + + linked_keywords = {kw.lower().strip() for kw in fetch_keywords_for_media(media_id, db_instance)} + expected_keywords = {kw.lower().strip() for kw in media_data['keywords']} + assert linked_keywords == expected_keywords + + # FIX: The get_all_document_versions function defaults to NOT including content. + # We must explicitly request it for the assertion to work. + doc_versions = db_instance.get_all_document_versions(media_id, include_content=True) + assert len(doc_versions) == 1 + assert doc_versions[0]['version_number'] == 1 + assert doc_versions[0]['content'] == media_data['content'] + + # ... other tests in this class are correct ... + @given(initial_media=st_media_data(), update_media=st_media_data()) + def test_update_increments_version_and_changes_data(self, db_instance: MediaDatabase, initial_media: dict, + update_media: dict): + initial_media["content"] += f" initial_{uuid.uuid4().hex}" + update_media["content"] += f" update_{uuid.uuid4().hex}" + media_id, media_uuid, _ = db_instance.add_media_with_keywords(**initial_media) + original = db_instance.get_media_by_id(media_id) + media_id_up, media_uuid_up, msg = db_instance.add_media_with_keywords( + url=original['url'], + overwrite=True, + **update_media + ) + assert media_id_up == media_id + assert media_uuid_up == media_uuid + assert "updated" in msg + updated = db_instance.get_media_by_id(media_id) + assert updated is not None + assert updated['version'] == original['version'] + 1 + assert updated['title'] == update_media['title'] + assert updated['content'] == update_media['content'] + doc_versions = db_instance.get_all_document_versions(media_id) + assert len(doc_versions) == 2 + + @given(media_data=st_media_data()) + def test_soft_delete_makes_item_unfindable_by_default(self, db_instance: MediaDatabase, media_data: dict): + unique_word = f"hypothesis_{uuid.uuid4().hex}" + media_data["content"] = f"{media_data['content']} {unique_word}" + media_id, _, _ = db_instance.add_media_with_keywords(**media_data) + original = db_instance.get_media_by_id(media_id) + assert original is not None + db_instance.soft_delete_media(media_id) + assert db_instance.get_media_by_id(media_id) is None + results, total = db_instance.search_media_db(search_query=unique_word) + assert total == 0 + raw_record = db_instance.get_media_by_id(media_id, include_deleted=True) + assert raw_record is not None + assert raw_record['deleted'] == 1 + assert raw_record['version'] == original['version'] + 1 + + +class TestSearchProperties: + @given(media_data=st_media_data()) + def test_search_finds_item_by_its_properties(self, db_instance: MediaDatabase, media_data: dict): + unique_word = f"hypothesis_{uuid.uuid4().hex}" + media_data["content"] = f"{media_data['content']} {unique_word}" + media_id, _, _ = db_instance.add_media_with_keywords(**media_data) + results, total = db_instance.search_media_db(search_query=unique_word, search_fields=['content']) + assert total == 1 + assert results[0]['id'] == media_id + keyword_to_find = media_data["keywords"][0] + results, total = db_instance.search_media_db(search_query=None, must_have_keywords=[keyword_to_find], + media_ids_filter=[media_id]) + assert total == 1 + assert results[0]['id'] == media_id + results, total = db_instance.search_media_db(search_query=None, media_types=[media_data['media_type']], + media_ids_filter=[media_id]) + assert total == 1 + assert results[0]['id'] == media_id + + @given(item1=st_media_data(), item2=st_media_data()) + def test_search_isolates_results_correctly(self, db_instance: MediaDatabase, item1: dict, item2: dict): + item1_kws = set(kw.lower() for kw in item1['keywords']) + item2_kws = set(kw.lower() for kw in item2['keywords']) + assume(item1_kws.isdisjoint(item2_kws)) + item1["content"] += f" item1_{uuid.uuid4().hex}" + item2["content"] += f" item2_{uuid.uuid4().hex}" + id1, _, _ = db_instance.add_media_with_keywords(**item1) + id2, _, _ = db_instance.add_media_with_keywords(**item2) + keyword_to_find = item1['keywords'][0] + results, total = db_instance.search_media_db(search_query=None, must_have_keywords=[keyword_to_find], + media_ids_filter=[id1, id2]) + assert total == 1 + assert results[0]['id'] == id1 + + @given(media_data=st_media_data()) + def test_soft_deleted_item_is_not_in_fts_search(self, db_instance: MediaDatabase, media_data: dict): + unique_term = f"fts_{uuid.uuid4().hex}" + media_data['title'] = f"{media_data['title']} {unique_term}" + media_data['content'] += f" {uuid.uuid4().hex}" + media_id, _, _ = db_instance.add_media_with_keywords(**media_data) + results, total = db_instance.search_media_db(search_query=unique_term) + assert total == 1 + was_deleted = db_instance.soft_delete_media(media_id) + assert was_deleted is True + results, total = db_instance.search_media_db(search_query=unique_term) + assert total == 0 + + +class TestIdempotencyAndConstraints: + """Tests for idempotency of operations and enforcement of DB constraints.""" + + @settings(deadline=None) + @given(media_data=st_media_data()) + def test_mark_as_trash_is_idempotent(self, db_instance: MediaDatabase, media_data: dict): + """ + Property: Marking an item as trash multiple times has the same effect as + marking it once. The version should only increment on the first call. + """ + media_data["content"] += f" {uuid.uuid4().hex}" + media_id, _, _ = db_instance.add_media_with_keywords(**media_data) + + assert db_instance.mark_as_trash(media_id) is True + item_v2 = db_instance.get_media_by_id(media_id, include_trash=True) + assert item_v2['version'] == 2 + + assert db_instance.mark_as_trash(media_id) is False + item_still_v2 = db_instance.get_media_by_id(media_id, include_trash=True) + assert item_still_v2['version'] == 2 + + @given( + media1=st_media_data(), + media2=st_media_data(), + url_part1=st.uuids().map(str), + url_part2=st.uuids().map(str), + ) + def test_add_media_with_conflicting_hash_is_handled(self, + db_instance: MediaDatabase, + media1: dict, + media2: dict, + url_part1: str, + url_part2: str): + # Ensure URLs will be different, a highly unlikely edge case otherwise + assume(url_part1 != url_part2) + # Ensure titles are different to test a metadata-only update. + assume(media1['title'] != media2['title']) + + # Make content identical to trigger a content hash conflict. + media2['content'] = media1['content'] + + # Use the deterministic UUIDs from Hypothesis to build the URLs. + media1['url'] = f"http://example.com/{url_part1}" + media2['url'] = f"http://example.com/{url_part2}" + + id1, _, _ = db_instance.add_media_with_keywords(**media1) + + # 1. Test with overwrite=False. Should fail due to conflict. + id2, _, msg2 = db_instance.add_media_with_keywords(**media2, overwrite=False) + assert id2 is None + assert "already exists. Overwrite not enabled." in msg2 + + # 2. Test with overwrite=True. Should update the existing item's metadata. + id3, _, msg3 = db_instance.add_media_with_keywords(**media2, overwrite=True) + assert id3 == id1 + assert "updated" in msg3 + + # 3. Verify the metadata was actually updated in the database. + final_item = db_instance.get_media_by_id(id1) + assert final_item is not None + assert final_item['title'] == media2['title'] + + +class TestTimeBasedAndSearchQueries: + # ... other tests in this class are correct ... + + @given(days=st.integers(min_value=1, max_value=365)) + def test_empty_trash_respects_time_threshold(self, db_instance: MediaDatabase, days: int): + """ + Property: `empty_trash` should only soft-delete items whose `trash_date` + is older than the specified threshold. + """ + media_id, _, _ = db_instance.add_media_with_keywords( + title="Trash Test", content=f"...{uuid.uuid4().hex}", media_type="article", keywords=["test"]) + + # This call handles versioning correctly, bumping version to 2 + db_instance.mark_as_trash(media_id) + item_v2 = db_instance.get_media_by_id(media_id, include_trash=True) + + past_date = datetime.now(timezone.utc) - timedelta(days=days + 1) + + # FIX: The manual update MUST comply with the database triggers. + # This means we have to increment the version and supply a client_id. + # This makes the test setup robust. + with db_instance.transaction(): + db_instance.execute_query( + "UPDATE Media SET trash_date = ?, version = ?, client_id = ?, last_modified = ? WHERE id = ?", + ( + past_date.strftime('%Y-%m-%dT%H:%M:%S.%f')[:-3] + 'Z', + item_v2['version'] + 1, # Manually increment version for this setup step + 'test_setup_client', + db_instance._get_current_utc_timestamp_str(), + media_id + ) + ) + + # Now the item is at version 3. + # `empty_trash` will find this item and call `soft_delete_media`, + # which will correctly read version 3 and update to version 4. + processed_count, _ = empty_trash(db_instance=db_instance, days_threshold=days) + assert processed_count == 1 + + final_item = db_instance.get_media_by_id(media_id, include_trash=True, include_deleted=True) + assert final_item['deleted'] == 1 + assert final_item['version'] == 4 # Initial: 1, Trash: 2, Manual Date Change: 3, Delete: 4 + + +# +# End of test_media_db_properties.py +####################################################################################################################### diff --git a/Tests/Media_DB/test_media_db_v2.py b/Tests/Media_DB/test_media_db_v2.py new file mode 100644 index 00000000..a5fc26ae --- /dev/null +++ b/Tests/Media_DB/test_media_db_v2.py @@ -0,0 +1,460 @@ +# tests/test_media_db_v2.py +# Description: Unit tests for SQLite database operations, including CRUD, transactions, and sync log management. +# This version is self-contained and does not require a conftest.py file. +# +# Standard Library Imports: +import json +import os +import pytest +import shutil +import sys +import time +import sqlite3 +from datetime import datetime, timezone, timedelta +from pathlib import Path +# +# --- Path Setup (Replaces conftest.py logic) --- +# Add the project root to the Python path to allow importing the library. +# This assumes the tests are in a 'tests' directory at the project root. +try: + project_root = Path(__file__).resolve().parent.parent + sys.path.insert(0, str(project_root)) + # If your source code is in a 'src' directory, you might need: + # sys.path.insert(0, str(project_root / "src")) +except (NameError, IndexError): + # Fallback for environments where __file__ is not defined + pass +# +# Local imports (from the main project) +from tldw_chatbook.DB.Client_Media_DB_v2 import MediaDatabase as Database, ConflictError\ +# +####################################################################################################################### +# +# Helper Functions (for use in tests) +# + +def get_log_count(db: Database, entity_uuid: str) -> int: + """Helper to get sync log entries for assertions.""" + cursor = db.execute_query("SELECT COUNT(*) FROM sync_log WHERE entity_uuid = ?", (entity_uuid,)) + return cursor.fetchone()[0] + + +def get_latest_log(db: Database, entity_uuid: str) -> dict | None: + """Helper to get the most recent sync log for an entity.""" + cursor = db.execute_query( + "SELECT * FROM sync_log WHERE entity_uuid = ? ORDER BY change_id DESC LIMIT 1", + (entity_uuid,) + ) + row = cursor.fetchone() + return dict(row) if row else None + + +def get_entity_version(db: Database, entity_table: str, uuid: str) -> int | None: + """Helper to get the current version of an entity.""" + cursor = db.execute_query(f"SELECT version FROM {entity_table} WHERE uuid = ?", (uuid,)) + row = cursor.fetchone() + return row['version'] if row else None + + +####################################################################################################################### +# +# Pytest Fixtures (Moved from conftest.py) +# + +@pytest.fixture(scope="function") +def memory_db_factory(): + """Factory fixture to create in-memory Database instances with automatic connection closing.""" + created_dbs = [] + + def _create_db(client_id="test_client"): + db = Database(db_path=":memory:", client_id=client_id) + created_dbs.append(db) + return db + + yield _create_db + # Teardown: close connections for all created in-memory DBs + for db in created_dbs: + try: + db.close_connection() + except Exception: # Ignore errors during cleanup + pass + + +@pytest.fixture(scope="function") +def temp_db_path(tmp_path: Path) -> str: + """Creates a temporary directory and returns a unique DB path string within it.""" + # The built-in tmp_path fixture handles directory creation and cleanup. + return str(tmp_path / "test_db.sqlite") + + +@pytest.fixture(scope="function") +def file_db(temp_db_path: str): + """Creates a file-based Database instance using a temporary path with automatic connection closing.""" + db = Database(db_path=temp_db_path, client_id="file_client") + yield db + db.close_connection() + + +@pytest.fixture(scope="function") +def db_instance(memory_db_factory): + """Provides a fresh, isolated in-memory DB for a single test.""" + return memory_db_factory("crud_client") + + +@pytest.fixture(scope="class") +def search_db(tmp_path_factory): + """Sets up a single DB with predictable data for all search tests in a class.""" + db_path = tmp_path_factory.mktemp("search_tests") / "search.db" + db = Database(db_path, "search_client") + + # Add a predictable set of media items + db.add_media_with_keywords( + title="Alpha One", content="Content about Python and programming.", media_type="article", + keywords=["python", "programming"], ingestion_date="2023-01-15T12:00:00Z" + ) # ID 1 + db.add_media_with_keywords( + title="Beta Two", content="A video about data science.", media_type="video", + keywords=["python", "data science"], ingestion_date="2023-02-20T12:00:00Z" + ) # ID 2 + db.add_media_with_keywords( + title="Gamma Three (TRASH)", content="Old news.", media_type="article", + keywords=["news"], ingestion_date="2023-03-10T12:00:00Z" + ) # ID 3 + db.mark_as_trash(3) + + yield db + db.close_connection() + + +####################################################################################################################### +# +# Test Classes +# + +class TestDatabaseInitialization: + def test_memory_db_creation(self, memory_db_factory): + """Test creating an in-memory database.""" + db = memory_db_factory("client_mem") + assert db.is_memory_db + assert db.client_id == "client_mem" + cursor = db.execute_query("SELECT name FROM sqlite_master WHERE type='table' AND name='Media'") + assert cursor.fetchone() is not None + db.close_connection() + + def test_file_db_creation(self, file_db, temp_db_path): + """Test creating a file-based database.""" + assert not file_db.is_memory_db + assert file_db.client_id == "file_client" + assert os.path.exists(temp_db_path) + cursor = file_db.execute_query("SELECT name FROM sqlite_master WHERE type='table' AND name='Media'") + assert cursor.fetchone() is not None + # file_db fixture handles closure + + def test_missing_client_id(self): + """Test that ValueError is raised if client_id is missing.""" + with pytest.raises(ValueError, match="Client ID cannot be empty"): + Database(db_path=":memory:", client_id="") + with pytest.raises(ValueError, match="Client ID cannot be empty"): + Database(db_path=":memory:", client_id=None) + + +class TestDatabaseTransactions: + def test_transaction_commit(self, memory_db_factory): + db = memory_db_factory() + keyword = "commit_test" + with db.transaction(): + db.add_keyword(keyword) + cursor = db.execute_query("SELECT keyword FROM Keywords WHERE keyword = ?", (keyword,)) + assert cursor.fetchone()['keyword'] == keyword + + def test_transaction_rollback(self, memory_db_factory): + db = memory_db_factory() + keyword = "rollback_test" + initial_count_cursor = db.execute_query("SELECT COUNT(*) FROM Keywords") + initial_count = initial_count_cursor.fetchone()[0] + try: + with db.transaction(): + new_uuid = db._generate_uuid() + db.execute_query( + "INSERT INTO Keywords (keyword, uuid, last_modified, version, client_id, deleted) VALUES (?, ?, ?, 1, ?, 0)", + (keyword, new_uuid, db._get_current_utc_timestamp_str(), db.client_id), + commit=False + ) + cursor_inside = db.execute_query("SELECT COUNT(*) FROM Keywords") + assert cursor_inside.fetchone()[0] == initial_count + 1 + raise ValueError("Simulating error to trigger rollback") + except ValueError: + pass # Expected error + except Exception as e: + pytest.fail(f"Unexpected exception during rollback test: {e}") + + final_count_cursor = db.execute_query("SELECT COUNT(*) FROM Keywords") + assert final_count_cursor.fetchone()[0] == initial_count + + +class TestSearchFunctionality: + # The 'search_db' fixture is now defined at the module level + # and provides a shared database for all tests in this class. + + def test_some_search_function(self, search_db): + """A placeholder test demonstrating usage of the search_db fixture.""" + # Example: Search for items with the keyword "python" + # results = search_db.search(keywords=["python"]) + # assert len(results) == 2 + pass # Add actual search tests here + + +class TestDatabaseCRUDAndSync: + # The 'db_instance' fixture is now defined at the module level + # and provides a fresh in-memory DB for each test in this class. + + def test_add_keyword(self, db_instance): + keyword = " test keyword " + expected_keyword = "test keyword" + kw_id, kw_uuid = db_instance.add_keyword(keyword) + + assert kw_id is not None + assert kw_uuid is not None + + cursor = db_instance.execute_query("SELECT * FROM Keywords WHERE id = ?", (kw_id,)) + row = cursor.fetchone() + assert row['keyword'] == expected_keyword + assert row['uuid'] == kw_uuid + + log_entry = get_latest_log(db_instance, kw_uuid) + assert log_entry['operation'] == 'create' + assert log_entry['entity'] == 'Keywords' + + def test_add_existing_keyword(self, db_instance): + keyword = "existing" + kw_id1, kw_uuid1 = db_instance.add_keyword(keyword) + log_count1 = get_log_count(db_instance, kw_uuid1) + kw_id2, kw_uuid2 = db_instance.add_keyword(keyword) + log_count2 = get_log_count(db_instance, kw_uuid1) + + assert kw_id1 == kw_id2 + assert kw_uuid1 == kw_uuid2 + assert log_count1 == log_count2 + + def test_soft_delete_keyword(self, db_instance): + keyword = "to_delete" + kw_id, kw_uuid = db_instance.add_keyword(keyword) + initial_version = get_entity_version(db_instance, "Keywords", kw_uuid) + + assert db_instance.soft_delete_keyword(keyword) is True + + cursor = db_instance.execute_query("SELECT deleted, version FROM Keywords WHERE id = ?", (kw_id,)) + row = cursor.fetchone() + assert row['deleted'] == 1 + assert row['version'] == initial_version + 1 + + log_entry = get_latest_log(db_instance, kw_uuid) + assert log_entry['operation'] == 'delete' + assert log_entry['version'] == initial_version + 1 + + def test_undelete_keyword(self, db_instance): + keyword = "to_undelete" + kw_id, kw_uuid = db_instance.add_keyword(keyword) + db_instance.soft_delete_keyword(keyword) + deleted_version = get_entity_version(db_instance, "Keywords", kw_uuid) + + undelete_id, undelete_uuid = db_instance.add_keyword(keyword) + + assert undelete_id == kw_id + cursor = db_instance.execute_query("SELECT deleted, version FROM Keywords WHERE id = ?", (kw_id,)) + row = cursor.fetchone() + assert row['deleted'] == 0 + assert row['version'] == deleted_version + 1 + + log_entry = get_latest_log(db_instance, kw_uuid) + assert log_entry['operation'] == 'update' + assert log_entry['version'] == deleted_version + 1 + + def test_add_media_with_keywords_create(self, db_instance): + title = "Test Media Create" + content = "Some unique content for create." + keywords = ["create_kw1", "create_kw2"] + + media_id, media_uuid, msg = db_instance.add_media_with_keywords( + title=title, media_type="article", content=content, keywords=keywords + ) + assert media_id is not None + assert f"Media '{title}' added." in msg + + cursor = db_instance.execute_query("SELECT uuid, version FROM Media WHERE id = ?", (media_id,)) + media_row = cursor.fetchone() + assert media_row['uuid'] == media_uuid + assert media_row['version'] == 1 + + log_entry = get_latest_log(db_instance, media_uuid) + assert log_entry['operation'] == 'create' + + def test_add_media_with_keywords_update(self, db_instance): + """Test updating a media item with new content, title, and keywords.""" + # Initial media setup + title1 = "Test Media Original" + title2 = "Test Media Updated" + content1 = "Initial content." + content2 = "Updated content." + keywords1 = ["update_kw1"] + keywords2 = ["update_kw2", "update_kw3"] + media_type = "article" # Use consistent media_type across tests + + # Create initial media item + media_id, media_uuid, msg1 = db_instance.add_media_with_keywords( + title=title1, media_type=media_type, content=content1, keywords=keywords1 + ) + assert "added" in msg1.lower(), f"Expected 'added' in message, got: {msg1}" + initial_version = get_entity_version(db_instance, "Media", media_uuid) + assert initial_version == 1, f"Expected initial version 1, got {initial_version}" + + # Fetch the created media to get its URL (stable identifier) + created_media = db_instance.get_media_by_id(media_id) + assert created_media is not None, "Failed to retrieve created media item" + url_to_update = created_media['url'] + + # Update the media item + updated_id, updated_uuid, msg2 = db_instance.add_media_with_keywords( + title=title2, + media_type=media_type, + content=content2, + keywords=keywords2, + overwrite=True, + url=url_to_update + ) + + # Verify update operation returned correct values + assert updated_id == media_id, "Update returned different media ID" + assert updated_uuid == media_uuid, "Update returned different UUID" + assert "updated" in msg2.lower(), f"Expected 'updated' in message, got: {msg2}" + + # Verify content was updated + cursor = db_instance.execute_query("SELECT content, title, version FROM Media WHERE id = ?", (media_id,)) + media_row = cursor.fetchone() + assert media_row['content'] == content2, "Content was not updated" + assert media_row['title'] == title2, "Title was not updated" + assert media_row['version'] == initial_version + 1, f"Version not incremented, expected {initial_version + 1}, got {media_row['version']}" + + # Verify keywords were updated + cursor = db_instance.execute_query( + """ + SELECT k.keyword FROM MediaKeywords mk + JOIN Keywords k ON mk.keyword_id = k.id + WHERE mk.media_id = ? + """, + (media_id,) + ) + linked_keywords = [row['keyword'] for row in cursor.fetchall()] + assert set(kw.lower() for kw in linked_keywords) == set(kw.lower() for kw in keywords2), "Keywords were not updated correctly" + + # Verify sync log was created for the update + log_entry = get_latest_log(db_instance, media_uuid) + assert log_entry['operation'] == 'update', f"Expected 'update' operation, got {log_entry['operation']}" + assert log_entry['version'] == initial_version + 1, f"Log version mismatch: {log_entry['version']} vs {initial_version + 1}" + assert log_entry['entity'] == 'Media', f"Expected 'Media' entity, got {log_entry['entity']}" + + def test_soft_delete_media_cascade(self, db_instance): + media_id, media_uuid, _ = db_instance.add_media_with_keywords( + title="Cascade Test", content="Cascade content", keywords=["cascade1"], media_type="article" + ) + media_version = get_entity_version(db_instance, "Media", media_uuid) + + assert db_instance.soft_delete_media(media_id, cascade=True) is True + + cursor = db_instance.execute_query("SELECT deleted, version FROM Media WHERE id = ?", (media_id,)) + assert dict(cursor.fetchone()) == {'deleted': 1, 'version': media_version + 1} + + cursor = db_instance.execute_query("SELECT COUNT(*) FROM MediaKeywords WHERE media_id = ?", (media_id,)) + assert cursor.fetchone()[0] == 0 + + media_log = get_latest_log(db_instance, media_uuid) + assert media_log['operation'] == 'delete' + + def test_optimistic_locking_prevents_update_with_stale_version(self, db_instance): + kw_id, kw_uuid = db_instance.add_keyword("conflict_test") + original_version = 1 + + db_instance.execute_query( + "UPDATE Keywords SET version = ?, client_id = ? WHERE id = ?", + (original_version + 1, "external_client", kw_id), commit=True + ) + + cursor = db_instance.execute_query( + "UPDATE Keywords SET keyword='stale_update', version=?, client_id=? WHERE id=? AND version=?", + (original_version + 1, db_instance.client_id, kw_id, original_version), commit=True + ) + assert cursor.rowcount == 0 + + def test_version_validation_trigger(self, db_instance): + kw_id, kw_uuid = db_instance.add_keyword("validation_test") + current_version = get_entity_version(db_instance, "Keywords", kw_uuid) + + with pytest.raises(sqlite3.IntegrityError, match="Version must increment by exactly 1"): + db_instance.execute_query( + "UPDATE Keywords SET version = ? WHERE id = ?", + (current_version + 2, kw_id), commit=True + ) + + def test_client_id_validation_trigger(self, db_instance): + kw_id, kw_uuid = db_instance.add_keyword("clientid_test") + current_version = get_entity_version(db_instance, "Keywords", kw_uuid) + + with pytest.raises(sqlite3.IntegrityError, match="Client ID cannot be NULL or empty"): + db_instance.execute_query( + "UPDATE Keywords SET version = ?, client_id = NULL WHERE id = ?", + (current_version + 1, kw_id), commit=True + ) + + +class TestSyncLogManagement: + @pytest.fixture(autouse=True) + def setup_db(self, db_instance): + """Use autouse to provide the db_instance to every test in this class.""" + # Add some initial data to generate logs + db_instance.add_keyword("log_kw_1") + time.sleep(0.01) + db_instance.add_keyword("log_kw_2") + time.sleep(0.01) + db_instance.add_keyword("log_kw_3") + db_instance.soft_delete_keyword("log_kw_2") + self.db = db_instance + + def test_get_sync_log_entries_all(self): + logs = self.db.get_sync_log_entries() + assert len(logs) == 4 + assert logs[0]['change_id'] == 1 + + def test_get_sync_log_entries_since(self): + logs = self.db.get_sync_log_entries(since_change_id=2) + assert len(logs) == 2 + assert logs[0]['change_id'] == 3 + + def test_get_sync_log_entries_limit(self): + logs = self.db.get_sync_log_entries(limit=2) + assert len(logs) == 2 + assert logs[0]['change_id'] == 1 + assert logs[1]['change_id'] == 2 + + def test_delete_sync_log_entries_specific(self): + initial_logs = self.db.get_sync_log_entries() + ids_to_delete = [initial_logs[1]['change_id'], initial_logs[2]['change_id']] + deleted_count = self.db.delete_sync_log_entries(ids_to_delete) + assert deleted_count == 2 + remaining_ids = {log['change_id'] for log in self.db.get_sync_log_entries()} + assert remaining_ids == {1, 4} + + def test_delete_sync_log_entries_before(self): + deleted_count = self.db.delete_sync_log_entries_before(3) + assert deleted_count == 3 + remaining_logs = self.db.get_sync_log_entries() + assert len(remaining_logs) == 1 + assert remaining_logs[0]['change_id'] == 4 + + def test_delete_sync_log_entries_invalid_id(self): + with pytest.raises(ValueError): + self.db.delete_sync_log_entries([1, "two", 3]) + +# +# End of test_media_db_v2.py +######################################################################################################################## + diff --git a/Tests/DB/test_sync_client.py b/Tests/Media_DB/test_sync_client.py similarity index 99% rename from Tests/DB/test_sync_client.py rename to Tests/Media_DB/test_sync_client.py index 4a18a7e5..b0bf5c84 100644 --- a/Tests/DB/test_sync_client.py +++ b/Tests/Media_DB/test_sync_client.py @@ -12,8 +12,8 @@ import requests # # Local Imports -from .test_sqlite_db import get_entity_version -from tldw_cli.tldw_app.DB.Sync_Client import ClientSyncEngine +from .test_media_db_v2 import get_entity_version +from tldw_chatbook.DB.Sync_Client import ClientSyncEngine # ####################################################################################################################### # diff --git a/Tests/Prompts_DB/__init__.py b/Tests/Prompts_DB/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/Tests/Prompts_DB/tests_prompts_db.py b/Tests/Prompts_DB/tests_prompts_db.py new file mode 100644 index 00000000..d69f981d --- /dev/null +++ b/Tests/Prompts_DB/tests_prompts_db.py @@ -0,0 +1,1102 @@ +# test_Prompts_DB.py +import sqlite3 +import unittest +import uuid +import os +import shutil +import tempfile +import threading +from pathlib import Path + +# The module to be tested +from tldw_chatbook.DB.Prompts_DB import ( + PromptsDatabase, + DatabaseError, + SchemaError, + InputError, + ConflictError, + add_or_update_prompt, + load_prompt_details_for_ui, + export_prompt_keywords_to_csv, + view_prompt_keywords_markdown, + export_prompts_formatted, +) + + +# --- Test Case Base --- +class BaseTestCase(unittest.TestCase): + """Base class for tests, handles temporary DB setup and teardown.""" + + def setUp(self): + """Set up a new in-memory database for each test.""" + self.client_id = "test_client_1" + self.db = PromptsDatabase(':memory:', client_id=self.client_id) + # For tests requiring a file-based DB + self.temp_dir = tempfile.mkdtemp() + self.db_path = Path(self.temp_dir) / "test_prompts.db" + # ---- Add a list to track all created file DBs ---- + self.file_db_instances = [] + + def tearDown(self): + """Close connection and clean up resources.""" + if hasattr(self, 'db') and self.db: + self.db.close_connection() + # ---- Close all tracked file DB instances ---- + for instance in self.file_db_instances: + instance.close_connection() + # Now it's safe to remove the directory + shutil.rmtree(self.temp_dir) + + def _get_file_db(self): + """Helper to get a file-based database instance, ensuring it's tracked for cleanup.""" + # ---- Track the new instance ---- + instance = PromptsDatabase(self.db_path, client_id=f"test_client_file_{len(self.file_db_instances)}") + self.file_db_instances.append(instance) + return instance + + def _add_sample_data(self, db_instance): + """Helper to populate a given database instance with some initial data.""" + # Note: All arguments are provided to match the fixed library code + db_instance.add_prompt( + name="Recipe Generator", + author="ChefAI", + details="Creates recipes for cooking", + system_prompt="You are a chef.", + user_prompt="Give me a recipe for pasta.", + keywords=["food", "cooking"], + overwrite=True + ) + db_instance.add_prompt( + name="Code Explainer", + author="DevHelper", + details="Explains python code snippets", + system_prompt="You are a senior dev.", + user_prompt="Explain this python code.", + keywords=["code", "python"], + overwrite=True + ) + db_instance.add_prompt( + name="Poem Writer", + author="BardBot", + details="Writes poems about nature", + system_prompt="You are a poet.", + user_prompt="Write a poem about the sea.", + keywords=["writing", "poetry"], + overwrite=True + ) + # Add a deleted prompt for testing filters + pid, _, _ = db_instance.add_prompt( + name="Old Prompt", + author="Old", + details="Old details", + overwrite=True + ) + db_instance.soft_delete_prompt(pid) + + +# --- Test Suites --- +class TestDatabaseInitialization(BaseTestCase): + + def test_init_success_in_memory(self): + self.assertIsNotNone(self.db) + self.assertIsInstance(self.db, PromptsDatabase) + self.assertTrue(self.db.is_memory_db) + self.assertEqual(self.db.client_id, self.client_id) + + def test_init_success_file_based(self): + file_db = self._get_file_db() + self.assertTrue(self.db_path.exists()) + self.assertFalse(file_db.is_memory_db) + conn = file_db.get_connection() + # WAL mode is set for file-based dbs + cursor = conn.execute("PRAGMA journal_mode;") + self.assertEqual(cursor.fetchone()[0].lower(), 'wal') + + def test_init_failure_no_client_id(self): + with self.assertRaises(ValueError): + PromptsDatabase(':memory:', client_id=None) + with self.assertRaises(ValueError): + PromptsDatabase(':memory:', client_id="") + + def test_schema_version_check(self): + conn = self.db.get_connection() + version = conn.execute("SELECT version FROM schema_version").fetchone()['version'] + self.assertEqual(version, self.db._CURRENT_SCHEMA_VERSION) + + def test_fts_tables_created(self): + conn = self.db.get_connection() + try: + conn.execute("SELECT * FROM prompts_fts LIMIT 1") + conn.execute("SELECT * FROM prompt_keywords_fts LIMIT 1") + except Exception as e: + self.fail(f"FTS tables not created or queryable: {e}") + + def test_thread_safety_connections(self): + """Verify that different threads get different connection objects.""" + connections = {} + db_instance = self._get_file_db() + + def get_conn(thread_id): + conn = db_instance.get_connection() + connections[thread_id] = id(conn) + db_instance.close_connection() + + thread1 = threading.Thread(target=get_conn, args=(1,)) + thread2 = threading.Thread(target=get_conn, args=(2,)) + + thread1.start() + thread2.start() + thread1.join() + thread2.join() + + self.assertIn(1, connections) + self.assertIn(2, connections) + self.assertNotEqual(connections[1], connections[2], + "Connections for different threads should be different objects") + + +class TestCrudOperations(BaseTestCase): + + def test_add_keyword(self): + kw_id, kw_uuid = self.db.add_keyword(" Test Keyword 1 ") + self.assertIsNotNone(kw_id) + self.assertIsNotNone(kw_uuid) + + # Verify it was added correctly and normalized + kw_data = self.db.get_active_keyword_by_text("test keyword 1") + self.assertIsNotNone(kw_data) + self.assertEqual(kw_data['keyword'], "test keyword 1") + self.assertEqual(kw_data['id'], kw_id) + + # Verify sync log + sync_logs = self.db.get_sync_log_entries() + self.assertEqual(len(sync_logs), 1) + log_entry = sync_logs[0] + self.assertEqual(log_entry['entity'], 'PromptKeywordsTable') + self.assertEqual(log_entry['entity_uuid'], kw_uuid) + self.assertEqual(log_entry['operation'], 'create') + self.assertEqual(log_entry['version'], 1) + + # Verify FTS + res = self.db.execute_query( + "SELECT rowid FROM prompt_keywords_fts WHERE prompt_keywords_fts MATCH ?", + ("test",) + ).fetchone() + self.assertIsNotNone(res) + self.assertEqual(res['rowid'], kw_id) + + def test_add_existing_keyword(self): + kw_id1, kw_uuid1 = self.db.add_keyword("duplicate") + kw_id2, kw_uuid2 = self.db.add_keyword("duplicate") + self.assertEqual(kw_id1, kw_id2) + self.assertEqual(kw_uuid1, kw_uuid2) + sync_logs = self.db.get_sync_log_entries() + self.assertEqual(len(sync_logs), 1) # Should only log the creation once + + def test_add_prompt(self): + pid, puuid, msg = self.db.add_prompt("My Prompt", "Me", "Details here", keywords=["tag1", "tag2"]) + self.assertIsNotNone(pid) + self.assertIsNotNone(puuid) + self.assertIn("added", msg) + + prompt = self.db.get_prompt_by_id(pid) + self.assertEqual(prompt['name'], "My Prompt") + self.assertEqual(prompt['version'], 1) + + keywords = self.db.fetch_keywords_for_prompt(pid) + self.assertIn("tag1", keywords) + self.assertIn("tag2", keywords) + + # Check sync logs - 1 for prompt, 2 for keywords, 2 for links + sync_logs = self.db.get_sync_log_entries() + prompt_create_logs = [l for l in sync_logs if l['entity'] == 'Prompts' and l['operation'] == 'create'] + self.assertEqual(len(prompt_create_logs), 1) + prompt_create_log = prompt_create_logs[0] + self.assertEqual(prompt_create_log['entity_uuid'], puuid) + + # Verify FTS + res = self.db.execute_query( + "SELECT rowid FROM prompts_fts WHERE prompts_fts MATCH ?", + ("My Prompt",) + ).fetchone() + self.assertIsNotNone(res) + self.assertEqual(res['rowid'], pid) + + def test_add_prompt_conflict(self): + self.db.add_prompt("Conflict Prompt", "Author", "Details") + with self.assertRaises(ConflictError): + self.db.add_prompt("Conflict Prompt", "Author", "Details", overwrite=False) + + def test_add_prompt_overwrite(self): + pid1, _, _ = self.db.add_prompt("Overwrite Me", "Author1", "Details1") + pid2, _, msg = self.db.add_prompt("Overwrite Me", "Author2", "Details2", overwrite=True) + self.assertEqual(pid1, pid2) + self.assertIn("updated", msg) + + prompt = self.db.get_prompt_by_id(pid1) + self.assertEqual(prompt['author'], "Author2") + self.assertEqual(prompt['version'], 2) + + def test_soft_delete_prompt(self): + pid, puuid, _ = self.db.add_prompt("To Be Deleted", "Author", "Details", keywords=["temp"]) + + self.assertIsNotNone(self.db.get_prompt_by_id(pid)) + + success = self.db.soft_delete_prompt(pid) + self.assertTrue(success) + + self.assertIsNone(self.db.get_prompt_by_id(pid)) + + deleted_prompt = self.db.get_prompt_by_id(pid, include_deleted=True) + self.assertIsNotNone(deleted_prompt) + self.assertEqual(deleted_prompt['deleted'], 1) + self.assertEqual(deleted_prompt['version'], 2) + + res = self.db.execute_query( + "SELECT rowid FROM prompts_fts WHERE prompts_fts MATCH ?", + ("To Be Deleted",) + ).fetchone() + self.assertIsNone(res) + + link_exists = self.db.execute_query("SELECT 1 FROM PromptKeywordLinks WHERE prompt_id=?", (pid,)).fetchone() + self.assertIsNone(link_exists) + + sync_logs = self.db.get_sync_log_entries() # Check all logs + delete_log = next(l for l in sync_logs if l['entity'] == 'Prompts' and l['operation'] == 'delete') + unlink_log = next(l for l in sync_logs if l['entity'] == 'PromptKeywordLinks' and l['operation'] == 'unlink') + self.assertEqual(delete_log['entity_uuid'], puuid) + self.assertIn(puuid, unlink_log['entity_uuid']) + + def test_soft_delete_keyword(self): + kw_id, kw_uuid = self.db.add_keyword("ephemeral") + self.db.add_prompt("Test Prompt", "Author", "Some details", keywords=["ephemeral"]) + + success = self.db.soft_delete_keyword("ephemeral") + self.assertTrue(success) + + self.assertIsNone(self.db.get_active_keyword_by_text("ephemeral")) + + res = self.db.execute_query( + "SELECT rowid FROM prompt_keywords_fts WHERE prompt_keywords_fts MATCH ?", + ("ephemeral",) + ).fetchone() + self.assertIsNone(res) + + prompt = self.db.get_prompt_by_name("Test Prompt") + keywords = self.db.fetch_keywords_for_prompt(prompt['id']) + self.assertNotIn("ephemeral", keywords) + + def test_update_prompt_by_id(self): + pid, puuid, _ = self.db.add_prompt("Initial Name", "Author", "Details", keywords=["old_kw"]) + + update_data = {"name": "Updated Name", "details": "New details", "keywords": ["new_kw", "another_kw"]} + + updated_uuid, msg = self.db.update_prompt_by_id(pid, update_data) + self.assertEqual(updated_uuid, puuid) + self.assertIn("updated successfully", msg) + + prompt = self.db.get_prompt_by_id(pid) + self.assertEqual(prompt['name'], "Updated Name") + self.assertEqual(prompt['details'], "New details") + self.assertEqual(prompt['version'], 2) + + keywords = self.db.fetch_keywords_for_prompt(pid) + self.assertIn("new_kw", keywords) + self.assertIn("another_kw", keywords) + self.assertNotIn("old_kw", keywords) + + res = self.db.execute_query( + "SELECT rowid FROM prompts_fts WHERE prompts_fts MATCH ?", + ("Updated Name",) + ).fetchone() + self.assertIsNotNone(res) + self.assertEqual(res['rowid'], pid) + + +class TestQueryOperations(BaseTestCase): + + def setUp(self): + super().setUp() + self._add_sample_data(self.db) + + def test_list_prompts(self): + prompts, total_pages, page, total_items = self.db.list_prompts(page=1, per_page=2) + self.assertEqual(len(prompts), 2) + self.assertEqual(total_items, 3) # 3 active prompts + self.assertEqual(total_pages, 2) + self.assertEqual(page, 1) + + def test_list_prompts_include_deleted(self): + _, _, _, total_items = self.db.list_prompts(include_deleted=True) + self.assertEqual(total_items, 4) + + def test_fetch_prompt_details(self): + details = self.db.fetch_prompt_details("Recipe Generator") + self.assertIsNotNone(details) + self.assertEqual(details['name'], "Recipe Generator") + self.assertIn("food", details['keywords']) + self.assertIn("cooking", details['keywords']) + + def test_fetch_all_keywords(self): + keywords = self.db.fetch_all_keywords() + expected = sorted(["food", "cooking", "code", "python", "writing", "poetry"]) + self.assertEqual(sorted(keywords), expected) + + def test_search_prompts_by_name(self): + results, total = self.db.search_prompts("Recipe") + self.assertEqual(total, 1) + self.assertEqual(results[0]['name'], "Recipe Generator") + + def test_search_prompts_by_details(self): + results, total = self.db.search_prompts("python code", search_fields=['details', 'user_prompt']) + self.assertEqual(total, 1) + self.assertEqual(results[0]['name'], "Code Explainer") + + def test_search_prompts_by_keyword(self): + results, total = self.db.search_prompts("poetry", search_fields=['keywords']) + self.assertEqual(total, 1) + self.assertEqual(results[0]['name'], "Poem Writer") + + +class TestUtilitiesAndAdvancedFeatures(BaseTestCase): + + def test_backup_database(self): + file_db = self._get_file_db() + file_db.add_prompt("Backup Test", "Tester", "Details") + + backup_path = self.db_path.with_suffix('.backup.db') + + success = file_db.backup_database(str(backup_path)) + self.assertTrue(success) + self.assertTrue(backup_path.exists()) + + backup_db = PromptsDatabase(backup_path, client_id="backup_verifier") + prompt = backup_db.get_prompt_by_name("Backup Test") + self.assertIsNotNone(prompt) + self.assertEqual(prompt['author'], "Tester") + + backup_db.close_connection() + + def test_backup_database_same_file_fails(self): + file_db = self._get_file_db() + # The function catches this ValueError and returns False + success = file_db.backup_database(str(self.db_path)) + self.assertFalse(success) + + def test_transaction_rollback(self): + pid, _, _ = self.db.add_prompt("Initial", "Auth", "Det") + + try: + with self.db.transaction(): + self.db.execute_query("UPDATE Prompts SET name = ?, version = version + 1, client_id='t' WHERE id = ?", + ("Updated", pid), commit=False) + prompt_inside = self.db.get_prompt_by_id(pid) + self.assertEqual(prompt_inside['name'], "Updated") + raise ValueError("Intentional failure to trigger rollback") + except ValueError: + pass # Expected exception + + prompt_outside = self.db.get_prompt_by_id(pid) + self.assertEqual(prompt_outside['name'], "Initial") + + def test_delete_sync_log_entries(self): + self.db.add_keyword("kw1") + self.db.add_keyword("kw2") + + logs = self.db.get_sync_log_entries() + self.assertEqual(len(logs), 2) + log_ids_to_delete = [log['change_id'] for log in logs] + + deleted_count = self.db.delete_sync_log_entries(log_ids_to_delete) + self.assertEqual(deleted_count, 2) + + remaining_logs = self.db.get_sync_log_entries() + self.assertEqual(len(remaining_logs), 0) + + def test_update_prompt_version_conflict(self): + """ + Tests that the database trigger prevents updates that violate the versioning rule. + This is a direct test of the database integrity layer. + """ + db = self.db + pid, _, _ = db.add_prompt("Trigger Test", "Author", "Details") # Prompt is now at version 1 + + db.update_prompt_by_id(pid, {'details': 'New Details'}) + prompt = db.get_prompt_by_id(pid) + self.assertEqual(prompt['version'], 2, "Version should be 2 after the first update.") + + conn = db.get_connection() + + with self.assertRaises(sqlite3.IntegrityError) as cm: + with conn: + conn.execute("UPDATE Prompts SET version = 2, client_id='raw' WHERE id = ?", (pid,)) + + self.assertIn("Version must increment by exactly 1", str(cm.exception)) + + final_prompt = db.get_prompt_by_id(pid) + self.assertEqual(final_prompt['version'], 2, "Version should remain 2 after the failed update.") + + +class TestStandaloneFunctions(BaseTestCase): + def setUp(self): + super().setUp() + self._add_sample_data(self.db) + + def test_add_or_update_prompt(self): + # Test update + pid, _, msg = add_or_update_prompt(self.db, "Recipe Generator", "New Chef", "New details", keywords=["italian"]) + self.assertIn("updated", msg) + prompt = self.db.get_prompt_by_id(pid) + self.assertEqual(prompt['author'], "New Chef") + keywords = self.db.fetch_keywords_for_prompt(pid) + self.assertIn("italian", keywords) + + # Test add + pid_new, _, msg_new = add_or_update_prompt(self.db, "New Standalone Prompt", "Tester", "Details") + self.assertIn("added", msg_new) + self.assertIsNotNone(self.db.get_prompt_by_id(pid_new)) + + def test_load_prompt_details_for_ui(self): + name, author, details, sys_p, user_p, kws = load_prompt_details_for_ui(self.db, "Code Explainer") + self.assertEqual(name, "Code Explainer") + self.assertEqual(author, "DevHelper") + self.assertEqual(sys_p, "You are a senior dev.") + self.assertEqual(kws, "code, python") + + def test_load_prompt_details_not_found(self): + result = load_prompt_details_for_ui(self.db, "Non Existent") + self.assertEqual(result, ("", "", "", "", "", "")) + + def test_view_prompt_keywords_markdown(self): + md_output = view_prompt_keywords_markdown(self.db) + self.assertIn("### Current Active Prompt Keywords:", md_output) + self.assertIn("- code (1 active prompts)", md_output) + self.assertIn("- cooking (1 active prompts)", md_output) + + def test_export_keywords_to_csv(self): + file_db = self._get_file_db() + add_or_update_prompt(file_db, "Prompt 1", "Auth", "Det", keywords=["a", "b"]) + add_or_update_prompt(file_db, "Prompt 2", "Auth", "Det", keywords=["b", "c"]) + + status, file_path = export_prompt_keywords_to_csv(file_db) + self.assertIn("Successfully exported", status) + self.assertTrue(os.path.exists(file_path)) + + with open(file_path, 'r', encoding='utf-8') as f: + content = f.read() + self.assertIn("Keyword,Associated Prompts", content) + self.assertIn("a,Prompt 1,1", content) + self.assertIn("c,Prompt 2,1", content) + # Handle potential ordering difference in GROUP_CONCAT + self.assertTrue("b,\"Prompt 1,Prompt 2\",2" in content or "b,\"Prompt 2,Prompt 1\",2" in content) + + os.remove(file_path) + + def test_export_prompts_formatted_csv(self): + file_db = self._get_file_db() + self._add_sample_data(file_db) + + status, file_path = export_prompts_formatted(file_db, export_format='csv') + self.assertIn("Successfully exported 3 prompts", status) + self.assertTrue(os.path.exists(file_path)) + + with open(file_path, 'r', encoding='utf-8') as f: + content = f.read() + self.assertIn("Name,UUID,Author,Details,System Prompt,User Prompt,Keywords", content) + self.assertIn("Recipe Generator", content) + # FIX: Keywords are sorted by the fetch method, so 'cooking' comes before 'food'. + self.assertIn('"cooking, food"', content) + + os.remove(file_path) + + def test_export_prompts_formatted_markdown(self): + file_db = self._get_file_db() + self._add_sample_data(file_db) + + status, file_path = export_prompts_formatted(file_db, export_format='markdown') + self.assertIn("Successfully exported 3 prompts to Markdown", status) + self.assertTrue(os.path.exists(file_path)) + # A more thorough test could unzip and verify content, but file creation is a good start. + os.remove(file_path) + + +class TestDatabaseIntegrityAndSchema(BaseTestCase): + """ + Tests focused on low-level database integrity, schema rules, and triggers. + These tests may involve direct SQL execution to bypass library methods. + """ + + def test_database_version_too_high_raises_error(self): + """Ensure initializing a DB with a future schema version fails gracefully.""" + file_db = self._get_file_db() + conn = file_db.get_connection() + # Manually set the schema version to a higher number + conn.execute("UPDATE schema_version SET version = 99") + conn.commit() + # FIX: We must close the connection so the file handle is released before the next open attempt + file_db.close_connection() + # The instance is now closed, remove it from tracking to avoid double-closing + self.file_db_instances.remove(file_db) + + # Now, trying to create a new instance pointing to this DB should fail + # FIX: The __init__ method wraps SchemaError in a DatabaseError. + with self.assertRaisesRegex(DatabaseError, "newer than supported"): + PromptsDatabase(self.db_path, client_id="test_client_fail") + + def test_trigger_prevents_bad_version_update(self): + """Verify the SQL trigger prevents updates that don't increment version by 1.""" + pid, _, _ = self.db.add_prompt("Trigger Test", "T", "D") # Version is 1 + conn = self.db.get_connection() + + with self.assertRaises(sqlite3.IntegrityError) as cm: + conn.execute("UPDATE Prompts SET version = 3 WHERE id = ?", (pid,)) + self.assertIn("Version must increment by exactly 1", str(cm.exception)) + + # This should succeed + conn.execute("UPDATE Prompts SET version = 2, client_id='raw_sql' WHERE id = ?", (pid,)) + conn.commit() + prompt = self.db.get_prompt_by_id(pid) + self.assertEqual(prompt['version'], 2) + + def test_trigger_prevents_uuid_change(self): + """Verify the SQL trigger prevents changing a UUID on update.""" + pid, original_uuid, _ = self.db.add_prompt("UUID Lock", "T", "D") + conn = self.db.get_connection() + new_uuid = str(uuid.uuid4()) + + with self.assertRaises(sqlite3.IntegrityError) as cm: + # Try to update the UUID (and correctly increment version) + conn.execute("UPDATE Prompts SET uuid = ?, version = 2, client_id='raw' WHERE id = ?", (new_uuid, pid)) + self.assertIn("UUID cannot be changed", str(cm.exception)) + + # Verify UUID is unchanged + prompt = self.db.get_prompt_by_id(pid) + self.assertEqual(prompt['uuid'], original_uuid) + + +class TestAdvancedBehaviorsAndEdgeCases(BaseTestCase): + """ + Tests for more nuanced behaviors and specific edge cases not covered + in standard CRUD operations. + """ + + def test_reopen_closed_connection(self): + """Test that the library can reopen a connection that was explicitly closed.""" + file_db = self._get_file_db() + + conn1 = file_db.get_connection() + self.assertIsNotNone(conn1) + conn1.close() + + conn2 = file_db.get_connection() + self.assertIsNotNone(conn2) + self.assertNotEqual(id(conn1), id(conn2)) + + file_db.add_keyword("test_after_reopen") + self.assertIsNotNone(file_db.get_active_keyword_by_text("test_after_reopen")) + + def test_undelete_keyword(self): + """Test that adding an already soft-deleted keyword undeletes it.""" + self.db.add_keyword("to be deleted and restored") + self.db.soft_delete_keyword("to be deleted and restored") + self.assertIsNone(self.db.get_active_keyword_by_text("to be deleted and restored")) + + # Now, add it again + kw_id, kw_uuid = self.db.add_keyword("to be deleted and restored") + + # Verify it's active again + restored_kw = self.db.get_active_keyword_by_text("to be deleted and restored") + self.assertIsNotNone(restored_kw) + self.assertEqual(restored_kw['id'], kw_id) + self.assertEqual(restored_kw['version'], 3) # 1: create, 2: delete, 3: undelete (update) + + # Check sync log for the 'update' operation + sync_logs = self.db.get_sync_log_entries() + undelete_log = next(log for log in sync_logs if log['version'] == 3) + self.assertEqual(undelete_log['operation'], 'update') + self.assertEqual(undelete_log['payload']['deleted'], 0) + + def test_update_prompt_name_to_existing_name_conflict(self): + """Ensure updating a prompt's name to another existing prompt's name fails.""" + self.db.add_prompt("Prompt A", "Author", "Details") + pid_b, _, _ = self.db.add_prompt("Prompt B", "Author", "Details") + + with self.assertRaises(ConflictError): + self.db.update_prompt_by_id(pid_b, {"name": "Prompt A"}) + + def test_nested_transactions(self): + pid, _, _ = self.db.add_prompt("Transaction Test", "T", "D") + + try: + with self.db.transaction(): # Outer transaction + self.db.execute_query("UPDATE Prompts SET name = 'Outer Update', version=2, client_id='t' WHERE id = ?", + (pid,)) + + with self.db.transaction(): # Inner transaction + self.db.execute_query( + "UPDATE Prompts SET author = 'Inner Update', version=3, client_id='t' WHERE id = ?", (pid,)) + + prompt_inside = self.db.get_prompt_by_id(pid) + self.assertEqual(prompt_inside['author'], 'Inner Update') + + raise ValueError("Force rollback of outer transaction") + + except ValueError: + pass + + prompt_outside = self.db.get_prompt_by_id(pid) + self.assertEqual(prompt_outside['name'], "Transaction Test") + self.assertEqual(prompt_outside['author'], "T") + self.assertEqual(prompt_outside['version'], 1) + + def test_soft_delete_prompt_by_name_and_uuid(self): + """Test soft deletion using name and UUID identifiers.""" + _, p1_uuid, _ = self.db.add_prompt("Deletable By Name", "A", "D") + p2_id, p2_uuid, _ = self.db.add_prompt("Deletable By UUID", "B", "E") + + # Delete by name + self.assertTrue(self.db.soft_delete_prompt("Deletable By Name")) + self.assertIsNone(self.db.get_prompt_by_name("Deletable By Name")) + + # Delete by UUID + self.assertTrue(self.db.soft_delete_prompt(p2_uuid)) + self.assertIsNone(self.db.get_prompt_by_id(p2_id)) + + +class TestSearchFunctionality(BaseTestCase): + """More detailed tests for the search_prompts function.""" + + def setUp(self): + super().setUp() + self._add_sample_data(self.db) + self.db.add_prompt("Shared Term Prompt", "Author", "This prompt contains python.", keywords=["generic"]) + self.db.add_prompt("Another Code Prompt", "DevHelper", "More code things.", keywords=["python"]) + + def test_search_with_no_query_returns_all_active(self): + """Searching with no query should act like listing all active prompts.""" + results, total = self.db.search_prompts(None) + # 3 from _add_sample_data + 2 added in this setUp = 5 active prompts + self.assertEqual(total, 5) + self.assertEqual(len(results), 5) + + def test_search_with_no_results(self): + """Ensure a search with no matches returns an empty list and zero total.""" + results, total = self.db.search_prompts("nonexistentxyz") + self.assertEqual(total, 0) + self.assertEqual(len(results), 0) + + def test_search_pagination(self): + """Test if pagination works correctly on search results.""" + results, total = self.db.search_prompts("python", search_fields=['details', 'keywords'], page=1, + results_per_page=2) + self.assertEqual(total, 3) # Code Explainer, Shared Term, Another Code + self.assertEqual(len(results), 2) + + results_p2, total_p2 = self.db.search_prompts("python", search_fields=['details', 'keywords'], page=2, + results_per_page=2) + self.assertEqual(total_p2, 3) + self.assertEqual(len(results_p2), 1) + + def test_search_across_multiple_fields(self): + """Test searching in both details and keywords simultaneously.""" + results, total = self.db.search_prompts("python", search_fields=['details', 'keywords']) + self.assertEqual(total, 3) + names = {r['name'] for r in results} + self.assertIn("Code Explainer", names) + self.assertIn("Shared Term Prompt", names) + self.assertIn("Another Code Prompt", names) + + def test_search_with_invalid_fts_syntax_raises_error(self): + """Verify that malformed FTS queries raise a DatabaseError.""" + # An unclosed quote is invalid syntax + with self.assertRaises(DatabaseError): + self.db.search_prompts('invalid "syntax', search_fields=['name']) + + +class TestStandaloneFunctionExports(BaseTestCase): + """Tests for variations in the standalone export functions.""" + + def setUp(self): + super().setUp() + self.file_db = self._get_file_db() + self._add_sample_data(self.file_db) + + def test_export_with_no_matching_prompts(self): + """Test export when the filter criteria yields no results.""" + status, file_path = export_prompts_formatted( + self.file_db, + filter_keywords=["nonexistent_keyword"] + ) + self.assertIn("No prompts found", status) + self.assertEqual(file_path, "None") + + def test_export_csv_minimal_columns(self): + """Test CSV export with most boolean flags turned off.""" + status, file_path = export_prompts_formatted( + self.file_db, + export_format='csv', + include_details=False, + include_system=False, + include_user=False, + include_associated_keywords=False + ) + self.assertIn("Successfully exported", status) + self.assertTrue(os.path.exists(file_path)) + + with open(file_path, 'r', encoding='utf-8') as f: + header = f.readline().strip() + self.assertEqual(header, "Name,UUID,Author") + + os.remove(file_path) + + def test_export_markdown_with_different_template(self): + """Test Markdown export using a non-default template.""" + status, file_path = export_prompts_formatted( + self.file_db, + export_format='markdown', + markdown_template_name="Detailed Template" + ) + self.assertIn("Successfully exported", status) + self.assertTrue(os.path.exists(file_path)) + os.remove(file_path) + + +class TestConcurrencyAndDataIntegrity(BaseTestCase): + """ + Tests for race conditions, data encoding, and integrity under stress. + """ + + def test_concurrent_updates_to_same_prompt(self): + """ + Simulate a race condition where two threads update the same prompt. + One should succeed, the other should fail with a ConflictError. + """ + file_db = self._get_file_db() + pid, _, _ = file_db.add_prompt("Race Condition", "Initial", "Data") + + results = {} + barrier = threading.Barrier(2, timeout=5) + + def worker(user_id): + try: + db_instance = PromptsDatabase(self.db_path, client_id=f"worker_{user_id}") + barrier.wait() + db_instance.update_prompt_by_id(pid, {'details': f'Updated by {user_id}'}) + results[user_id] = "success" + except ConflictError: + results[user_id] = "conflict" + except DatabaseError as e: + if "locked" in str(e).lower() or "conflict" in str(e).lower(): + results[user_id] = "conflict" + else: + results[user_id] = e + except Exception as e: + results[user_id] = e + finally: + if 'db_instance' in locals(): + db_instance.close_connection() + + thread1 = threading.Thread(target=worker, args=(1,)) + thread2 = threading.Thread(target=worker, args=(2,)) + + thread1.start() + thread2.start() + thread1.join() + thread2.join() + + self.assertIn("success", results.values()) + self.assertIn("conflict", results.values()) + + final_prompt = file_db.get_prompt_by_id(pid) + self.assertEqual(final_prompt['version'], 2) + self.assertTrue(final_prompt['details'].startswith("Updated by")) + + def test_unicode_character_support(self): + """Ensure all text fields correctly handle Unicode characters.""" + unicode_name = "こんにちは世界" + unicode_author = "Александр" + unicode_details = "Testing emoji support 👍 and special characters ç, é, à." + unicode_keywords = ["你好", "世界", "prüfung"] + + pid, _, _ = self.db.add_prompt( + name=unicode_name, + author=unicode_author, + details=unicode_details, + keywords=unicode_keywords + ) + + prompt = self.db.fetch_prompt_details(pid) + self.assertEqual(prompt['name'], unicode_name) + self.assertEqual(prompt['author'], unicode_author) + self.assertEqual(prompt['details'], unicode_details) + self.assertEqual(sorted(prompt['keywords']), sorted(unicode_keywords)) + + results, total = self.db.search_prompts("你好", search_fields=['keywords']) + self.assertEqual(total, 1) + self.assertEqual(results[0]['name'], unicode_name) + + status, file_path = export_prompts_formatted(self.db, 'csv') + self.assertTrue(os.path.exists(file_path)) + with open(file_path, 'r', encoding='utf-8') as f: + content = f.read() + self.assertIn(unicode_name, content) + self.assertIn(unicode_author, content) + self.assertIn("prüfung", content) + os.remove(file_path) + + def test_fts_desynchronization_by_direct_sql(self): + """ + Demonstrates that direct SQL updates will de-sync the FTS table. + """ + unique_term = "zzyzx" + pid, _, _ = self.db.add_prompt("FTS Sync Test", "Author", f"Details with {unique_term}") + + results, total = self.db.search_prompts(unique_term) + self.assertEqual(total, 1) + + conn = self.db.get_connection() + conn.execute( + "UPDATE Prompts SET details='Details are now different', version=2, client_id='raw' WHERE id=?", + (pid,) + ) + conn.commit() + + prompt = self.db.get_prompt_by_id(pid) + self.assertEqual(prompt['details'], "Details are now different") + + # FTS search for the *new* term will FAIL because FTS was not updated + results_new, total_new = self.db.search_prompts("different") + self.assertEqual(total_new, 0) + + # FTS search for the *old* term will SUCCEED because FTS is stale + results_old, total_old = self.db.search_prompts(unique_term) + self.assertEqual(total_old, 1) + + +# Additional tests from the failing set, now expected to pass +class TestAdvancedStateTransitions(BaseTestCase): + def test_update_soft_deleted_prompt_restores_it(self): + pid, _, _ = self.db.add_prompt("To be deleted", "Author", "Old Details") + self.db.soft_delete_prompt(pid) + self.assertIsNone(self.db.get_prompt_by_id(pid)) + + update_data = {"details": "New, Restored Details"} + self.db.update_prompt_by_id(pid, update_data) + + restored_prompt = self.db.get_prompt_by_id(pid) + self.assertIsNotNone(restored_prompt) + self.assertEqual(restored_prompt['deleted'], 0) + self.assertEqual(restored_prompt['details'], "New, Restored Details") + self.assertEqual(restored_prompt['version'], 3) + + def test_handling_of_corrupted_sync_log_payload(self): + self.db.add_keyword("good_payload") + log_entry = self.db.get_sync_log_entries()[0] + change_id = log_entry['change_id'] + + conn = self.db.get_connection() + conn.execute( + "UPDATE sync_log SET payload = ? WHERE change_id = ?", + ("{'bad_json': this_is_not_valid}", change_id) + ) + conn.commit() + + all_logs = self.db.get_sync_log_entries() + corrupted_log = next(log for log in all_logs if log['change_id'] == change_id) + self.assertIsNone(corrupted_log['payload']) + + def test_add_keyword_with_only_whitespace_fails(self): + """Test that adding a keyword that is only whitespace raises an InputError.""" + with self.assertRaises(InputError): + self.db.add_keyword(" ") + + # Verify no sync log was created + self.assertEqual(len(self.db.get_sync_log_entries()), 0) + + def test_add_prompt_with_only_whitespace_name_fails(self): + """Test that adding a prompt with a whitespace-only name raises an InputError.""" + with self.assertRaises(InputError): + self.db.add_prompt(" ", "Author", "Details") + + # Verify no sync log was created + self.assertEqual(len(self.db.get_sync_log_entries()), 0) + + +class TestSyncLogManagement(BaseTestCase): + """ + In-depth tests for the sync_log table management and access methods. + """ + + def test_get_sync_log_with_since_change_id_and_limit(self): + """Verify that fetching logs with a starting ID and a limit works correctly.""" + # Add 5 items, which should generate at least 5 logs + self.db.add_keyword("kw1") + self.db.add_keyword("kw2") + self.db.add_keyword("kw3") + self.db.add_keyword("kw4") + self.db.add_keyword("kw5") + + all_logs = self.db.get_sync_log_entries() + self.assertEqual(len(all_logs), 5) + + # Get logs since change_id 2 (should get 3, 4, 5) + logs_since_2 = self.db.get_sync_log_entries(since_change_id=2) + self.assertEqual(len(logs_since_2), 3) + self.assertEqual(logs_since_2[0]['change_id'], 3) + + # Get logs since change_id 2 with a limit of 1 (should only get 3) + logs_limited = self.db.get_sync_log_entries(since_change_id=2, limit=1) + self.assertEqual(len(logs_limited), 1) + self.assertEqual(logs_limited[0]['change_id'], 3) + + def test_delete_sync_log_with_nonexistent_ids(self): + """Ensure deleting a mix of existing and non-existent log IDs works as expected.""" + self.db.add_keyword("kw1") # change_id 1 + self.db.add_keyword("kw2") # change_id 2 + + # Attempt to delete IDs 1 and 9999 (which doesn't exist) + deleted_count = self.db.delete_sync_log_entries([1, 9999]) + + self.assertEqual(deleted_count, 1) + + remaining_logs = self.db.get_sync_log_entries() + self.assertEqual(len(remaining_logs), 1) + self.assertEqual(remaining_logs[0]['change_id'], 2) + + +class TestComplexStateAndInputInteractions(BaseTestCase): + """ + Tests for nuanced interactions between methods and states. + """ + + def test_add_prompt_with_overwrite_false_on_deleted_prompt(self): + """ + Verify the specific behavior of add_prompt(overwrite=False) when a prompt + is soft-deleted. It should not restore it and should return a specific message. + """ + self.db.add_prompt("Deleted but Exists", "Author", "Details") + self.db.soft_delete_prompt("Deleted but Exists") + + # Attempt to add it again without overwrite flag + pid, puuid, msg = self.db.add_prompt("Deleted but Exists", "New Author", "New Details", overwrite=False) + + self.assertIn("exists but is soft-deleted", msg) + + # Verify it was not restored or updated + prompt = self.db.get_prompt_by_name("Deleted but Exists", include_deleted=True) + self.assertEqual(prompt['author'], "Author") # Should be the original author + self.assertEqual(prompt['deleted'], 1) # Should remain deleted + + def test_soft_delete_nonexistent_item_returns_false(self): + """Ensure attempting to delete non-existent items returns False and doesn't error.""" + result_prompt = self.db.soft_delete_prompt("non-existent prompt") + self.assertFalse(result_prompt) + + result_keyword = self.db.soft_delete_keyword("non-existent keyword") + self.assertFalse(result_keyword) + + def test_update_keywords_for_prompt_with_empty_list_removes_all(self): + """Updating keywords with an empty list should remove all existing keywords.""" + pid, _, _ = self.db.add_prompt("Keyword Test", "A", "D", keywords=["kw1", "kw2"]) + self.assertEqual(len(self.db.fetch_keywords_for_prompt(pid)), 2) + + # Update with an empty list + self.db.update_keywords_for_prompt(pid, []) + + self.assertEqual(len(self.db.fetch_keywords_for_prompt(pid)), 0) + + # Verify unlink events were logged + unlink_logs = [l for l in self.db.get_sync_log_entries() if l['operation'] == 'unlink'] + self.assertEqual(len(unlink_logs), 2) + + def test_update_keywords_for_prompt_is_idempotent(self): + """Running update_keywords with the same list should result in no changes or new logs.""" + pid, _, _ = self.db.add_prompt("Idempotent Test", "A", "D", keywords=["kw1", "kw2"]) + + initial_log_count = len(self.db.get_sync_log_entries()) + + # Rerun with the same keywords + self.db.update_keywords_for_prompt(pid, ["kw1", "kw2"]) + + final_log_count = len(self.db.get_sync_log_entries()) + self.assertEqual(initial_log_count, final_log_count, + "No new sync logs should be created for an idempotent update") + + def test_update_keywords_for_prompt_handles_duplicates_and_whitespace(self): + """Ensure keyword lists are properly normalized before processing.""" + pid, _, _ = self.db.add_prompt("Normalization Test", "A", "D") + + messy_keywords = [" Tag A ", "tag b", "Tag A", " ", "tag c "] + self.db.update_keywords_for_prompt(pid, messy_keywords) + + final_keywords = self.db.fetch_keywords_for_prompt(pid) + self.assertEqual(sorted(final_keywords), ["tag a", "tag b", "tag c"]) + + +class TestBulkOperationsAndScale(BaseTestCase): + """ + Tests focusing on bulk methods and behavior with a larger number of records. + """ + + def test_execute_many_success(self): + """Test successful bulk insertion with execute_many.""" + keywords_to_add = [ + (f"bulk_keyword_{i}", str(uuid.uuid4()), self.db._get_current_utc_timestamp_str(), 1, self.client_id, 0) + for i in range(50) + ] + sql = "INSERT INTO PromptKeywordsTable (keyword, uuid, last_modified, version, client_id, deleted) VALUES (?, ?, ?, ?, ?, ?)" + + with self.db.transaction(): + self.db.execute_many(sql, keywords_to_add) + + count = self.db.execute_query("SELECT COUNT(*) FROM PromptKeywordsTable").fetchone()[0] + self.assertEqual(count, 50) + + def test_execute_many_failure_with_integrity_error_rolls_back(self): + """Ensure a failing execute_many call within a transaction rolls back the entire batch.""" + keywords_to_add = [ + ("unique_1", str(uuid.uuid4()), self.db._get_current_utc_timestamp_str(), 1, self.client_id, 0), + ("not_unique", str(uuid.uuid4()), self.db._get_current_utc_timestamp_str(), 1, self.client_id, 0), + ("not_unique", str(uuid.uuid4()), self.db._get_current_utc_timestamp_str(), 1, self.client_id, 0), + # Fails here + ("unique_2", str(uuid.uuid4()), self.db._get_current_utc_timestamp_str(), 1, self.client_id, 0), + ] + sql = "INSERT INTO PromptKeywordsTable (keyword, uuid, last_modified, version, client_id, deleted) VALUES (?, ?, ?, ?, ?, ?)" + + with self.assertRaises(DatabaseError) as cm: + with self.db.transaction(): + self.db.execute_many(sql, keywords_to_add) + + self.assertIn("UNIQUE constraint failed", str(cm.exception)) + + # Verify rollback + count = self.db.execute_query("SELECT COUNT(*) FROM PromptKeywordsTable").fetchone()[0] + self.assertEqual(count, 0) + + def test_dependency_integrity_on_delete(self): + """ + Verify that deleting a prompt doesn't delete a keyword used by other prompts. + """ + # "common_kw" is used by both prompts + p1_id, _, _ = self.db.add_prompt("Prompt 1", "A", "D", keywords=["p1_kw", "common_kw"]) + self.db.add_prompt("Prompt 2", "B", "E", keywords=["p2_kw", "common_kw"]) + + # Soft delete Prompt 1 + self.db.soft_delete_prompt(p1_id) + + # Check that "common_kw" still exists and is active + kw_data = self.db.get_active_keyword_by_text("common_kw") + self.assertIsNotNone(kw_data) + + # Check that Prompt 2 still has its link to "common_kw" + p2_details = self.db.fetch_prompt_details("Prompt 2") + self.assertIn("common_kw", p2_details['keywords']) + + +if __name__ == '__main__': + unittest.main(argv=['first-arg-is-ignored'], exit=False) + + +# +# End of tests_prompts_db.py +####################################################################################################################### diff --git a/Tests/Prompts_DB/tests_prompts_db_properties.py b/Tests/Prompts_DB/tests_prompts_db_properties.py new file mode 100644 index 00000000..ae6ff3b7 --- /dev/null +++ b/Tests/Prompts_DB/tests_prompts_db_properties.py @@ -0,0 +1,574 @@ +# test_prompts_db_properties.py +# +# Property-based tests for the Prompts_DB_v2 library using Hypothesis. + +# Imports +import uuid +import pytest +import json +from pathlib import Path +import sqlite3 +import threading +import time + +# Third-Party Imports +from hypothesis import given, strategies as st, settings, HealthCheck +from hypothesis.stateful import RuleBasedStateMachine, rule, precondition, Bundle + +# Local Imports +# Assuming Prompts_DB_v2.py is in a location Python can find. +# For example, in the same directory or in a package. +from tldw_chatbook.DB.Prompts_DB import ( + PromptsDatabase, + InputError, + DatabaseError, + ConflictError +) + +######################################################################################################################## +# +# Hypothesis Setup: + +# A custom profile for DB tests to avoid timeouts on complex operations. +settings.register_profile( + "db_friendly", + deadline=1500, # Increased deadline for potentially slow DB I/O + suppress_health_check=[ + HealthCheck.too_slow, + HealthCheck.function_scoped_fixture + ] +) +settings.load_profile("db_friendly") + + +# --- Fixtures --- + +@pytest.fixture +def client_id(): + """Provides a consistent client ID for tests.""" + return "hypothesis_client" + + +@pytest.fixture +def db_path(tmp_path): + """Provides a temporary path for the database file for each test.""" + return tmp_path / "prop_test_prompts_db.sqlite" + + +@pytest.fixture(scope="function") +def db_instance(db_path, client_id): + """Creates a fresh PromptsDatabase instance for each test function.""" + current_db_path = Path(db_path) + # Ensure no leftover files from a failed previous run (important for WAL mode) + for suffix in ["", "-wal", "-shm"]: + p = Path(str(current_db_path) + suffix) + if p.exists(): + p.unlink(missing_ok=True) + + db = PromptsDatabase(current_db_path, client_id) + yield db + db.close_connection() + + +# --- Hypothesis Strategies --- + +# Strategy for text fields that cannot be empty or just whitespace. +st_required_text = st.text(min_size=1, max_size=100).filter(lambda s: s.strip()) + +# Strategy for optional text fields. +st_optional_text = st.one_of(st.none(), st.text(max_size=500)) + + +@st.composite +def st_prompt_data(draw): + """A composite strategy to generate a dictionary of prompt data.""" + # Generate keywords that are unique after normalization + keywords = draw(st.lists(st_required_text, max_size=5, unique_by=lambda s: s.strip().lower())) + + return { + "name": draw(st_required_text), + "author": draw(st_optional_text), + "details": draw(st_optional_text), + "system_prompt": draw(st_optional_text), + "user_prompt": draw(st_optional_text), + "keywords": keywords, + } + + +# A strategy for a non-one integer to test version validation triggers. +st_bad_version_offset = st.integers().filter(lambda x: x != 1) + + +# --- Test Classes --- + +class TestPromptProperties: + """Property-based tests for core Prompt operations.""" + + @given(prompt_data=st_prompt_data()) + def test_prompt_roundtrip(self, db_instance: PromptsDatabase, prompt_data: dict): + """ + Property: If we add a prompt, retrieving it should return the same data, + accounting for any normalization (e.g., stripping name, normalizing keywords). + """ + try: + # Use overwrite=False to test the create path; ConflictError is a valid outcome. + prompt_id, prompt_uuid, _ = db_instance.add_prompt(**prompt_data, overwrite=False) + except ConflictError: + return # Hypothesis generated a name collision, which is not a failure. + + assert prompt_id is not None + assert prompt_uuid is not None + + retrieved_prompt = db_instance.fetch_prompt_details(prompt_id) + assert retrieved_prompt is not None + + # Compare basic fields + assert retrieved_prompt["name"] == prompt_data["name"].strip() + assert retrieved_prompt["author"] == prompt_data["author"] + assert retrieved_prompt["details"] == prompt_data["details"] + assert retrieved_prompt["system_prompt"] == prompt_data["system_prompt"] + assert retrieved_prompt["user_prompt"] == prompt_data["user_prompt"] + assert retrieved_prompt["version"] == 1 + assert not retrieved_prompt["deleted"] + + # Compare keywords, which are normalized by the database + expected_keywords = sorted([db_instance._normalize_keyword(k) for k in prompt_data["keywords"]]) + retrieved_keywords = sorted(retrieved_prompt["keywords"]) + assert retrieved_keywords == expected_keywords + + @given(initial_prompt=st_prompt_data(), update_payload=st_prompt_data()) + def test_update_increments_version_and_changes_data(self, db_instance: PromptsDatabase, initial_prompt: dict, + update_payload: dict): + """ + Property: A successful update must increment the version number by exactly 1 + and correctly apply the new data, including keywords. + """ + try: + prompt_id, _, _ = db_instance.add_prompt(**initial_prompt) + except ConflictError: + return # Skip if initial name conflicts + + original_prompt = db_instance.get_prompt_by_id(prompt_id) + + try: + # update_prompt_by_id handles fetching current version and incrementing it. + uuid, msg = db_instance.update_prompt_by_id(prompt_id, update_payload) + assert uuid is not None + except ConflictError as e: + # A legitimate failure if the new name is already taken by another prompt. + assert "already exists" in str(e) + return + + updated_prompt = db_instance.fetch_prompt_details(prompt_id) + assert updated_prompt is not None + assert updated_prompt['version'] == original_prompt['version'] + 1 + + # Verify the payload was applied + assert updated_prompt['name'] == update_payload['name'].strip() + assert updated_prompt['author'] == update_payload['author'] + expected_keywords = sorted([db_instance._normalize_keyword(k) for k in update_payload["keywords"]]) + assert sorted(updated_prompt['keywords']) == expected_keywords + + @given(prompt_data=st_prompt_data()) + def test_soft_delete_makes_item_unfindable(self, db_instance: PromptsDatabase, prompt_data: dict): + """ + Property: After soft-deleting a prompt, it should not be retrievable by + default methods, but should exist in the DB with deleted=1. + """ + try: + prompt_id, _, _ = db_instance.add_prompt(**prompt_data) + except ConflictError: + return + + # Perform the soft delete + success = db_instance.soft_delete_prompt(prompt_id) + assert success is True + + # Assert it's no longer findable via public methods by default + assert db_instance.get_prompt_by_id(prompt_id) is None + assert db_instance.fetch_prompt_details(prompt_id) is None + + all_prompts, _, _, _ = db_instance.list_prompts() + assert prompt_id not in [p['id'] for p in all_prompts] + + # Assert it CAN be found when explicitly requested + deleted_record = db_instance.get_prompt_by_id(prompt_id, include_deleted=True) + assert deleted_record is not None + assert deleted_record['deleted'] == 1 + assert deleted_record['version'] == 2 # 1=create, 2=delete + + @given(initial_prompt=st_prompt_data(), update_name=st_required_text, version_offset=st_bad_version_offset) + def test_update_with_stale_version_fails_via_trigger(self, db_instance: PromptsDatabase, initial_prompt: dict, + update_name: str, version_offset: int): + """ + Property: Attempting a direct DB update with a version that does not increment + by exactly 1 must be rejected by the database trigger. + """ + try: + prompt_id, _, _ = db_instance.add_prompt(**initial_prompt) + except ConflictError: + return + + original_prompt = db_instance.get_prompt_by_id(prompt_id) + + # Attempt a direct DB update with a bad version number. + # This tests the 'prompts_validate_sync_update' trigger. + with pytest.raises(DatabaseError) as excinfo: + db_instance.execute_query( + "UPDATE Prompts SET name = ?, version = ? WHERE id = ?", + (update_name, original_prompt['version'] + version_offset, prompt_id), + commit=True + ) + assert "version must increment by exactly 1" in str(excinfo.value).lower() + + +class TestKeywordAndLinkingProperties: + """Property-based tests for Keywords and their linking to Prompts.""" + + @given(keyword_text=st_required_text) + def test_keyword_normalization_and_roundtrip(self, db_instance: PromptsDatabase, keyword_text: str): + """ + Property: Adding a keyword normalizes it (lowercase, stripped). + Retrieving it returns the normalized version. + """ + kw_id, kw_uuid = db_instance.add_keyword(keyword_text) + assert kw_id is not None + assert kw_uuid is not None + + retrieved_kw = db_instance.get_active_keyword_by_text(keyword_text) + assert retrieved_kw is not None + assert retrieved_kw['keyword'] == db_instance._normalize_keyword(keyword_text) + + @given(keyword=st_required_text) + def test_add_keyword_is_idempotent_on_undelete(self, db_instance: PromptsDatabase, keyword: str): + """ + Property: Adding a keyword that was previously soft-deleted should reactivate + it (not create a new one), and its version should be correctly incremented. + """ + # 1. Add for the first time + kw_id_v1, _ = db_instance.add_keyword(keyword) + assert db_instance.get_prompt_by_id(kw_id_v1) is not None # Using wrong get method in original code + kw_v1 = db_instance.get_active_keyword_by_text(keyword) + assert kw_v1['version'] == 1 + + # 2. Soft delete it + success = db_instance.soft_delete_keyword(keyword) + assert success is True + + # Check raw state + raw_kw = db_instance.execute_query("SELECT * FROM PromptKeywordsTable WHERE id=?", (kw_id_v1,)).fetchone() + assert raw_kw['deleted'] == 1 + assert raw_kw['version'] == 2 + + # 3. Add it again (should trigger undelete) + kw_id_v3, _ = db_instance.add_keyword(keyword) + + # Assert it's the same record + assert kw_id_v3 == kw_id_v1 + + kw_v3 = db_instance.get_active_keyword_by_text(keyword) + assert kw_v3 is not None + assert not db_instance.get_prompt_by_id(kw_v3['id'], include_deleted=True)['deleted'] + # The version should be 3 (1=create, 2=delete, 3=undelete/update) + assert kw_v3['version'] == 3 + + @given( + prompt_data=st_prompt_data(), + new_keywords=st.lists(st_required_text, max_size=5, unique_by=lambda s: s.strip().lower()) + ) + def test_update_keywords_for_prompt_links_and_unlinks(self, db_instance: PromptsDatabase, prompt_data: dict, + new_keywords: list): + """ + Property: Updating keywords for a prompt correctly adds new links, + removes old ones, and leaves unchanged ones alone. + """ + try: + prompt_id, _, _ = db_instance.add_prompt(**prompt_data) + except ConflictError: + return + + # Initial state check + initial_expected_kws = sorted([db_instance._normalize_keyword(k) for k in prompt_data['keywords']]) + assert sorted(db_instance.fetch_keywords_for_prompt(prompt_id)) == initial_expected_kws + + # Update the keywords + db_instance.update_keywords_for_prompt(prompt_id, new_keywords) + + # Final state check + final_expected_kws = sorted([db_instance._normalize_keyword(k) for k in new_keywords]) + assert sorted(db_instance.fetch_keywords_for_prompt(prompt_id)) == final_expected_kws + + +class TestAdvancedProperties: + """Tests for FTS, Sync Log, and other complex interactions.""" + + @given(prompt_data=st_prompt_data()) + def test_soft_deleted_item_is_not_in_fts(self, db_instance: PromptsDatabase, prompt_data: dict): + """ + Property: Once a prompt is soft-deleted, it must not appear in FTS search results. + """ + # Ensure the name has a unique, searchable term. + unique_term = str(uuid.uuid4()) + prompt_data['name'] = f"{prompt_data['name']} {unique_term}" + + try: + prompt_id, _, _ = db_instance.add_prompt(**prompt_data) + except ConflictError: + return + + # 1. Verify it IS searchable before deletion + results_before, total_before = db_instance.search_prompts(unique_term) + assert total_before == 1 + assert results_before[0]['id'] == prompt_id + + # 2. Soft-delete the prompt + db_instance.soft_delete_prompt(prompt_id) + + # 3. Verify it is NOT searchable after deletion + results_after, total_after = db_instance.search_prompts(unique_term) + assert total_after == 0 + + @given(prompt_data=st_prompt_data()) + def test_add_creates_correct_sync_log_entries(self, db_instance: PromptsDatabase, prompt_data: dict): + """ + Property: Adding a new prompt must create the correct 'create' and 'link' + operations in the sync_log. + """ + latest_change_id_before = db_instance.get_sync_log_entries(limit=1)[-1][ + 'change_id'] if db_instance.get_sync_log_entries(limit=1) else 0 + + try: + prompt_id, prompt_uuid, _ = db_instance.add_prompt(**prompt_data) + except ConflictError: + return + + new_logs = db_instance.get_sync_log_entries(since_change_id=latest_change_id_before) + + # Verify 'create' log for the prompt itself + prompt_create_logs = [log for log in new_logs if log['entity'] == 'Prompts' and log['operation'] == 'create'] + assert len(prompt_create_logs) == 1 + assert prompt_create_logs[0]['entity_uuid'] == prompt_uuid + assert prompt_create_logs[0]['version'] == 1 + + # Verify 'link' logs for the keywords + normalized_keywords = {db_instance._normalize_keyword(k) for k in prompt_data['keywords']} + link_logs = [log for log in new_logs if log['entity'] == 'PromptKeywordLinks' and log['operation'] == 'link'] + assert len(link_logs) == len(normalized_keywords) + + # The payload should contain the composite UUID of prompt_uuid + keyword_uuid + for log in link_logs: + assert log['payload']['prompt_uuid'] == prompt_uuid + assert log['entity_uuid'].startswith(prompt_uuid) + + @given(prompt_data=st_prompt_data()) + def test_delete_creates_correct_sync_log_entries(self, db_instance: PromptsDatabase, prompt_data: dict): + """ + Property: Soft-deleting a prompt must create a 'delete' log for the prompt + and 'unlink' logs for all its keyword connections. + """ + try: + prompt_id, prompt_uuid, _ = db_instance.add_prompt(**prompt_data) + except ConflictError: + return + + num_keywords = len(prompt_data['keywords']) + latest_change_id_before = db_instance.get_sync_log_entries(limit=1)[-1]['change_id'] + + # Action: Soft delete + db_instance.soft_delete_prompt(prompt_id) + + new_logs = db_instance.get_sync_log_entries(since_change_id=latest_change_id_before) + + # Verify 'delete' log for the prompt + prompt_delete_logs = [log for log in new_logs if log['entity'] == 'Prompts' and log['operation'] == 'delete'] + assert len(prompt_delete_logs) == 1 + assert prompt_delete_logs[0]['entity_uuid'] == prompt_uuid + assert prompt_delete_logs[0]['version'] == 2 + + # Verify 'unlink' logs for the keywords + unlink_logs = [log for log in new_logs if + log['entity'] == 'PromptKeywordLinks' and log['operation'] == 'unlink'] + assert len(unlink_logs) == num_keywords + for log in unlink_logs: + assert log['payload']['prompt_uuid'] == prompt_uuid + + +class TestDataIntegrityAndConcurrency: + """Tests for database constraints and thread safety.""" + + def test_add_prompt_with_conflicting_name_fails(self, db_instance: PromptsDatabase): + """ + Property: Adding a prompt with a name that already exists (and overwrite=False) + must raise a ConflictError. + """ + prompt_data = {"name": "Unique Prompt Name", "author": "Tester"} + db_instance.add_prompt(**prompt_data) + + # Attempt to add again with the same name + with pytest.raises(ConflictError): + db_instance.add_prompt(**prompt_data, overwrite=False) + + def test_update_prompt_to_conflicting_name_fails(self, db_instance: PromptsDatabase): + """ + Property: Updating a prompt's name to a name that is already used by + another active prompt must raise a ConflictError. + """ + p1_id, _, _ = db_instance.add_prompt(name="Prompt One", author="A") + db_instance.add_prompt(name="Prompt Two", author="B") # The conflicting name + + update_payload = {"name": "Prompt Two"} + with pytest.raises(ConflictError): + db_instance.update_prompt_by_id(p1_id, update_payload) + + def test_each_thread_gets_a_separate_connection(self, db_instance: PromptsDatabase): + """ + Property: The `get_connection` method must provide a unique + connection object for each thread, via threading.local. + """ + connection_ids = set() + lock = threading.Lock() + + def get_and_store_conn_id(): + conn = db_instance.get_connection() + with lock: + connection_ids.add(id(conn)) + + threads = [threading.Thread(target=get_and_store_conn_id) for _ in range(5)] + for t in threads: + t.start() + for t in threads: + t.join() + + # If threading.local is working, there should be 5 unique connection IDs. + assert len(connection_ids) == 5 + + def test_wal_mode_allows_concurrent_reads_during_write_transaction(self, db_instance: PromptsDatabase): + """ + Property: In WAL mode, one thread can read from the DB while another + thread has an open write transaction. + """ + prompt_id, _, _ = db_instance.add_prompt(name="Concurrent Read Test", details="Original") + + write_transaction_started = threading.Event() + read_result = [] + + def writer_thread(): + # The update method opens its own transaction + with db_instance.transaction(): + db_instance.execute_query("UPDATE Prompts SET details = 'Updated' WHERE id = ?", (prompt_id,)) + write_transaction_started.set() # Signal that the transaction is open + time.sleep(0.2) # Hold the transaction open + # Transaction commits here + + def reader_thread(): + write_transaction_started.wait() # Wait until the writer is in its transaction + # This read should succeed immediately and read the state BEFORE the commit. + prompt = db_instance.get_prompt_by_id(prompt_id) + read_result.append(prompt) + + w = threading.Thread(target=writer_thread) + r = threading.Thread(target=reader_thread) + + w.start() + r.start() + w.join() + r.join() + + # The reader thread should have completed successfully and read the *original* state. + assert len(read_result) == 1 + assert read_result[0] is not None + assert read_result[0]['details'] == "Original" # It read the state before the writer committed. + + +# --- State Machine Tests --- + +class PromptLifecycleMachine(RuleBasedStateMachine): + """ + Models the lifecycle of a single Prompt: create, update, delete. + Hypothesis will try to find sequences of these actions that break invariants. + """ + + def __init__(self): + super().__init__() + self.db = None # Injected by the test class fixture + # State for a single prompt's lifecycle + self.prompt_id = None + self.prompt_name = None + self.is_deleted = True + + prompts = Bundle('prompts') + + @rule(target=prompts, data=st_prompt_data()) + def create_prompt(self, data): + # We only want to test the lifecycle of one prompt per machine run for simplicity. + if self.prompt_id is not None: + return + + try: + new_id, _, _ = self.db.add_prompt(**data, overwrite=False) + except ConflictError: + # Hypothesis might generate a duplicate name. We treat this as "no action taken". + return + + self.prompt_id = new_id + self.prompt_name = data['name'].strip() + self.is_deleted = False + + retrieved = self.db.get_prompt_by_id(self.prompt_id) + assert retrieved is not None + assert retrieved['name'] == self.prompt_name + assert retrieved['version'] == 1 + return self.prompt_id + + @rule(prompt_id=prompts, update_data=st_prompt_data()) + def update_prompt(self, prompt_id, update_data): + if self.prompt_id is None or self.is_deleted: + return + + original_version = self.db.get_prompt_by_id(prompt_id, include_deleted=True)['version'] + + try: + self.db.update_prompt_by_id(prompt_id, update_data) + # If successful, update our internal state + self.prompt_name = update_data['name'].strip() + except ConflictError as e: + # This is a valid outcome if the new name is taken. + assert "already exists" in str(e) + # The state of our prompt hasn't changed. + return + + retrieved = self.db.get_prompt_by_id(self.prompt_id) + assert retrieved is not None + assert retrieved['version'] == original_version + 1 + assert retrieved['name'] == self.prompt_name + + @rule(prompt_id=prompts) + def soft_delete_prompt(self, prompt_id): + if self.prompt_id is None or self.is_deleted: + return + + original_version = self.db.get_prompt_by_id(prompt_id, include_deleted=True)['version'] + + success = self.db.soft_delete_prompt(prompt_id) + assert success + self.is_deleted = True + + # Verify it's gone from standard lookups + assert self.db.get_prompt_by_id(self.prompt_id) is None + assert self.db.get_prompt_by_name(self.prompt_name) is None + + # Verify its deleted state + raw_record = self.db.get_prompt_by_id(self.prompt_id, include_deleted=True) + assert raw_record['deleted'] == 1 + assert raw_record['version'] == original_version + 1 + + +# This is the actual test class that pytest discovers and runs. +# It inherits the rules and provides the `db_instance` fixture. +@settings(max_examples=50, stateful_step_count=20) +class TestPromptLifecycleAsTest(PromptLifecycleMachine): + + @pytest.fixture(autouse=True) + def inject_db(self, db_instance): + """Injects the clean db_instance fixture into the state machine.""" + self.db = db_instance \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 6aa9b042..6f6152aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,8 @@ transformers = ["transformers"] dev = [ # Example for development dependencies "pytest", "textual-dev", # For Textual development tools + "hypothesis", + "pytest_asyncio", ] diff --git a/tldw_chatbook/DB/ChaChaNotes_DB.py b/tldw_chatbook/DB/ChaChaNotes_DB.py index 6ba48738..f691864e 100644 --- a/tldw_chatbook/DB/ChaChaNotes_DB.py +++ b/tldw_chatbook/DB/ChaChaNotes_DB.py @@ -1901,6 +1901,7 @@ def search_character_cards(self, search_term: str, limit: int = 10) -> List[Dict Raises: CharactersRAGDBError: For database errors during the search. """ + safe_search_term = f'"{search_term}"' query = """ SELECT cc.* FROM character_cards_fts fts @@ -1914,7 +1915,7 @@ def search_character_cards(self, search_term: str, limit: int = 10) -> List[Dict rows = cursor.fetchall() return [self._deserialize_row_fields(row, self._CHARACTER_CARD_JSON_FIELDS) for row in rows if row] except CharactersRAGDBError as e: - logger.error(f"Error searching character cards for '{search_term}': {e}") + logger.error(f"Error searching character cards for '{safe_search_term}': {e}") raise # --- Conversation Methods --- @@ -2296,6 +2297,7 @@ def search_conversations_by_title(self, title_query: str, character_id: Optional if not title_query.strip(): logger.warning("Empty title_query provided for conversation search. Returning empty list.") return [] + safe_search_term = f'"{title_query}"' base_query = """ SELECT c.* FROM conversations_fts fts @@ -2315,7 +2317,7 @@ def search_conversations_by_title(self, title_query: str, character_id: Optional cursor = self.execute_query(base_query, tuple(params_list)) return [dict(row) for row in cursor.fetchall()] except CharactersRAGDBError as e: - logger.error(f"Error searching conversations for title '{title_query}': {e}") + logger.error(f"Error searching conversations for title '{safe_search_term}': {e}") raise # --- Message Methods --- @@ -2429,27 +2431,25 @@ def get_messages_for_conversation(self, conversation_id: str, limit: int = 100, order_by_timestamp: str = "ASC") -> List[Dict[str, Any]]: """ Lists messages for a specific conversation. - Returns non-deleted messages, ordered by `timestamp` according to `order_by_timestamp`. - Includes all fields, including `image_data` and `image_mime_type`. - - Args: - conversation_id: The UUID of the conversation. - limit: Maximum number of messages to return. Defaults to 100. - offset: Number of messages to skip. Defaults to 0. - order_by_timestamp: Sort order for 'timestamp' field ("ASC" or "DESC"). - Defaults to "ASC". - - Returns: - A list of message dictionaries. Can be empty. - - Raises: - InputError: If `order_by_timestamp` has an invalid value. - CharactersRAGDBError: For database errors. + Crucially, it also ensures the parent conversation is not soft-deleted. """ if order_by_timestamp.upper() not in ["ASC", "DESC"]: raise InputError("order_by_timestamp must be 'ASC' or 'DESC'.") - query = f"SELECT id, conversation_id, parent_message_id, sender, content, image_data, image_mime_type, timestamp, ranking, last_modified, version, client_id, deleted FROM messages WHERE conversation_id = ? AND deleted = 0 ORDER BY timestamp {order_by_timestamp} LIMIT ? OFFSET ?" # Explicitly list columns + + # The new query joins with conversations to check its 'deleted' status. + query = f""" + SELECT m.id, m.conversation_id, m.parent_message_id, m.sender, m.content, + m.image_data, m.image_mime_type, m.timestamp, m.ranking, + m.last_modified, m.version, m.client_id, m.deleted + FROM messages m + JOIN conversations c ON m.conversation_id = c.id + WHERE m.conversation_id = ? + AND m.deleted = 0 + AND c.deleted = 0 + ORDER BY m.timestamp {order_by_timestamp} + LIMIT ? OFFSET ? + """ try: cursor = self.execute_query(query, (conversation_id, limit, offset)) return [dict(row) for row in cursor.fetchall()] @@ -2667,6 +2667,7 @@ def search_messages_by_content(self, content_query: str, conversation_id: Option Raises: CharactersRAGDBError: For database search errors. """ + safe_search_term = f'"{content_query}"' base_query = """ SELECT m.* FROM messages_fts fts @@ -2686,7 +2687,7 @@ def search_messages_by_content(self, content_query: str, conversation_id: Option cursor = self.execute_query(base_query, tuple(params_list)) return [dict(row) for row in cursor.fetchall()] except CharactersRAGDBError as e: - logger.error(f"Error searching messages for content '{content_query}': {e}") + logger.error(f"Error searching messages for content '{safe_search_term}': {e}") raise # --- Keyword, KeywordCollection, Note Methods (CRUD + Search) --- @@ -3252,7 +3253,8 @@ def search_keywords(self, search_term: str, limit: int = 10) -> List[Dict[str, A Returns: A list of matching keyword dictionaries. """ - return self._search_generic_items_fts("keywords_fts", "keywords", "keyword", search_term, limit) + safe_search_term = f'"{search_term}"' + return self._search_generic_items_fts("keywords_fts", "keywords", "keyword", safe_search_term, limit) # Keyword Collections def add_keyword_collection(self, name: str, parent_id: Optional[int] = None) -> Optional[int]: @@ -3372,7 +3374,8 @@ def soft_delete_keyword_collection(self, collection_id: int, expected_version: i ) def search_keyword_collections(self, search_term: str, limit: int = 10) -> List[Dict[str, Any]]: - return self._search_generic_items_fts("keyword_collections_fts", "keyword_collections", "name", search_term, + safe_search_term = f'"{search_term}"' + return self._search_generic_items_fts("keyword_collections_fts", "keyword_collections", "name", safe_search_term, limit) # Notes (Now with UUID and specific methods) @@ -3536,19 +3539,21 @@ def soft_delete_note(self, note_id: str, expected_version: int) -> bool | None: def search_notes(self, search_term: str, limit: int = 10) -> List[Dict[str, Any]]: """Searches notes_fts (title and content). Corrected JOIN condition.""" - # notes_fts matches against title and content - # FTS table column group: notes_fts - # Content table: notes, content_rowid: rowid (maps to notes.rowid) + # FTS5 requires wrapping terms with special characters in double quotes + # to be treated as a literal phrase. + safe_search_term = f'"{search_term}"' + query = """ SELECT main.* FROM notes_fts fts - JOIN notes main ON fts.rowid = main.rowid -- Corrected Join condition - WHERE fts.notes_fts MATCH ? \ + JOIN notes main ON fts.rowid = main.rowid + WHERE fts.notes_fts MATCH ? AND main.deleted = 0 - ORDER BY rank LIMIT ? \ + ORDER BY rank LIMIT ? """ try: - cursor = self.execute_query(query, (search_term, limit)) + # Pass the quoted string as the parameter + cursor = self.execute_query(query, (safe_search_term, limit)) return [dict(row) for row in cursor.fetchall()] except CharactersRAGDBError as e: logger.error(f"Error searching notes for '{search_term}': {e}") @@ -3824,6 +3829,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): logger.debug( f"Exception in nested transaction block on thread {threading.get_ident()}: {exc_type.__name__}. Outermost transaction will handle rollback if this exception propagates.") + # Return False to re-raise any exceptions that occurred within the `with` block, # allowing them to be handled by the caller or to propagate further up. # This is standard behavior for context managers. diff --git a/tldw_chatbook/DB/Client_Media_DB_v2.py b/tldw_chatbook/DB/Client_Media_DB_v2.py index e6fccf54..a93785ba 100644 --- a/tldw_chatbook/DB/Client_Media_DB_v2.py +++ b/tldw_chatbook/DB/Client_Media_DB_v2.py @@ -1703,320 +1703,257 @@ def soft_delete_media(self, media_id: int, cascade: bool = True) -> bool: logger.error(f"Unexpected error soft deleting media ID {media_id}: {e}", exc_info=True) raise DatabaseError(f"Unexpected error during soft delete: {e}") from e - def add_media_with_keywords(self, - *, - url: Optional[str] = None, - title: Optional[str], - media_type: Optional[str], - content: Optional[str], - keywords: Optional[List[str]] = None, - prompt: Optional[str] = None, - analysis_content: Optional[str] = None, - transcription_model: Optional[str] = None, - author: Optional[str] = None, - ingestion_date: Optional[str] = None, - overwrite: bool = False, - chunk_options: Optional[Dict] = None, - chunks: Optional[List[Dict[str, Any]]] = None - ) -> Tuple[Optional[int], Optional[str], str]: - """ - Adds a new media item or updates an existing one based on URL or content hash. - - Handles creation or update of the Media record, generates a content hash, - associates keywords (adding them if necessary), creates an initial - DocumentVersion, logs appropriate sync events ('create' or 'update' for - Media, plus events from keyword and document version handling), and - updates the `media_fts` table. - - If `chunks` are provided, they are saved as UnvectorizedMediaChunks. - If updating an existing media item and new chunks are provided, old chunks - for that media item are hard-deleted before inserting the new ones. - - If an existing item is found (by URL or content hash) and `overwrite` is False, - the operation is skipped. If `overwrite` is True, the existing item is updated. - - Args: - url (Optional[str]): The URL of the media (unique). Generated if not provided. - title (Optional[str]): Title of the media. Defaults to 'Untitled'. - media_type (Optional[str]): Type of media (e.g., 'article', 'video'). Defaults to 'unknown'. - content (Optional[str]): The main text content. Required. - keywords (Optional[List[str]]): List of keyword strings to associate. - prompt (Optional[str]): Optional prompt associated with this version. - analysis_content (Optional[str]): Optional analysis/summary content. - transcription_model (Optional[str]): Model used for transcription, if applicable. - author (Optional[str]): Author of the media. - ingestion_date (Optional[str]): ISO 8601 formatted UTC timestamp for ingestion. - Defaults to current time if None. - overwrite (bool): If True, update the media item if it already exists. - Defaults to False (skip if exists). - chunk_options (Optional[Dict]): Placeholder for chunking parameters. - chunks (Optional[List[Dict[str, Any]]]): A list of dictionaries, where each dictionary - represents a chunk of the media content. Expected keys - per dictionary: 'text' (str, required), and optional - 'start_char' (int), 'end_char' (int), - 'chunk_type' (str), 'metadata' (dict). - - Returns: - Tuple[Optional[int], Optional[str], str]: A tuple containing: - - media_id (Optional[int]): The ID of the added/updated media item. - - media_uuid (Optional[str]): The UUID of the added/updated media item. - - message (str): A status message indicating the action taken - ("added", "updated", "already_exists_skipped"). - - Raises: - InputError: If `content` is None or required chunk data is malformed. - ConflictError: If an update fails due to a version mismatch. - DatabaseError: For underlying database issues or errors during sync/FTS logging. - """ + def add_media_with_keywords( + self, + *, + url: Optional[str] = None, + title: Optional[str] = None, + media_type: Optional[str] = None, + content: Optional[str] = None, + keywords: Optional[List[str]] = None, + prompt: Optional[str] = None, + analysis_content: Optional[str] = None, + transcription_model: Optional[str] = None, + author: Optional[str] = None, + ingestion_date: Optional[str] = None, + overwrite: bool = False, + chunk_options: Optional[Dict] = None, + chunks: Optional[List[Dict[str, Any]]] = None, + ) -> Tuple[Optional[int], Optional[str], str]: + """Add or update a media record, handle keyword links, optional chunks and full-text sync.""" + + # --------------------------------------------------------------------- + # 1. Fast‑fail validation & normalisation + # --------------------------------------------------------------------- if content is None: raise InputError("Content cannot be None.") - title = title or 'Untitled' - media_type = media_type or 'unknown' - keywords_list = [k.strip().lower() for k in keywords if k and k.strip()] if keywords else [] - # Get current time and client ID - current_time = self._get_current_utc_timestamp_str() - client_id = self.client_id + title = title or "Untitled" + media_type = media_type or "unknown" + keywords_norm = [k.strip().lower() for k in keywords or [] if k and k.strip()] - # Handle ingestion_date: Use provided, else generate now. Use full timestamp. - ingestion_date_str = ingestion_date or current_time + now = self._get_current_utc_timestamp_str() + ingestion_date = ingestion_date or now + client_id = self.client_id content_hash = hashlib.sha256(content.encode()).hexdigest() - if not url: - url = f"local://{media_type}/{content_hash}" + url = url or f"local://{media_type}/{content_hash}" + + # Determine the final chunk status before any DB operation + final_chunk_status = "completed" if chunks is not None else "pending" + + logging.info("add_media_with_keywords: url=%s, title=%s, client=%s", url, title, client_id) + + # ------------------------------------------------------------------ + # Helper builders + # ------------------------------------------------------------------ + def _media_payload(uuid_: str, version_: int, *, chunk_status: str) -> Dict[str, Any]: + """Return a dict suitable for INSERT/UPDATE parameters and for sync logging.""" + return { + "url": url, + "title": title, + "type": media_type, + "content": content, + "author": author, + "ingestion_date": ingestion_date, + "transcription_model": transcription_model, + "content_hash": content_hash, + "is_trash": 0, + "trash_date": None, + "chunking_status": chunk_status, + "vector_processing": 0, + "uuid": uuid_, + "last_modified": now, + "version": version_, + "client_id": client_id, + "deleted": 0, + } - logging.info(f"Processing add/update for: URL='{url}', Title='{title}', Client='{client_id}'") + def _persist_chunks(cnx: sqlite3.Connection, media_id: int) -> None: + """Delete/insert un-vectorized chunks as requested. DOES NOT update parent Media.""" + if chunks is None: + return # caller did not touch chunks - media_id: Optional[int] = None - media_uuid: Optional[str] = None - action = "skipped" # Default action + if overwrite: + cnx.execute("DELETE FROM UnvectorizedMediaChunks WHERE media_id = ?", (media_id,)) + + if not chunks: # empty list means just wipe + return - # Determine initial chunking_status based on presence of 'chunks' argument - # This will be used when preparing insert_data or update_data for the Media table - initial_chunking_status = "pending" - if chunks is not None: # chunks is an empty list or a list with items - initial_chunking_status = "processing" + created = self._get_current_utc_timestamp_str() + for idx, ch in enumerate(chunks): + if not isinstance(ch, dict) or ch.get("text") is None: + logging.warning("Skipping invalid chunk index %s for media_id %s", idx, media_id) + continue + + chunk_uuid = self._generate_uuid() + cnx.execute( + """INSERT INTO UnvectorizedMediaChunks (media_id, chunk_text, chunk_index, start_char, end_char, + chunk_type, creation_date, last_modified_orig, is_processed, + metadata, uuid, last_modified, version, client_id, deleted, + prev_version, merge_parent_uuid) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", + ( + media_id, ch["text"], idx, ch.get("start_char"), ch.get("end_char"), ch.get("chunk_type"), + created, created, False, + json.dumps(ch.get("metadata")) if isinstance(ch.get("metadata"), dict) else None, + chunk_uuid, created, 1, client_id, 0, None, None, + ), + ) + self._log_sync_event( + cnx, "UnvectorizedMediaChunks", chunk_uuid, "create", 1, + { + **ch, "media_id": media_id, "uuid": chunk_uuid, "chunk_index": idx, + "creation_date": created, "last_modified": created, "version": 1, + "client_id": client_id, "deleted": 0, + }, + ) + # ------------------------------------------------------------------ + # 2. Main transactional block + # ------------------------------------------------------------------ try: with self.transaction() as conn: - cursor = conn.cursor() - cursor.execute('SELECT id, uuid, version FROM Media WHERE (url = ? OR content_hash = ?) AND deleted = 0 LIMIT 1', (url, content_hash)) - existing_media = cursor.fetchone() - media_id, media_uuid, action = None, None, "skipped" + cur = conn.cursor() + + # Find existing record by URL or content_hash + cur.execute( + "SELECT id, uuid, version, url, content_hash FROM Media WHERE url = ? AND deleted = 0 LIMIT 1", + (url,), + ) + row = cur.fetchone() + + if not row: + cur.execute( + "SELECT id, uuid, version, url, content_hash FROM Media WHERE content_hash = ? AND deleted = 0 LIMIT 1", + (content_hash,), + ) + row = cur.fetchone() + + # --- Path A: Record exists, handle UPDATE, CANONICALIZATION, or SKIP --- + if row: + media_id = row["id"] + media_uuid = row["uuid"] + current_ver = row["version"] + existing_url = row["url"] + existing_hash = row["content_hash"] - if existing_media: - media_id, media_uuid, current_version = existing_media['id'], existing_media['uuid'], existing_media['version'] + # Case A.1: Overwrite is requested. if overwrite: - action = "updated" - new_version = current_version + 1 - logger.info(f"Updating existing media ID {media_id} (UUID: {media_uuid}) to version {new_version}.") - update_data = { # Prepare dict for easier payload generation - 'url': url, - 'title': title, - 'type': media_type, - 'content': content, - 'author': author, - 'ingestion_date': ingestion_date_str, - 'transcription_model': transcription_model, - 'content_hash': content_hash, - 'is_trash': 0, - 'trash_date': None, # Ensure trash_date is None here - 'chunking_status': "pending", - 'vector_processing': 0, - 'last_modified': current_time, # Set last_modified - 'version': new_version, - 'client_id': client_id, - 'deleted': 0, - 'uuid': media_uuid - } - cursor.execute( - """UPDATE Media SET url=?, title=?, type=?, content=?, author=?, ingestion_date=?, - transcription_model=?, content_hash=?, is_trash=?, trash_date=?, chunking_status=?, - vector_processing=?, last_modified=?, version=?, client_id=?, deleted=? - WHERE id=? AND version=?""", - (update_data['url'], update_data['title'], update_data['type'], update_data['content'], - update_data['author'], update_data['ingestion_date'], update_data['transcription_model'], - update_data['content_hash'], update_data['is_trash'], update_data['trash_date'], # Pass None for trash_date - update_data['chunking_status'], update_data['vector_processing'], - update_data['last_modified'], # Pass current_time - update_data['version'], update_data['client_id'], update_data['deleted'], - media_id, current_version) + # Case A.1.a: Content is identical. No version bump needed for main content. + if content_hash == existing_hash: + logging.info(f"Media content for ID {media_id} is identical. Updating metadata/chunks only.") + + # Update keywords and chunks without changing the main Media record yet. + self.update_keywords_for_media(media_id, keywords_norm) + _persist_chunks(conn, media_id) + + # If new chunks were provided, the media's chunking status has changed, + # which justifies a version bump on the parent Media record. + if chunks is not None: + logging.info(f"Chunks provided for identical media; updating media chunk_status and version for ID {media_id}.") + new_ver = current_ver + 1 + cur.execute( + """UPDATE Media SET chunking_status = 'completed', version = ?, last_modified = ? + WHERE id = ? AND version = ?""", + (new_ver, now, media_id, current_ver) + ) + if cur.rowcount == 0: + raise ConflictError(f"Media (updating chunk status for identical content id={media_id})", media_id) + + self._log_sync_event(conn, "Media", media_uuid, "update", new_ver, {"chunking_status": "completed", "last_modified": now}) + + return media_id, media_uuid, f"Media '{title}' is already up-to-date." + + # Case A.1.b: Content is different. Proceed with a full versioned update. + new_ver = current_ver + 1 + payload = _media_payload(media_uuid, new_ver, chunk_status=final_chunk_status) + cur.execute( + """UPDATE Media + SET url=:url, title=:title, type=:type, content=:content, author=:author, + ingestion_date=:ingestion_date, transcription_model=:transcription_model, + content_hash=:content_hash, is_trash=:is_trash, trash_date=:trash_date, + chunking_status=:chunking_status, vector_processing=:vector_processing, + last_modified=:last_modified, version=:version, client_id=:client_id, deleted=:deleted + WHERE id = :id AND version = :ver""", + {**payload, "id": media_id, "ver": current_ver}, ) - if cursor.rowcount == 0: - raise ConflictError("Media", media_id) + if cur.rowcount == 0: + raise ConflictError(f"Media (full update id={media_id})", media_id) - # Use the update_data dict directly for the payload - self._log_sync_event(conn, 'Media', media_uuid, 'update', new_version, update_data) - self._update_fts_media(conn, media_id, update_data['title'], update_data['content']) - - # Consolidate keyword and version creation here for "updated" - self.update_keywords_for_media(media_id, keywords_list) # Manages its own logs - # Create a new document version representing this update + self._log_sync_event(conn, "Media", media_uuid, "update", new_ver, payload) + self._update_fts_media(conn, media_id, payload["title"], payload["content"]) + self.update_keywords_for_media(media_id, keywords_norm) self.create_document_version( - media_id=media_id, - content=content, # Use the new content for the version - prompt=prompt, - analysis_content=analysis_content + media_id=media_id, content=content, prompt=prompt, analysis_content=analysis_content ) + _persist_chunks(conn, media_id) + return media_id, media_uuid, f"Media '{title}' updated to new version." + + # Case A.2: Overwrite is FALSE. else: - action = "already_exists_skipped" - else: # Not existing_media - action = "added" + is_canonicalisation = ( + existing_url.startswith("local://") + and not url.startswith("local://") + and content_hash == existing_hash + ) + if is_canonicalisation: + logging.info(f"Canonicalizing URL for media_id {media_id} to {url}") + new_ver = current_ver + 1 + cur.execute( + "UPDATE Media SET url = ?, last_modified = ?, version = ?, client_id = ? WHERE id = ? AND version = ?", + (url, now, new_ver, client_id, media_id, current_ver), + ) + if cur.rowcount == 0: + raise ConflictError(f"Media (canonicalization id={media_id})", media_id) + + self._log_sync_event( + conn, "Media", media_uuid, "update", new_ver, {"url": url, "last_modified": now} + ) + return media_id, media_uuid, f"Media '{title}' URL canonicalized." + + return None, None, f"Media '{title}' already exists. Overwrite not enabled." + + # --- Path B: Record does not exist, perform INSERT --- + else: media_uuid = self._generate_uuid() - new_version = 1 - logger.info(f"Inserting new media '{title}' with UUID {media_uuid}.") - insert_data = { # Prepare dict for easier payload generation - 'url': url, - 'title': title, - 'type': media_type, - 'content': content, - 'author': author, - 'ingestion_date': ingestion_date_str, # Use generated/passed ingestion_date - 'transcription_model': transcription_model, - 'content_hash': content_hash, - 'is_trash': 0, - 'trash_date': None, # trash_date is NULL on creation - 'chunking_status': initial_chunking_status, - 'vector_processing': 0, - 'uuid': media_uuid, - 'last_modified': current_time, # Set last_modified - 'version': new_version, - 'client_id': client_id, - 'deleted': 0 - } - cursor.execute( - """INSERT INTO Media (url, title, type, content, author, ingestion_date, transcription_model, - content_hash, is_trash, trash_date, chunking_status, vector_processing, uuid, - last_modified, version, client_id, deleted) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", - (insert_data['url'], insert_data['title'], insert_data['type'], insert_data['content'], - insert_data['author'], insert_data['ingestion_date'], # Pass ingestion_date_str - insert_data['transcription_model'], insert_data['content_hash'], insert_data['is_trash'], - insert_data['trash_date'], # Pass None for trash_date - insert_data['chunking_status'], insert_data['vector_processing'], insert_data['uuid'], - insert_data['last_modified'], # Pass current_time - insert_data['version'], insert_data['client_id'], insert_data['deleted']) + payload = _media_payload(media_uuid, 1, chunk_status=final_chunk_status) + + cur.execute( + """INSERT INTO Media (url, title, type, content, author, ingestion_date, + transcription_model, content_hash, is_trash, trash_date, + chunking_status, vector_processing, uuid, last_modified, + version, client_id, deleted) + VALUES (:url, :title, :type, :content, :author, :ingestion_date, + :transcription_model, :content_hash, :is_trash, :trash_date, + :chunking_status, :vector_processing, :uuid, :last_modified, + :version, :client_id, :deleted)""", + payload, ) - media_id = cursor.lastrowid + media_id = cur.lastrowid if not media_id: - raise DatabaseError("Failed to get last row ID for new media.") - - # Use the insert_data dict directly for the payload - self._log_sync_event(conn, 'Media', media_uuid, 'create', new_version, insert_data) - self._update_fts_media(conn, media_id, insert_data['title'], insert_data['content']) + raise DatabaseError("Failed to obtain new media ID.") - # Consolidate keyword and version creation here for "added" - self.update_keywords_for_media(media_id, keywords_list) + self._log_sync_event(conn, "Media", media_uuid, "create", 1, payload) + self._update_fts_media(conn, media_id, payload["title"], payload["content"]) + self.update_keywords_for_media(media_id, keywords_norm) self.create_document_version( - media_id=media_id, - content=content, - prompt=prompt, - analysis_content=analysis_content + media_id=media_id, content=content, prompt=prompt, analysis_content=analysis_content ) - - # --- Handle Unvectorized Chunks --- - if chunks is not None: # chunks argument was provided (could be empty or list of dicts) - if action == "updated": - # If overwriting and new chunks are provided, clear old ones. - # If `chunks` is an empty list, it also means clear old ones. - if overwrite: # Only delete if overwrite is true - logging.debug( - f"Hard deleting existing UnvectorizedMediaChunks for updated media_id {media_id} due to overwrite and new chunks being provided.") - conn.execute("DELETE FROM UnvectorizedMediaChunks WHERE media_id = ?", (media_id,)) - - num_chunks_saved = 0 - if chunks: # If chunks list is not empty - chunk_creation_time = self._get_current_utc_timestamp_str() - for i, chunk_data in enumerate(chunks): - if not isinstance(chunk_data, dict) or 'text' not in chunk_data or chunk_data['text'] is None: - logging.warning( - f"Skipping invalid chunk data at index {i} for media_id {media_id} " - f"(missing 'text' or not a dict): {str(chunk_data)[:100]}" - ) - continue - - chunk_item_uuid = self._generate_uuid() - metadata_payload = chunk_data.get('metadata') - - chunk_record_tuple = ( - media_id, - chunk_data['text'], - i, # chunk_index - chunk_data.get('start_char'), - chunk_data.get('end_char'), - chunk_data.get('chunk_type'), - chunk_creation_time, # creation_date - chunk_creation_time, # last_modified_orig - False, # is_processed - json.dumps(metadata_payload) if isinstance(metadata_payload, dict) else None, # metadata - chunk_item_uuid, - chunk_creation_time, # last_modified - 1, # version - self.client_id, - 0, # deleted - None, # prev_version - None # merge_parent_uuid - ) - try: - conn.execute(""" - INSERT INTO UnvectorizedMediaChunks ( - media_id, chunk_text, chunk_index, start_char, end_char, chunk_type, - creation_date, last_modified_orig, is_processed, metadata, uuid, - last_modified, version, client_id, deleted, prev_version, merge_parent_uuid - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, chunk_record_tuple) - - # Construct payload for sync log (SQLite row is not directly usable as dict here) - sync_payload_chunk = { - 'media_id': media_id, 'chunk_text': chunk_data['text'], 'chunk_index': i, - 'start_char': chunk_data.get('start_char'), 'end_char': chunk_data.get('end_char'), - 'chunk_type': chunk_data.get('chunk_type'), 'creation_date': chunk_creation_time, - 'last_modified_orig': chunk_creation_time, 'is_processed': False, - 'metadata': metadata_payload, # Store original dict, not JSON string, in sync log if possible - 'uuid': chunk_item_uuid, 'last_modified': chunk_creation_time, - 'version': 1, 'client_id': self.client_id, 'deleted': 0 - } - self._log_sync_event(conn, 'UnvectorizedMediaChunks', chunk_item_uuid, 'create', 1, sync_payload_chunk) - num_chunks_saved += 1 - except sqlite3.IntegrityError as e: - logging.error(f"Integrity error saving chunk {i} (type: {chunk_data.get('chunk_type')}) for media_id {media_id}: {e}. Data: {chunk_data['text'][:50]}...") - raise DatabaseError(f"Failed to save chunk {i} due to integrity constraint: {e}") from e - logging.info(f"Saved {num_chunks_saved} unvectorized chunks for media_id {media_id}.") - - # Update Media chunking_status - # If chunks were provided (even an empty list, meaning "clear existing and add these (none)"), - # then chunking is considered 'completed' from the perspective of this operation. - # If `chunks` was None (meaning "don't touch existing chunks"), status remains as is or 'pending'. - final_chunking_status_for_media = 'completed' # if chunks is not None - # If the main `perform_chunking` flag (from request, not DB field) was false, - # then perhaps status should be different. For now, if chunks data is passed, it's 'completed'. - # This might need more nuanced logic based on the `perform_chunking` flag from the original request. - conn.execute("UPDATE Media SET chunking_status = ? WHERE id = ?", - (final_chunking_status_for_media, media_id,)) - logging.debug( - f"Updated Media chunking_status to '{final_chunking_status_for_media}' for media_id {media_id} after chunk processing.") - - # Original chunk_options placeholder log + _persist_chunks(conn, media_id) if chunk_options: - logging.info(f"Chunking logic placeholder (using chunk_options) for media {media_id}") + logging.info("chunk_options ignored (placeholder): %s", chunk_options) + + return media_id, media_uuid, f"Media '{title}' added." + + except (InputError, ConflictError, sqlite3.IntegrityError) as e: + # Catch the specific IntegrityError from the trigger and re-raise as a more descriptive error if you want + logging.error(f"Transaction failed, rolling back: {type(e).__name__} - {e}") + raise # Re-raise the original exception + except Exception as exc: + logging.error(f"Unexpected error in transaction: {type(exc).__name__} - {exc}") + raise DatabaseError(f"Unexpected error processing media: {exc}") from exc - # Determine message based on action (outside transaction) - if action == "updated": - message = f"Media '{title}' updated." - elif action == "added": - message = f"Media '{title}' added." - else: - message = f"Media '{title}' exists, not overwritten." - return media_id, media_uuid, message - except (InputError, ConflictError, DatabaseError, sqlite3.Error) as e: - logger.error(f"Error processing media (URL: {url}): {e}", exc_info=isinstance(e, (DatabaseError, sqlite3.Error))) - if isinstance(e, (InputError, ConflictError, DatabaseError)): - raise e - else: - raise DatabaseError(f"Failed to process media: {e}") from e - except Exception as e: - logger.error(f"Unexpected error processing media (URL: {url}): {e}", exc_info=True) - raise DatabaseError(f"Unexpected error processing media: {e}") from e def create_document_version(self, media_id: int, content: str, prompt: Optional[str] = None, analysis_content: Optional[str] = None) -> Dict[str, Any]: """ @@ -3855,7 +3792,7 @@ def check_database_integrity(db_path): # Standalone check is fine db_path (str): The path to the SQLite database file. Returns: - bool: True if the integrity check returns 'ok', False otherwise or if + bool: True if the integrity check returns 'ok', False otherwise, or if an error occurs during the check. """ logger.info(f"Checking integrity of database: {db_path}") diff --git a/tldw_chatbook/DB/Prompts_DB.py b/tldw_chatbook/DB/Prompts_DB.py index d46cf9ab..a91a15e8 100644 --- a/tldw_chatbook/DB/Prompts_DB.py +++ b/tldw_chatbook/DB/Prompts_DB.py @@ -1121,7 +1121,8 @@ def soft_delete_keyword(self, keyword_text: str) -> bool: def get_prompt_by_id(self, prompt_id: int, include_deleted: bool = False) -> Optional[Dict]: query = "SELECT * FROM Prompts WHERE id = ?" params = [prompt_id] - if not include_deleted: query += " AND deleted = 0" + if not include_deleted: + query += " AND deleted = 0" try: cursor = self.execute_query(query, tuple(params)) result = cursor.fetchone() @@ -1133,7 +1134,8 @@ def get_prompt_by_id(self, prompt_id: int, include_deleted: bool = False) -> Opt def get_prompt_by_uuid(self, prompt_uuid: str, include_deleted: bool = False) -> Optional[Dict]: query = "SELECT * FROM Prompts WHERE uuid = ?" params = [prompt_uuid] - if not include_deleted: query += " AND deleted = 0" + if not include_deleted: + query += " AND deleted = 0" try: cursor = self.execute_query(query, tuple(params)) result = cursor.fetchone() @@ -1145,7 +1147,8 @@ def get_prompt_by_uuid(self, prompt_uuid: str, include_deleted: bool = False) -> def get_prompt_by_name(self, name: str, include_deleted: bool = False) -> Optional[Dict]: query = "SELECT * FROM Prompts WHERE name = ?" params = [name] - if not include_deleted: query += " AND deleted = 0" + if not include_deleted: + query += " AND deleted = 0" try: cursor = self.execute_query(query, tuple(params)) result = cursor.fetchone() @@ -1231,7 +1234,7 @@ def fetch_keywords_for_prompt(self, prompt_id: int, include_deleted: bool = Fals def search_prompts(self, search_query: Optional[str], - search_fields: Optional[List[str]] = None, # e.g. ['name', 'details', 'keywords'] + search_fields: Optional[List[str]] = None, # e.g. ['name', 'details', 'keywords'] page: int = 1, results_per_page: int = 20, include_deleted: bool = False @@ -1240,96 +1243,86 @@ def search_prompts(self, if results_per_page < 1: raise ValueError("Results per page must be >= 1") if search_query and not search_fields: - search_fields = ["name", "details", "system_prompt", "user_prompt", "author"] # Default FTS fields + search_fields = ["name", "details", "system_prompt", "user_prompt", "author"] elif not search_fields: search_fields = [] offset = (page - 1) * results_per_page - base_select_parts = ["p.id", "p.uuid", "p.name", "p.author", "p.details", - "p.system_prompt", "p.user_prompt", "p.last_modified", "p.version", "p.deleted"] - count_select = "COUNT(DISTINCT p.id)" - base_from = "FROM Prompts p" - joins = [] + base_select = "SELECT p.*" + count_select = "SELECT COUNT(p.id)" + from_clause = "FROM Prompts p" conditions = [] params = [] if not include_deleted: conditions.append("p.deleted = 0") - fts_search_active = False - if search_query: - fts_query_parts = [] - if "name" in search_fields: fts_query_parts.append("name") - if "author" in search_fields: fts_query_parts.append("author") - if "details" in search_fields: fts_query_parts.append("details") - if "system_prompt" in search_fields: fts_query_parts.append("system_prompt") - if "user_prompt" in search_fields: fts_query_parts.append("user_prompt") - - # FTS on prompt fields - if fts_query_parts: - fts_search_active = True - if not any("prompts_fts fts_p" in j_item for j_item in joins): - joins.append("JOIN prompts_fts fts_p ON fts_p.rowid = p.id") - # Build FTS query: field1:query OR field2:query ... - # For simple matching, just use the query directly if FTS table covers all these. - # The FTS table definition needs to match these fields. - # Assuming prompts_fts has 'name', 'author', 'details', 'system_prompt', 'user_prompt' - conditions.append("fts_p.prompts_fts MATCH ?") - params.append(search_query) # User provides FTS syntax or simple terms - - # FTS on keywords (if specified in search_fields) + # --- Robust FTS search using subqueries --- + if search_query and search_fields: + matching_prompt_ids = set() + text_search_fields = {"name", "author", "details", "system_prompt", "user_prompt"} + + # Search in prompt text fields + if any(field in text_search_fields for field in search_fields): + try: + cursor = self.execute_query("SELECT rowid FROM prompts_fts WHERE prompts_fts MATCH ?", (search_query,)) + matching_prompt_ids.update(row['rowid'] for row in cursor.fetchall()) + except sqlite3.Error as e: + logging.error(f"FTS search on prompts failed: {e}", exc_info=True) + raise DatabaseError(f"FTS search on prompts failed: {e}") from e + + + # Search in keywords if "keywords" in search_fields: - fts_search_active = True - # Join for keywords - if not any("PromptKeywordLinks pkl" in j_item for j_item in joins): - joins.append("JOIN PromptKeywordLinks pkl ON p.id = pkl.prompt_id") - if not any("PromptKeywordsTable pkw" in j_item for j_item in joins): - joins.append("JOIN PromptKeywordsTable pkw ON pkl.keyword_id = pkw.id AND pkw.deleted = 0") - if not any("prompt_keywords_fts fts_k" in j_item for j_item in joins): - joins.append("JOIN prompt_keywords_fts fts_k ON fts_k.rowid = pkw.id") - - conditions.append("fts_k.prompt_keywords_fts MATCH ?") - params.append(search_query) # Match against keywords - - order_by_clause_str = "ORDER BY p.last_modified DESC, p.id DESC" - if fts_search_active: - # FTS results are naturally sorted by relevance (rank) by SQLite. - # We can select rank if needed for explicit sorting or display. - if "fts_p.rank AS relevance_score" not in " ".join(base_select_parts) and "fts_p" in " ".join(joins) : - base_select_parts.append("fts_p.rank AS relevance_score") # Add if fts_p is used - elif "fts_k.rank AS relevance_score_kw" not in " ".join(base_select_parts) and "fts_k" in " ".join(joins): - base_select_parts.append("fts_k.rank AS relevance_score_kw") # Add if fts_k is used - # A more complex ranking might be needed if both prompt and keyword FTS are active. - # For now, default sort or rely on SQLite's combined FTS rank if multiple MATCH clauses are used. - order_by_clause_str = "ORDER BY p.last_modified DESC, p.id DESC" # Fallback, FTS rank is implicit - - final_select_stmt = f"SELECT DISTINCT {', '.join(base_select_parts)}" - join_clause = " ".join(list(dict.fromkeys(joins))) # Unique joins - where_clause = "WHERE " + " AND ".join(conditions) if conditions else "" + try: + # 1. Find keyword IDs matching the query + kw_cursor = self.execute_query("SELECT rowid FROM prompt_keywords_fts WHERE prompt_keywords_fts MATCH ?", (search_query,)) + matching_keyword_ids = {row['rowid'] for row in kw_cursor.fetchall()} + + # 2. Find prompt IDs linked to those keywords + if matching_keyword_ids: + placeholders = ','.join('?' * len(matching_keyword_ids)) + link_cursor = self.execute_query( + f"SELECT DISTINCT prompt_id FROM PromptKeywordLinks WHERE keyword_id IN ({placeholders})", + tuple(matching_keyword_ids) + ) + matching_prompt_ids.update(row['prompt_id'] for row in link_cursor.fetchall()) + except sqlite3.Error as e: + logging.error(f"FTS search on keywords failed: {e}", exc_info=True) + raise DatabaseError(f"FTS search on keywords failed: {e}") from e + + if not matching_prompt_ids: + return [], 0 # No matches found, short-circuit + + # Add the final ID list to the main query conditions + id_placeholders = ','.join('?' * len(matching_prompt_ids)) + conditions.append(f"p.id IN ({id_placeholders})") + params.extend(list(matching_prompt_ids)) + + # --- Build and Execute Final Query --- + where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else "" + order_by_clause = "ORDER BY p.last_modified DESC, p.id DESC" try: - count_sql = f"SELECT {count_select} {base_from} {join_clause} {where_clause}" - count_cursor = self.execute_query(count_sql, tuple(params)) - total_matches = count_cursor.fetchone()[0] + # Get total count + count_sql = f"{count_select} {from_clause} {where_clause}" + total_matches = self.execute_query(count_sql, tuple(params)).fetchone()[0] results_list = [] - if total_matches > 0 and offset < total_matches: - results_sql = f"{final_select_stmt} {base_from} {join_clause} {where_clause} {order_by_clause_str} LIMIT ? OFFSET ?" + if total_matches > 0: + # Get paginated results + results_sql = f"{base_select} {from_clause} {where_clause} {order_by_clause} LIMIT ? OFFSET ?" paginated_params = tuple(params + [results_per_page, offset]) results_cursor = self.execute_query(results_sql, paginated_params) results_list = [dict(row) for row in results_cursor.fetchall()] - # If keywords need to be attached to each result + # Attach keywords to each result for res_dict in results_list: res_dict['keywords'] = self.fetch_keywords_for_prompt(res_dict['id'], include_deleted=False) return results_list, total_matches - except sqlite3.Error as e: - if "no such table: prompts_fts" in str(e).lower() or "no such table: prompt_keywords_fts" in str(e).lower(): - logging.error(f"FTS table missing in {self.db_path_str}. Search may fail or be incomplete.") - # Fallback to LIKE search or raise error - # For now, let it fail and be caught by generic error. - logging.error(f"DB error during prompt search in '{self.db_path_str}': {e}", exc_info=True) + except (DatabaseError, sqlite3.Error) as e: + logging.error(f"DB error during prompt search: {e}", exc_info=True) raise DatabaseError(f"Failed to search prompts: {e}") from e # --- Sync Log Access Methods --- @@ -1355,6 +1348,7 @@ def get_sync_log_entries(self, since_change_id: int = 0, limit: Optional[int] = logger.error(f"Error fetching sync_log entries: {e}") raise DatabaseError("Failed to fetch sync_log entries") from e + def delete_sync_log_entries(self, change_ids: List[int]) -> int: if not change_ids: return 0 if not all(isinstance(cid, int) for cid in change_ids): diff --git a/tldw_chatbook/DB/Sync_Client.py b/tldw_chatbook/DB/Sync_Client.py index 3ef19b52..d13f0add 100644 --- a/tldw_chatbook/DB/Sync_Client.py +++ b/tldw_chatbook/DB/Sync_Client.py @@ -13,10 +13,7 @@ # Third-Party Imports # # Local Imports -try: - from tldw_cli.tldw_app.DB.Media_DB import Database, ConflictError, DatabaseError, InputError -except ImportError: - logger.error("ERROR: Could not import the 'Media_DB' library. Make sure Media_DB.py is accessible.") +from tldw_chatbook.DB.Client_Media_DB_v2 import MediaDatabase as Database, ConflictError, DatabaseError, InputError # ####################################################################################################################### # diff --git a/tldw_chatbook/Event_Handlers/LLM_Management_Events/llm_management_events_ollama.py b/tldw_chatbook/Event_Handlers/LLM_Management_Events/llm_management_events_ollama.py index c60cf202..4ff3c099 100644 --- a/tldw_chatbook/Event_Handlers/LLM_Management_Events/llm_management_events_ollama.py +++ b/tldw_chatbook/Event_Handlers/LLM_Management_Events/llm_management_events_ollama.py @@ -16,7 +16,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Optional -from cv2 import data +#from cv2 import data from textual.containers import Container from textual.css.query import QueryError from textual.widgets import Input, TextArea, RichLog @@ -33,7 +33,8 @@ __all__ = [ # ─── Ollama ─────────────────────────────────────────────────────────────── "handle_ollama_nav_button_pressed", - "handle_ollama_list_models_button_pressed", + # FIXME + #"handle_ollama_list_models_button_pressed", "handle_ollama_show_model_button_pressed", "handle_ollama_delete_model_button_pressed", "handle_ollama_copy_model_button_pressed", @@ -106,57 +107,58 @@ async def handle_ollama_nav_button_pressed(app: "TldwCli") -> None: app.notify("An unexpected error occurred while switching to Ollama view.", severity="error") -async def handle_ollama_list_models_button_pressed(app: "TldwCli") -> None: - """Handles the 'List Models' button press for Ollama.""" - logger = getattr(app, "loguru_logger", logging.getLogger(__name__)) - logger.debug("Ollama 'List Models' button pressed.") - try: - base_url_input = app.query_one("#ollama-server-url", Input) - log_output_widget = app.query_one("#ollama-combined-output", RichLog) - - base_url = base_url_input.value.strip() - if not base_url: - app.notify("Ollama Server URL is required.", severity="error") - base_url_input.focus() - return - - log_output_widget.clear() - _update_ollama_combined_output(app, f"Attempting to list models from: {base_url}...") - - app.run_worker( - _worker_ollama_list_models, - base_url, - thread=True, - name=f"ollama_list_models_{time.monotonic()}", - group="ollama_api", - description="Listing Ollama local models", - on_success=partial(_on_list_models_success, app), - on_error=partial(_on_ollama_worker_error, app, "list_models") - ) - if logging.error: - log_output_widget.write(f"Error listing models: {logging.error}") - - if logging.error: # This is the original error check, the one above is newly added by script - log_output_widget.write(f"Error listing models: {logging.error}") - app.notify("Error listing Ollama models.", severity="error") - elif data and data.get('models'): - try: - # Assuming 'data' is the JSON response, and 'models' is a list within it. - formatted_models = json.dumps(data['models'], indent=2) - log_output_widget.write(formatted_models) - app.notify(f"Successfully listed {len(data['models'])} Ollama models.") - except (TypeError, KeyError, json.JSONDecodeError) as e: - log_output_widget.write(f"Error processing model list response: {e}\nRaw data: {data}") - app.notify("Error processing model list from Ollama.", severity="error") - else: - log_output_widget.write("No models found or unexpected response.") - app.notify("No Ollama models found or unexpected response.", severity="warning") - except QueryError as e: # pragma: no cover - logger.error(f"QueryError in handle_ollama_list_models_button_pressed: {e}", exc_info=True) - app.notify("Error accessing Ollama UI elements for listing models.", severity="error") - except Exception as e: # pragma: no cover - logger.error(f"Unexpected error in handle_ollama_list_models_button_pressed: {e}", exc_info=True) - app.notify("An unexpected error occurred while listing Ollama models.", severity="error") +# FIXME +# async def handle_ollama_list_models_button_pressed(app: "TldwCli") -> None: +# """Handles the 'List Models' button press for Ollama.""" +# logger = getattr(app, "loguru_logger", logging.getLogger(__name__)) +# logger.debug("Ollama 'List Models' button pressed.") +# try: +# base_url_input = app.query_one("#ollama-server-url", Input) +# log_output_widget = app.query_one("#ollama-combined-output", RichLog) +# +# base_url = base_url_input.value.strip() +# if not base_url: +# app.notify("Ollama Server URL is required.", severity="error") +# base_url_input.focus() +# return +# +# log_output_widget.clear() +# _update_ollama_combined_output(app, f"Attempting to list models from: {base_url}...") +# +# app.run_worker( +# _worker_ollama_list_models, +# base_url, +# thread=True, +# name=f"ollama_list_models_{time.monotonic()}", +# group="ollama_api", +# description="Listing Ollama local models", +# on_success=partial(_on_list_models_success, app), +# on_error=partial(_on_ollama_worker_error, app, "list_models") +# ) +# if logging.error: +# log_output_widget.write(f"Error listing models: {logging.error}") +# +# if logging.error: # This is the original error check, the one above is newly added by script +# log_output_widget.write(f"Error listing models: {logging.error}") +# app.notify("Error listing Ollama models.", severity="error") +# elif data and data.get('models'): +# try: +# # Assuming 'data' is the JSON response, and 'models' is a list within it. +# formatted_models = json.dumps(data['models'], indent=2) +# log_output_widget.write(formatted_models) +# app.notify(f"Successfully listed {len(data['models'])} Ollama models.") +# except (TypeError, KeyError, json.JSONDecodeError) as e: +# log_output_widget.write(f"Error processing model list response: {e}\nRaw data: {data}") +# app.notify("Error processing model list from Ollama.", severity="error") +# else: +# log_output_widget.write("No models found or unexpected response.") +# app.notify("No Ollama models found or unexpected response.", severity="warning") +# except QueryError as e: # pragma: no cover +# logger.error(f"QueryError in handle_ollama_list_models_button_pressed: {e}", exc_info=True) +# app.notify("Error accessing Ollama UI elements for listing models.", severity="error") +# except Exception as e: # pragma: no cover +# logger.error(f"Unexpected error in handle_ollama_list_models_button_pressed: {e}", exc_info=True) +# app.notify("An unexpected error occurred while listing Ollama models.", severity="error") async def handle_ollama_show_model_button_pressed(app: "TldwCli") -> None: diff --git a/tldw_chatbook/Event_Handlers/ingest_events.py b/tldw_chatbook/Event_Handlers/ingest_events.py index eca057aa..b05d3375 100644 --- a/tldw_chatbook/Event_Handlers/ingest_events.py +++ b/tldw_chatbook/Event_Handlers/ingest_events.py @@ -3,6 +3,7 @@ # # Imports import json +from os import getenv from pathlib import Path from typing import TYPE_CHECKING, Optional, List, Any, Dict, Callable, Union # @@ -12,6 +13,7 @@ ListView, Collapsible, LoadingIndicator, Button from textual.css.query import QueryError from textual.containers import Container, VerticalScroll +from textual.worker import Worker from ..Constants import ALL_TLDW_API_OPTION_CONTAINERS # @@ -20,6 +22,7 @@ import tldw_chatbook.Event_Handlers.conv_char_events as ccp_handlers from .Chat_Events import chat_events as chat_handlers from tldw_chatbook.Event_Handlers.Chat_Events.chat_events import populate_chat_conversation_character_filter_select +from ..config import get_cli_setting from ..tldw_api import ( TLDWAPIClient, ProcessVideoRequest, ProcessAudioRequest, APIConnectionError, APIRequestError, APIResponseError, AuthenticationError, @@ -1121,6 +1124,12 @@ async def handle_tldw_api_submit_button_pressed(app: 'TldwCli', event: Button.Pr submit_button.disabled = True # app.notify is already called at the start of the function + def _reset_ui(): + """Return the widgets to their idle state after a hard failure.""" + loading_indicator.display = False + submit_button.disabled = False + status_area.load_text("Submission halted.") + # --- Get Auth Token (after basic validations pass) --- auth_token: Optional[str] = None try: @@ -1135,14 +1144,23 @@ async def handle_tldw_api_submit_button_pressed(app: 'TldwCli', event: Button.Pr submit_button.disabled = False status_area.load_text("Custom token required. Submission halted.") return + elif auth_method == "config_token": - auth_token = app.app_config.get("tldw_api", {}).get("auth_token_config") + # 1. Look in the active config, then in the environment. + auth_token = ( + get_cli_setting("tldw_api", "auth_token") # ~/.config/tldw_cli/config.toml + or getenv("TDLW_AUTH_TOKEN") # optional override + ) + + # 2. Abort early if we still have nothing. if not auth_token: - app.notify("Auth Token not found in tldw_api.auth_token_config. Please configure or use custom.", severity="error") - # Revert UI loading state - loading_indicator.display = False - submit_button.disabled = False - status_area.load_text("Config token missing. Submission halted.") + msg = ( + "Auth token not found — add it to the [tldw_api] section as " + "`auth_token = \"\"` or export TDLW_AUTH_TOKEN." + ) + logger.error(msg) + app.notify(msg, severity="error") + _reset_ui() return except QueryError as e: logger.error(f"UI component not found for TLDW API auth token for {selected_media_type}: {e}") @@ -1486,12 +1504,18 @@ def on_worker_failure(error: Exception): app.notify(brief_notify_message, severity="error", timeout=8) + # STORE THE CONTEXT + app._last_tldw_api_request_context = { + "request_model": request_model, + "overwrite_db": overwrite_db, + } app.run_worker( process_media_worker, name=f"tldw_api_processing_{selected_media_type}", # Unique worker name per tab group="api_calls", - description=f"Processing {selected_media_type} media via TLDW API" + description=f"Processing {selected_media_type} media via TLDW API", + exit_on_error=False ) @@ -1815,6 +1839,244 @@ def on_import_failure_notes(error: Exception): description="Importing selected note files." ) +async def handle_tldw_api_worker_failure(app: 'TldwCli', event: 'Worker.StateChanged'): + """Handles the failure of a TLDW API worker and updates the UI.""" + worker_name = event.worker.name or "" + media_type = worker_name.replace("tldw_api_processing_", "") + error = event.worker.error + + logger.error(f"TLDW API request worker failed for {media_type}: {error}", exc_info=True) + + try: + loading_indicator = app.query_one(f"#tldw-api-loading-indicator-{media_type}", LoadingIndicator) + submit_button = app.query_one(f"#tldw-api-submit-{media_type}", Button) + status_area = app.query_one(f"#tldw-api-status-area-{media_type}", TextArea) + + loading_indicator.display = False + submit_button.disabled = False + except QueryError as e_ui: + logger.error(f"UI component not found in on_worker_failure for {media_type}: {e_ui}") + return + + error_message_parts = [f"## API Request Failed! ({media_type.title()})\n\n"] + brief_notify_message = f"{media_type.title()} API Request Failed." + + # This logic is copied from your original local on_worker_failure function + if isinstance(error, APIConnectionError): + error_type = "Connection Error" + error_message_parts.append(f"**Type:** {error_type}\n") + error_message_parts.append(f"**Message:** `{str(error)}`\n") + brief_notify_message = f"Connection Error: {str(error)[:100]}" + elif isinstance(error, APIResponseError): + error_type = "API Error" + error_message_parts.append(f"**Type:** API Error\n**Status Code:** {error.status_code}\n**Message:** `{str(error)}`\n") + brief_notify_message = f"API Error {error.status_code}: {str(error)[:100]}" + # ... add other specific error types from your original function if needed ... + else: + error_type = "General Error" + error_message_parts.append(f"**Type:** {type(error).__name__}\n") + error_message_parts.append(f"**Message:** `{str(error)}`\n") + brief_notify_message = f"Processing failed: {str(error)[:100]}" + + status_area.clear() + status_area.load_text("".join(error_message_parts)) + status_area.display = True + app.notify(brief_notify_message, severity="error", timeout=8) + + +# async def handle_tldw_api_worker_success(app: 'TldwCli', event: 'Worker.StateChanged'): +# """Handles the success of a TLDW API worker and ingests the results.""" +# # This function would contain the logic from your original 'on_worker_success' +# # It needs to be made async and re-query UI elements. +# # The logic is complex, so for brevity, I'll show the skeleton. +# # The key is that you have access to `app` and `event` (which has the result). +# +# worker_name = event.worker.name or "" +# media_type = worker_name.replace("tldw_api_processing_", "") +# response_data = event.worker.result +# +# logger.info(f"TLDW API worker for {media_type} succeeded. Processing results.") +# +# try: +# # Reset UI state (disable loading, enable button) +# app.query_one(f"#tldw-api-loading-indicator-{media_type}", LoadingIndicator).display = False +# app.query_one(f"#tldw-api-submit-{media_type}", Button).disabled = False +# status_area = app.query_one(f"#tldw-api-status-area-{media_type}", TextArea) +# status_area.clear() +# +# # The rest of your success logic goes here. For example: +# if not app.media_db: +# # ... handle missing db ... +# return +# +# # ... process response_data, ingest to db, build summary_parts ... +# # (This logic is already well-written in your original `on_worker_success`) +# # You will need to retrieve `overwrite_db` and `request_model` if they +# # are needed for the success logic, potentially by storing them on the app +# # instance temporarily before starting the worker. +# status_area.load_text("## Success!\n\n(Full success logic from original function would go here)") +# status_area.display = True +# app.notify(f"{media_type.title()} processing complete.", severity="information") +# +# except QueryError as e_ui: +# logger.error(f"UI component not found in on_worker_success for {media_type}: {e_ui}") +# except Exception as e: +# logger.error(f"Error handling TLDW API worker success for {media_type}: {e}", exc_info=True) +# app.notify("Error processing successful API response.", severity="error") +async def handle_tldw_api_worker_success(app: 'TldwCli', event: 'Worker.StateChanged'): + """Handles the success of a TLDW API worker and ingests the results.""" + worker_name = event.worker.name or "" + media_type = worker_name.replace("tldw_api_processing_", "") + response_data = event.worker.result + + logger.info(f"TLDW API worker for {media_type} succeeded. Processing results.") + + try: + # Reset UI state (disable loading, enable button) + app.query_one(f"#tldw-api-loading-indicator-{media_type}", LoadingIndicator).display = False + app.query_one(f"#tldw-api-submit-{media_type}", Button).disabled = False + status_area = app.query_one(f"#tldw-api-status-area-{media_type}", TextArea) + status_area.clear() + + except QueryError as e_ui: + logger.error(f"UI component not found in on_worker_success for {media_type}: {e_ui}") + return + except Exception as e: + logger.error(f"Error resetting UI state in worker success handler: {e}", exc_info=True) + return + + # --- Pre-flight Checks and Context Retrieval --- + if not app.media_db: + logger.error("Media_DB_v2 not initialized. Cannot ingest API results.") + app.notify("Error: Local media database not available.", severity="error") + status_area.load_text("## Error\n\nLocal media database is not available. Cannot save results.") + return + + # Retrieve the context we saved before starting the worker + request_context = getattr(app, "_last_tldw_api_request_context", {}) + request_model = request_context.get("request_model") + overwrite_db = request_context.get("overwrite_db", False) + + if not request_model: + logger.error("Could not retrieve request_model from app context. Cannot properly ingest results.") + status_area.load_text("## Internal Error\n\nCould not retrieve original request context. Ingestion aborted.") + return + + # --- Data Processing and Ingestion --- + processed_count = 0 + error_count = 0 + successful_ingestions_details = [] + results_to_ingest: List[MediaItemProcessResult] = [] + + # Normalize different response types into a single list of MediaItemProcessResult + if isinstance(response_data, BatchMediaProcessResponse): + results_to_ingest = response_data.results + elif isinstance(response_data, list) and all(isinstance(item, ProcessedMediaWikiPage) for item in response_data): + for mw_page in response_data: + if mw_page.status == "Error": + error_count += 1 + logger.error(f"MediaWiki page '{mw_page.title}' processing error: {mw_page.error_message}") + continue + # Adapt ProcessedMediaWikiPage to the common result structure + results_to_ingest.append(MediaItemProcessResult( + status="Success", + input_ref=mw_page.input_ref or mw_page.title, + processing_source=mw_page.title, + media_type="mediawiki_page", + metadata={"title": mw_page.title, "page_id": mw_page.page_id, "namespace": mw_page.namespace}, + content=mw_page.content, + chunks=[{"text": chunk.get("text", ""), "metadata": chunk.get("metadata", {})} for chunk in mw_page.chunks] if mw_page.chunks else None, + )) + elif isinstance(response_data, BatchProcessXMLResponse): + for xml_item in response_data.results: + if xml_item.status == "Error": + error_count +=1; continue + results_to_ingest.append(MediaItemProcessResult( + status="Success", input_ref=xml_item.input_ref, media_type="xml", + metadata={"title": xml_item.title, "author": xml_item.author, "keywords": xml_item.keywords}, + content=xml_item.content, analysis=xml_item.summary, + )) + else: + logger.error(f"Unexpected TLDW API response data type for {media_type}: {type(response_data)}.") + status_area.load_text(f"## API Request Processed\n\nUnexpected response format. Raw response logged.") + app.notify("Error: Received unexpected data format from API.", severity="error") + return + + # --- Ingestion Loop --- + for item_result in results_to_ingest: + if item_result.status == "Success": + try: + # Prepare chunks for database insertion + unvectorized_chunks_to_save = [] + if item_result.chunks: + for chunk_item in item_result.chunks: + if isinstance(chunk_item, dict) and "text" in chunk_item: + unvectorized_chunks_to_save.append({ + "text": chunk_item.get("text"), "metadata": chunk_item.get("metadata", {}) + }) + elif isinstance(chunk_item, str): + unvectorized_chunks_to_save.append({"text": chunk_item, "metadata": {}}) + + # Call the DB function with data from both the API response and original request + media_id, _, msg = app.media_db.add_media_with_keywords( + url=item_result.input_ref, + title=item_result.metadata.get("title", item_result.input_ref), + media_type=item_result.media_type, + content=item_result.content or item_result.transcript, + keywords=item_result.metadata.get("keywords", []) or request_model.keywords, + prompt=request_model.custom_prompt, + analysis_content=item_result.analysis or item_result.summary, + author=item_result.metadata.get("author") or request_model.author, + overwrite=overwrite_db, + chunks=unvectorized_chunks_to_save + ) + + if media_id: + logger.info(f"Successfully ingested '{item_result.input_ref}' into local DB. Media ID: {media_id}. Msg: {msg}") + processed_count += 1 + successful_ingestions_details.append({ + "input_ref": item_result.input_ref, + "title": item_result.metadata.get("title", "N/A"), + "media_type": item_result.media_type, + "db_id": media_id + }) + else: + logger.error(f"Failed to ingest '{item_result.input_ref}' into local DB. Message: {msg}") + error_count += 1 + + except Exception as e_ingest: + logger.error(f"Error ingesting item '{item_result.input_ref}' into local DB: {e_ingest}", exc_info=True) + error_count += 1 + else: + logger.error(f"API processing error for '{item_result.input_ref}': {item_result.error}") + error_count += 1 + + # --- Build and Display Summary --- + summary_parts = [f"## TLDW API Request Successful ({media_type.title()})\n\n"] + if not results_to_ingest and error_count == 0: + summary_parts.append("API request successful, but no items were provided or found for processing.\n") + else: + summary_parts.append(f"- **Successfully Processed & Ingested:** {processed_count}\n") + summary_parts.append(f"- **Errors (API or DB):** {error_count}\n\n") + + if error_count > 0: + summary_parts.append("**Please check the application logs for details on any errors.**\n\n") + + if successful_ingestions_details: + summary_parts.append("### Ingested Items:\n") + for detail in successful_ingestions_details[:10]: # Show max 10 details + title_str = f" (Title: `{detail['title']}`)" if detail['title'] != 'N/A' else "" + summary_parts.append(f"- **Input:** `{detail['input_ref']}`{title_str}\n") + summary_parts.append(f" - **Type:** {detail['media_type']}, **DB ID:** {detail['db_id']}\n") + if len(successful_ingestions_details) > 10: + summary_parts.append(f"\n...and {len(successful_ingestions_details) - 10} more items.") + + status_area.load_text("".join(summary_parts)) + status_area.display = True + status_area.scroll_home(animate=False) + + notify_msg = f"{media_type.title()} Ingestion: {processed_count} done, {error_count} errors." + app.notify(notify_msg, severity="information" if error_count == 0 else "warning", timeout=7) # --- Button Handler Map --- INGEST_BUTTON_HANDLERS = { diff --git a/tldw_chatbook/UI/Ingest_Window.py b/tldw_chatbook/UI/Ingest_Window.py index 1cb0a68a..b17215df 100644 --- a/tldw_chatbook/UI/Ingest_Window.py +++ b/tldw_chatbook/UI/Ingest_Window.py @@ -57,7 +57,7 @@ def compose_tldw_api_form(self, media_type: str) -> ComposeResult: if not analysis_provider_options: analysis_provider_options = [("No Providers Configured", Select.BLANK)] - with VerticalScroll(classes="ingest-form-scrollable"): # TODO: Consider if this scrollable itself needs a unique ID if we have nested ones. For now, assuming not. + with VerticalScroll(classes="ingest-form-scrollable"): # FIXME/TODO: Needs unique Header ID since this is temlplate for whatever media type is selected yield Static("TLDW API Configuration", classes="sidebar-title") yield Label("API Endpoint URL:") yield Input(default_api_url, id=f"tldw-api-endpoint-url-{media_type}", placeholder="http://localhost:8000") diff --git a/tldw_chatbook/app.py b/tldw_chatbook/app.py index 47dd7574..b7f87026 100644 --- a/tldw_chatbook/app.py +++ b/tldw_chatbook/app.py @@ -334,6 +334,9 @@ class TldwCli(App[None]): # Specify return type for run() if needed, None is co current_chat_note_id: Optional[str] = None current_chat_note_version: Optional[int] = None + # Shared state for tldw API requests + _last_tldw_api_request_context: Dict[str, Any] = {} + def __init__(self): super().__init__() self.MediaDatabase = MediaDatabase @@ -2001,6 +2004,14 @@ async def on_button_pressed(self, event: Button.Pressed) -> None: self.loguru_logger.debug(f"Button pressed: ID='{button_id}' on Tab='{self.current_tab}'") + if button_id.startswith("tldw-api-browse-local-files-button-"): + try: + ingest_window = self.query_one(IngestWindow) + await ingest_window.on_button_pressed(event) + return # Event handled, stop further processing + except QueryError: + self.loguru_logger.error("Could not find IngestWindow to delegate browse button press.") + # 1. Handle global tab switching first if button_id.startswith("tab-"): await tab_events.handle_tab_button_pressed(self, event) @@ -2070,19 +2081,16 @@ def _update_mlx_log(self, message: str) -> None: async def on_input_changed(self, event: Input.Changed) -> None: input_id = event.input.id current_active_tab = self.current_tab - # --- Chat Sidebar Prompt Search --- - if input_id == "chat-prompt-search-input" and current_active_tab == TAB_CHAT: - await chat_handlers.handle_chat_sidebar_prompt_search_input_changed(self, event.value) # --- Notes Search --- - elif input_id == "notes-search-input" and current_active_tab == TAB_NOTES: + if input_id == "notes-search-input" and current_active_tab == TAB_NOTES: # Changed from elif to if await notes_handlers.handle_notes_search_input_changed(self, event.value) # --- Chat Sidebar Conversation Search --- elif input_id == "chat-conversation-search-bar" and current_active_tab == TAB_CHAT: await chat_handlers.handle_chat_conversation_search_bar_changed(self, event.value) elif input_id == "conv-char-search-input" and current_active_tab == TAB_CCP: - await ccp_handlers.handle_ccp_conversation_search_input_changed(self, event.value) + await ccp_handlers.handle_ccp_conversation_search_input_changed(self, event) elif input_id == "ccp-prompt-search-input" and current_active_tab == TAB_CCP: - await ccp_handlers.handle_ccp_prompt_search_input_changed(self, event.value) + await ccp_handlers.handle_ccp_prompt_search_input_changed(self, event) elif input_id == "chat-prompt-search-input" and current_active_tab == TAB_CHAT: # New condition if self._chat_sidebar_prompt_search_timer: # Use the new timer variable self._chat_sidebar_prompt_search_timer.stop() @@ -2268,6 +2276,15 @@ async def on_worker_state_changed(self, event: Worker.StateChanged) -> None: else: self.loguru_logger.debug(f"Chat-related worker '{worker_name_attr}' in other state: {worker_state}") + ####################################################################### + # --- Handle tldw server API Calls Worker (tldw API Ingestion) --- + ####################################################################### + elif worker_group == "api_calls": + self.loguru_logger.info(f"TLDW API worker '{event.worker.name}' finished with state {event.state}.") + if worker_state == WorkerState.SUCCESS: + await ingest_events.handle_tldw_api_worker_success(self, event) + elif worker_state == WorkerState.ERROR: + await ingest_events.handle_tldw_api_worker_failure(self, event) ####################################################################### # --- Handle Ollama API Worker --- diff --git a/tldw_chatbook/config.py b/tldw_chatbook/config.py index 480e6717..8826e5f1 100644 --- a/tldw_chatbook/config.py +++ b/tldw_chatbook/config.py @@ -867,7 +867,7 @@ def get_api_key(toml_key: str, env_var: str, section: Dict = api_section_legacy) [tldw_api] base_url = "http://127.0.0.1:8000" # Or your actual default remote endpoint # Default auth token can be stored here, or leave empty if user must always provide -# auth_token = "your_secret_token_if_you_have_a_default" +auth_token = "default-secret-key-for-single-user" [logging] # Log file will be placed in the same directory as the chachanotes_db_path below. @@ -894,6 +894,11 @@ def get_api_key(toml_key: str, env_var: str, section: Dict = api_section_legacy) vLLM = "http://localhost:8000" # Check if your API provider uses this address Custom = "http://localhost:1234/v1" Custom_2 = "http://localhost:5678/v1" +Custom_3 = "http://localhost:5678/v1" +Custom_4 = "http://localhost:5678/v1" +Custom_5 = "http://localhost:5678/v1" +Custom_6 = "http://localhost:5678/v1" + # Add other local URLs if needed [providers] @@ -1397,6 +1402,84 @@ def load_cli_config_and_ensure_existence(force_reload: bool = False) -> Dict[str return _CONFIG_CACHE +def save_setting_to_cli_config(section: str, key: str, value: Any) -> bool: + """ + Saves a specific setting to the user's CLI TOML configuration file. + + This function reads the current config, updates a specific key within a + section (handling nested sections like 'api_settings.openai'), and writes + the entire configuration back to the file. It then forces a reload of the + config cache. + + Args: + section: The name of the TOML section (e.g., "general", "api_settings.openai"). + key: The key within the section to update. + value: The new value for the key. + + Returns: + True if the setting was saved successfully, False otherwise. + """ + global _CONFIG_CACHE, settings + logger.info(f"Attempting to save setting: [{section}].{key} = {repr(value)}") + + # Ensure the parent directory for the config file exists. + try: + DEFAULT_CONFIG_PATH.parent.mkdir(parents=True, exist_ok=True) + except OSError as e: + logger.error(f"Could not create config directory {DEFAULT_CONFIG_PATH.parent}: {e}") + return False + + # Step 1: Read the current configuration from the user's file. + # If the file doesn't exist, we start with an empty dictionary. + config_data: Dict[str, Any] = {} + if DEFAULT_CONFIG_PATH.exists(): + try: + with open(DEFAULT_CONFIG_PATH, "rb") as f: + config_data = tomllib.load(f) + except tomllib.TOMLDecodeError as e: + logger.error(f"Corrupted config file at {DEFAULT_CONFIG_PATH}. Cannot save. Please fix or delete it. Error: {e}") + # Consider creating a backup of the corrupt file for the user. + return False + except Exception as e: + logger.error(f"Unexpected error reading {DEFAULT_CONFIG_PATH}: {e}", exc_info=True) + return False + + # Step 2: Modify the configuration data in memory. + # This handles nested sections by splitting the section string. + keys = section.split('.') + current_level = config_data + + try: + for part in keys: + # Traverse or create the nested dictionary structure. + current_level = current_level.setdefault(part, {}) + # Assign the new value to the key in the target section. + current_level[key] = value + except (TypeError, AttributeError): + # This error occurs if a key in the path (e.g., 'api_settings') is a value, not a table. + logger.error( + f"Configuration structure conflict. Could not set '{key}' in section '{section}' " + f"because a part of the path is not a table/dictionary. Please check your config file." + ) + return False + + # Step 3: Write the updated configuration back to the TOML file. + try: + with open(DEFAULT_CONFIG_PATH, "w", encoding="utf-8") as f: + toml.dump(config_data, f) + logger.success(f"Successfully saved setting to {DEFAULT_CONFIG_PATH}") + + # Step 4: Invalidate and reload global config caches to reflect changes immediately. + load_cli_config_and_ensure_existence(force_reload=True) + settings = load_settings() + logger.info("Global configuration caches reloaded.") + + return True + except (IOError, toml.TomlDecodeError) as e: + logger.error(f"Failed to write updated config to {DEFAULT_CONFIG_PATH}: {e}", exc_info=True) + return False + + # --- CLI Setting Getter --- def get_cli_setting(section: str, key: str, default: Any = None) -> Any: """Helper to get a specific setting from the loaded CLI configuration.""" diff --git a/tldw_chatbook/tldw_api/client.py b/tldw_chatbook/tldw_api/client.py index f966fa97..434ce28a 100644 --- a/tldw_chatbook/tldw_api/client.py +++ b/tldw_chatbook/tldw_api/client.py @@ -28,14 +28,19 @@ class TLDWAPIClient: def __init__(self, base_url: str, token: Optional[str] = None, timeout: float = 300.0): self.base_url = base_url.rstrip('/') self.token = token + self.bearer_token = None self.timeout = timeout self._client: Optional[httpx.AsyncClient] = None async def _get_client(self) -> httpx.AsyncClient: if self._client is None or self._client.is_closed: headers = {} + if self.bearer_token: + # Bearer Auth + headers["Authorization"] = f"Bearer {self.bearer_token}" if self.token: - headers["Authorization"] = f"Bearer {self.token}" + # Token Auth + headers["X-API-KEY"] = self.token self._client = httpx.AsyncClient( base_url=self.base_url, headers=headers, @@ -130,31 +135,31 @@ async def _stream_request( async def process_video(self, request_data: ProcessVideoRequest, file_paths: Optional[List[str]] = None) -> BatchMediaProcessResponse: form_data = model_to_form_data(request_data) httpx_files = prepare_files_for_httpx(file_paths, upload_field_name="files") - response_dict = await self._request("POST", "/api/v1/process-videos", data=form_data, files=httpx_files) + response_dict = await self._request("POST", "/api/v1/media/process-videos", data=form_data, files=httpx_files) return BatchMediaProcessResponse(**response_dict) async def process_audio(self, request_data: ProcessAudioRequest, file_paths: Optional[List[str]] = None) -> BatchMediaProcessResponse: form_data = model_to_form_data(request_data) httpx_files = prepare_files_for_httpx(file_paths, upload_field_name="files") - response_dict = await self._request("POST", "/api/v1/process-audios", data=form_data, files=httpx_files) + response_dict = await self._request("POST", "/api/v1/media/process-audios", data=form_data, files=httpx_files) return BatchMediaProcessResponse(**response_dict) async def process_pdf(self, request_data: ProcessPDFRequest, file_paths: Optional[List[str]] = None) -> BatchMediaProcessResponse: form_data = model_to_form_data(request_data) httpx_files = prepare_files_for_httpx(file_paths, upload_field_name="files") - response_dict = await self._request("POST", "/api/v1/process-pdfs", data=form_data, files=httpx_files) + response_dict = await self._request("POST", "/api/v1/media/process-pdfs", data=form_data, files=httpx_files) return BatchMediaProcessResponse(**response_dict) async def process_ebook(self, request_data: ProcessEbookRequest, file_paths: Optional[List[str]] = None) -> BatchMediaProcessResponse: form_data = model_to_form_data(request_data) httpx_files = prepare_files_for_httpx(file_paths, upload_field_name="files") - response_dict = await self._request("POST", "/api/v1/process-ebooks", data=form_data, files=httpx_files) + response_dict = await self._request("POST", "/api/v1/media/process-ebooks", data=form_data, files=httpx_files) return BatchMediaProcessResponse(**response_dict) async def process_document(self, request_data: ProcessDocumentRequest, file_paths: Optional[List[str]] = None) -> BatchMediaProcessResponse: form_data = model_to_form_data(request_data) httpx_files = prepare_files_for_httpx(file_paths, upload_field_name="files") - response_dict = await self._request("POST", "/api/v1/process-documents", data=form_data, files=httpx_files) + response_dict = await self._request("POST", "/api/v1/media/process-documents", data=form_data, files=httpx_files) return BatchMediaProcessResponse(**response_dict) async def process_xml(self, request_data: ProcessXMLRequest, file_path: str) -> BatchProcessXMLResponse: # XML expects single file