diff --git a/.env.template b/.env.template index e65d691..95a024f 100644 --- a/.env.template +++ b/.env.template @@ -57,6 +57,11 @@ AI_FOUNDRY_EMBEDDING_DATA_TYPE=float32 AI_FOUNDRY_EMBEDDING_DISTANCE_FUNCTION=cosine COSMOS_DB_FULL_TEXT_LANGUAGE=en-US +# Embed raw conversation turns on write so they can be vector-searched via +# search(target="turns"). The turns container is always provisioned with a +# vector index, so toggling this never requires recreating the container. +ENABLE_TURN_EMBEDDINGS=false + AI_FOUNDRY_CHAT_DEPLOYMENT_NAME= # Optional. Pin the Azure OpenAI REST API version used by chat and embeddings # clients. Leave blank to use the toolkit default ("2024-12-01-preview"). diff --git a/Docs/concepts.md b/Docs/concepts.md index 958ab19..7ce0c59 100644 --- a/Docs/concepts.md +++ b/Docs/concepts.md @@ -94,6 +94,8 @@ Memories stored in Cosmos DB include embeddings generated by Microsoft AI Foundr Facts work especially well for vector search because each fact is stored as a small, self-contained document. +By default raw conversation turns are *not* embedded — only derived memories (facts, episodic, procedural, summaries) carry vectors. Set `enable_turn_embeddings=True` (env `ENABLE_TURN_EMBEDDINGS`) to also embed turns on write, then call `search_cosmos(target="turns")` to vector-search the raw conversation log. The turns container is always provisioned with a `quantizedFlat` vector index, so this flag only toggles embedding generation and can be turned on or off at any time without recreating the container. + --- ## Processing Pipeline diff --git a/Docs/public_api.md b/Docs/public_api.md index 1a99f3c..1eb427b 100644 --- a/Docs/public_api.md +++ b/Docs/public_api.md @@ -14,7 +14,7 @@ ### Connection -- `__init__(cosmos_endpoint=None, cosmos_credential=None, cosmos_key=None, cosmos_database=None, cosmos_container=None, cosmos_turns_container='memories_turns', cosmos_summaries_container='memories_summaries', cosmos_counter_container=None, cosmos_lease_container=None, cosmos_throughput_mode=None, cosmos_autoscale_max_ru=None, ai_foundry_endpoint=None, ai_foundry_credential=None, ai_foundry_api_key=None, embedding_deployment_name='text-embedding-3-large', embedding_dimensions=None, chat_deployment_name='gpt-4o-mini', use_default_credential=True, processor=None) -> None` — configure local state, model clients, optional Cosmos auto-connect, and optional processing backend. The SDK uses a hard 3-container topology: turns in `memories_turns`, facts/episodic/procedural in `memories`, and summaries in `memories_summaries` (or the names you pass). +- `__init__(cosmos_endpoint=None, cosmos_credential=None, cosmos_key=None, cosmos_database=None, cosmos_container=None, cosmos_turns_container='memories_turns', cosmos_summaries_container='memories_summaries', cosmos_counter_container=None, cosmos_lease_container=None, cosmos_throughput_mode=None, cosmos_autoscale_max_ru=None, ai_foundry_endpoint=None, ai_foundry_credential=None, ai_foundry_api_key=None, embedding_deployment_name='text-embedding-3-large', embedding_dimensions=None, chat_deployment_name='gpt-4o-mini', use_default_credential=True, enable_turn_embeddings=None, processor=None) -> None` — configure local state, model clients, optional Cosmos auto-connect, and optional processing backend. The SDK uses a hard 3-container topology: turns in `memories_turns`, facts/episodic/procedural in `memories`, and summaries in `memories_summaries` (or the names you pass). `enable_turn_embeddings` (default `False`, env `ENABLE_TURN_EMBEDDINGS`) embeds raw turns on write so they can be vector-searched via `search_cosmos(target="turns")`; the turns container is always provisioned with a vector index, so toggling this never requires recreating it. - `close() -> None` — close Cosmos/model clients and owned credentials. - `connect_cosmos(endpoint=None, credential=None, key=None, database=None, container=None, turns_container=None, summaries_container=None) -> None` — connect to existing memory, turns, and summaries containers. - `create_memory_store(database=None, container=None, turns_container=None, summaries_container=None, counter_container=None, lease_container=None, endpoint=None, credential=None, key=None, embedding_dimensions=None, embedding_data_type=None, distance_function=None, full_text_language=None, throughput_mode=None, autoscale_max_ru=None) -> None` — create/connect the memory, turns, summaries, counter, and lease containers. @@ -37,7 +37,7 @@ ### Retrieval -- `search_cosmos(search_terms, memory_id=None, user_id=None, role=None, memory_types=None, thread_id=None, hybrid_search=False, top_k=5, tags_all=None, tags_any=None, exclude_tags=None, include_superseded=False, min_salience=None, min_confidence=None, created_after=None, created_before=None) -> list[dict]` — vector or hybrid search memories. +- `search_cosmos(search_terms, memory_id=None, user_id=None, role=None, memory_types=None, thread_id=None, hybrid_search=False, top_k=5, tags_all=None, tags_any=None, exclude_tags=None, include_superseded=False, min_salience=None, min_confidence=None, created_after=None, created_before=None, target='memories') -> list[dict]` — vector or hybrid search memories. Set `target="turns"` to search the raw conversation log instead of facts/episodic/procedural (requires turn embeddings; see `enable_turn_embeddings`). - `get_procedural_prompt(user_id) -> Optional[str]` — read the active procedural prompt. - `get_procedural_history(user_id, limit=10) -> list[dict]` — read procedural prompt history. - `get_procedural_memories(user_id, priority=None, category=None, min_salience=None, include_superseded=False) -> list[dict]` — retrieve procedural memory documents. @@ -67,7 +67,7 @@ Local-buffer methods remain synchronous in-memory operations; Cosmos, retrieval, ### Connection -- `__init__(cosmos_endpoint=None, cosmos_credential=None, cosmos_key=None, cosmos_database=None, cosmos_container=None, cosmos_turns_container='memories_turns', cosmos_summaries_container='memories_summaries', cosmos_counter_container=None, cosmos_lease_container=None, cosmos_throughput_mode=None, cosmos_autoscale_max_ru=None, ai_foundry_endpoint=None, ai_foundry_credential=None, ai_foundry_api_key=None, embedding_deployment_name='text-embedding-3-large', embedding_dimensions=None, chat_deployment_name='gpt-4o-mini', use_default_credential=True, processor=None) -> None` — configure async local state, model clients, and optional processing backend. The async SDK uses the same hard 3-container topology as the sync client. +- `__init__(cosmos_endpoint=None, cosmos_credential=None, cosmos_key=None, cosmos_database=None, cosmos_container=None, cosmos_turns_container='memories_turns', cosmos_summaries_container='memories_summaries', cosmos_counter_container=None, cosmos_lease_container=None, cosmos_throughput_mode=None, cosmos_autoscale_max_ru=None, ai_foundry_endpoint=None, ai_foundry_credential=None, ai_foundry_api_key=None, embedding_deployment_name='text-embedding-3-large', embedding_dimensions=None, chat_deployment_name='gpt-4o-mini', use_default_credential=True, enable_turn_embeddings=None, processor=None) -> None` — configure async local state, model clients, and optional processing backend. The async SDK uses the same hard 3-container topology as the sync client. `enable_turn_embeddings` (default `False`, env `ENABLE_TURN_EMBEDDINGS`) embeds raw turns on write so they can be vector-searched via `search_cosmos(target="turns")`. - `async close() -> None` — close async/sync resources and owned credentials. - `async connect_cosmos(endpoint=None, credential=None, key=None, database=None, container=None, turns_container=None, summaries_container=None) -> None` — connect to existing memory, turns, and summaries containers. - `async create_memory_store(database=None, container=None, turns_container=None, summaries_container=None, counter_container=None, lease_container=None, endpoint=None, credential=None, key=None, embedding_dimensions=None, embedding_data_type=None, distance_function=None, full_text_language=None, throughput_mode=None, autoscale_max_ru=None) -> None` — create/connect memory, turns, summaries, counter, and lease containers. @@ -90,7 +90,7 @@ Local-buffer methods remain synchronous in-memory operations; Cosmos, retrieval, ### Retrieval -- `async search_cosmos(search_terms, memory_id=None, user_id=None, role=None, memory_types=None, thread_id=None, hybrid_search=False, top_k=5, tags_all=None, tags_any=None, exclude_tags=None, include_superseded=False, min_salience=None, min_confidence=None, created_after=None, created_before=None) -> list[dict]` — vector or hybrid search memories. +- `async search_cosmos(search_terms, memory_id=None, user_id=None, role=None, memory_types=None, thread_id=None, hybrid_search=False, top_k=5, tags_all=None, tags_any=None, exclude_tags=None, include_superseded=False, min_salience=None, min_confidence=None, created_after=None, created_before=None, target='memories') -> list[dict]` — vector or hybrid search memories. Set `target="turns"` to search the raw conversation log instead of facts/episodic/procedural (requires turn embeddings; see `enable_turn_embeddings`). - `async get_procedural_prompt(user_id) -> Optional[str]` — read the active procedural prompt. - `async get_procedural_history(user_id, limit=10) -> list[dict]` — read procedural prompt history. - `async get_procedural_memories(user_id, priority=None, category=None, min_salience=None, include_superseded=False) -> list[dict]` — retrieve procedural memory documents. diff --git a/azure/cosmos/agent_memory/_base/base_client.py b/azure/cosmos/agent_memory/_base/base_client.py index 2fef0a5..54b0481 100644 --- a/azure/cosmos/agent_memory/_base/base_client.py +++ b/azure/cosmos/agent_memory/_base/base_client.py @@ -17,6 +17,7 @@ ) from azure.cosmos.agent_memory.exceptions import CosmosNotConnectedError, MemoryNotFoundError, ValidationError from azure.cosmos.agent_memory.logging import configure_logging, get_logger +from azure.cosmos.agent_memory.thresholds import get_enable_turn_embeddings logger = get_logger(__name__) @@ -45,6 +46,7 @@ def _init_base_config( embedding_dimensions: Optional[int], chat_deployment_name: str, use_default_credential: bool, + enable_turn_embeddings: Optional[bool] = None, default_credential_module: str = "azure.identity", ) -> None: """Initialize shared local state, config values, and default credentials.""" @@ -75,6 +77,9 @@ def _init_base_config( self._embedding_deployment_name = embedding_deployment_name self._embedding_dimensions = _resolve_embedding_dimensions(embedding_dimensions) self._chat_deployment_name = chat_deployment_name + self._enable_turn_embeddings = ( + enable_turn_embeddings if enable_turn_embeddings is not None else get_enable_turn_embeddings() + ) self._owns_cosmos_credential = False self._owns_ai_foundry_credential = False diff --git a/azure/cosmos/agent_memory/_container_routing.py b/azure/cosmos/agent_memory/_container_routing.py index fd439b5..f4adb00 100644 --- a/azure/cosmos/agent_memory/_container_routing.py +++ b/azure/cosmos/agent_memory/_container_routing.py @@ -32,6 +32,24 @@ class ContainerKey(str, Enum): USER_SCOPED_MEMORIES_TYPES: frozenset[str] = frozenset({"episodic", "procedural"}) +# Containers that expose a vector index and may be targeted by ``search``. +_SEARCH_TARGETS: dict[str, ContainerKey] = { + "memories": ContainerKey.MEMORIES, + "turns": ContainerKey.TURNS, +} + + +def resolve_search_target(target: str) -> ContainerKey: + """Map a public ``search(target=...)`` value to its ``ContainerKey``. + + ``"memories"`` (the default) targets facts/episodic/procedural; ``"turns"`` + targets the raw conversation log (requires turn embeddings to be enabled). + """ + try: + return _SEARCH_TARGETS[target] + except KeyError as exc: + raise ValueError(f"Unknown search target {target!r}; valid targets: {sorted(_SEARCH_TARGETS)}") from exc + def container_key_for_type(memory_type: str) -> ContainerKey: """Return the ``ContainerKey`` that owns documents of ``memory_type``.""" diff --git a/azure/cosmos/agent_memory/_utils.py b/azure/cosmos/agent_memory/_utils.py index 8fc8e8b..f27b53a 100644 --- a/azure/cosmos/agent_memory/_utils.py +++ b/azure/cosmos/agent_memory/_utils.py @@ -145,7 +145,8 @@ def _resolve_embedding_dimensions(val: Optional[int]) -> int: """Resolve embedding dimensions from explicit value or ``AI_FOUNDRY_EMBEDDING_DIMENSIONS`` env var. Defaults to 1536 (the dimension we ship with for ``text-embedding-3-large`` - truncated to 1536, which is the size DiskANN is tuned for in our containers). + truncated to 1536, which is the size our quantizedFlat vector indexes are + tuned for in our containers). Raises :class:`ConfigurationError` if the env var is set but cannot be parsed as a positive integer. @@ -368,8 +369,14 @@ def _container_policies( embedding_data_type: str, distance_function: str, full_text_language: str, + include_salience_composite: bool = True, ) -> tuple[dict, dict, dict]: - """Build the vector, indexing, and full-text policies for container creation.""" + """Build the vector, indexing, and full-text policies for container creation. + + ``include_salience_composite`` adds the ``(salience, created_at, id)`` + composite index required by procedural synthesis on the MEMORIES container. + Turns reuse this builder with it disabled (turns are never synthesized). + """ vector_embedding_policy = { "vectorEmbeddings": [ { @@ -384,25 +391,27 @@ def _container_policies( indexing_policy = { "includedPaths": [{"path": "/*"}], "excludedPaths": [ - {"path": "/embedding/*"}, {"path": "/source_memory_ids/*"}, {"path": "/supersedes_ids/*"}, + {"path": '/"_etag"/?'}, ], - "vectorIndexes": [{"path": "/embedding", "type": "diskANN"}], + "vectorIndexes": [{"path": "/embedding", "type": "quantizedFlat"}], "fullTextIndexes": [{"path": "/content"}], + } + + if include_salience_composite: # Procedural synthesis selects TOP N by (salience DESC, created_at ASC, id ASC). # Cosmos requires a composite index for multi-property ORDER BY; without it the # query returns a non-deterministic 50 of N when many docs share the default # salience (0.5), which makes the source-id short-circuit in synthesize_procedural # thrash and burn LLM calls on every reconcile. - "compositeIndexes": [ + indexing_policy["compositeIndexes"] = [ [ {"path": "/salience", "order": "descending"}, {"path": "/created_at", "order": "ascending"}, {"path": "/id", "order": "ascending"}, ] - ], - } + ] full_text_policy = { "defaultLanguage": full_text_language, diff --git a/azure/cosmos/agent_memory/aio/cosmos_memory_client.py b/azure/cosmos/agent_memory/aio/cosmos_memory_client.py index a2c1743..0ecc27a 100644 --- a/azure/cosmos/agent_memory/aio/cosmos_memory_client.py +++ b/azure/cosmos/agent_memory/aio/cosmos_memory_client.py @@ -36,18 +36,6 @@ logger = get_logger(__name__) -_TURNS_INDEXING_POLICY = { - "indexingMode": "consistent", - "automatic": True, - "includedPaths": [{"path": "/*"}], - "excludedPaths": [ - {"path": "/embedding/?"}, - {"path": "/source_memory_ids/*"}, - {"path": "/supersedes_ids/*"}, - {"path": '/"_etag"/?'}, - ], -} - _SUMMARIES_INDEXING_POLICY = { "indexingMode": "consistent", "automatic": True, @@ -96,6 +84,7 @@ def __init__( embedding_dimensions: Optional[int] = None, chat_deployment_name: str = "gpt-4o-mini", use_default_credential: bool = True, + enable_turn_embeddings: Optional[bool] = None, processor: Optional[AsyncMemoryProcessor] = None, transcript_metadata_keys: Optional[Iterable[str]] = None, ) -> None: @@ -118,6 +107,7 @@ def __init__( embedding_dimensions=embedding_dimensions, chat_deployment_name=chat_deployment_name, use_default_credential=use_default_credential, + enable_turn_embeddings=enable_turn_embeddings, default_credential_module="azure.identity.aio", ) self._background_tasks: set[asyncio.Task[Any]] = set() @@ -305,12 +295,18 @@ async def create_memory_store( autoscale_max_ru=self._cosmos_autoscale_max_ru, throughput_properties_cls=ThroughputProperties, ) - vec_policy, idx_policy, ft_policy = _container_policies( + _policy_kwargs = dict( embedding_dimensions=embedding_dimensions or self._embedding_dimensions or 1536, embedding_data_type=_resolve_embedding_data_type(embedding_data_type), distance_function=_resolve_distance_function(distance_function), full_text_language=_resolve_full_text_language(full_text_language), ) + vec_policy, idx_policy, ft_policy = _container_policies(**_policy_kwargs) + # Turns always carry the vector index (primed for search) but skip the + # salience composite index, which only procedural synthesis needs. + turns_vec_policy, turns_idx_policy, turns_ft_policy = _container_policies( + **_policy_kwargs, include_salience_composite=False + ) self._memories_container_client = await db.create_container_if_not_exists( **_build_container_kwargs( container_id=self._cosmos_container, @@ -328,7 +324,9 @@ async def create_memory_store( partition_key=partition_key, offer_throughput=offer, default_ttl=DEFAULT_TTL_BY_TYPE["turn"], - indexing_policy=_TURNS_INDEXING_POLICY, + indexing_policy=turns_idx_policy, + vector_embedding_policy=turns_vec_policy, + full_text_policy=turns_ft_policy, ) ) logger.info("Created turns container: %s/%s", self._cosmos_database, self._cosmos_turns_container) @@ -397,7 +395,11 @@ async def validate_topology(self) -> None: ) from exc def _build_store(self) -> AsyncMemoryStore: - return AsyncMemoryStore(containers=self._containers, embeddings_client=self._embeddings_client) + return AsyncMemoryStore( + containers=self._containers, + embeddings_client=self._embeddings_client, + enable_turn_embeddings=self._enable_turn_embeddings, + ) def _build_pipeline(self, store: AsyncMemoryStore) -> AsyncPipelineService: return AsyncPipelineService( @@ -664,6 +666,7 @@ async def search_cosmos( min_confidence: Optional[float] = None, created_after: Optional[str | datetime] = None, created_before: Optional[str | datetime] = None, + target: str = "memories", ) -> list[dict[str, Any]]: return await self._get_store().search( search_terms=search_terms, @@ -682,6 +685,7 @@ async def search_cosmos( min_confidence=min_confidence, created_after=created_after, created_before=created_before, + target=target, ) async def get_thread( diff --git a/azure/cosmos/agent_memory/aio/processors/inprocess.py b/azure/cosmos/agent_memory/aio/processors/inprocess.py index 9e10071..4b08851 100644 --- a/azure/cosmos/agent_memory/aio/processors/inprocess.py +++ b/azure/cosmos/agent_memory/aio/processors/inprocess.py @@ -45,13 +45,18 @@ def __init__( from azure.cosmos.agent_memory._container_routing import ContainerKey from azure.cosmos.agent_memory.aio.services.pipeline import AsyncPipelineService from azure.cosmos.agent_memory.aio.store import AsyncMemoryStore + from azure.cosmos.agent_memory.thresholds import get_enable_turn_embeddings containers = { ContainerKey.TURNS: turns_container, ContainerKey.MEMORIES: cosmos_container, ContainerKey.SUMMARIES: summaries_container, } - store = AsyncMemoryStore(containers=containers, embeddings_client=embeddings_client) + store = AsyncMemoryStore( + containers=containers, + embeddings_client=embeddings_client, + enable_turn_embeddings=get_enable_turn_embeddings(), + ) pipeline = AsyncPipelineService(store, chat_client, embeddings_client, containers=containers) self._pipeline = pipeline diff --git a/azure/cosmos/agent_memory/aio/store/memory_store.py b/azure/cosmos/agent_memory/aio/store/memory_store.py index 27104a1..184b9a8 100644 --- a/azure/cosmos/agent_memory/aio/store/memory_store.py +++ b/azure/cosmos/agent_memory/aio/store/memory_store.py @@ -12,6 +12,7 @@ USER_SCOPED_MEMORIES_TYPES, ContainerKey, container_key_for_type, + resolve_search_target, ) from azure.cosmos.agent_memory._query_builder import _QueryBuilder from azure.cosmos.agent_memory._utils import ( @@ -58,12 +59,14 @@ def __init__( *, containers: dict[ContainerKey, Any], embeddings_client: Any = None, + enable_turn_embeddings: bool = False, ) -> None: self._containers = containers self._turns_container = containers[ContainerKey.TURNS] self._memories_container = containers[ContainerKey.MEMORIES] self._summaries_container = containers[ContainerKey.SUMMARIES] self._embeddings_client = embeddings_client + self._enable_turn_embeddings = enable_turn_embeddings @property def container(self) -> Any: @@ -199,7 +202,7 @@ async def add( body = record.to_cosmos_dict() if embed is None: - embed = memory_type != "turn" + embed = memory_type != "turn" or self._enable_turn_embeddings if embedding is not None: body["embedding"] = embedding elif embed and content and self._embeddings_client is not None: @@ -242,7 +245,8 @@ async def push(self, local_memory: list[dict[str, Any]], batch_size: int = 25) - to_embed_idx: list[int] = [] to_embed_text: list[str] = [] for i, body in enumerate(bodies): - if body.get("type") != "turn" and body.get("content") and not body.get("embedding"): + embeddable_type = body.get("type") != "turn" or self._enable_turn_embeddings + if embeddable_type and body.get("content") and not body.get("embedding"): to_embed_idx.append(i) to_embed_text.append(body["content"]) if to_embed_text and self._embeddings_client is not None: @@ -811,8 +815,15 @@ async def search( created_before: Optional[str | datetime] = None, *, query: Optional[str] = None, + target: str = "memories", ) -> list[dict[str, Any]]: - """Search memories using vector similarity with optional full-text hybrid ranking.""" + """Search memories using vector similarity with optional full-text hybrid ranking. + + ``target`` selects the container to search: ``"memories"`` (default) for + facts/episodic/procedural, or ``"turns"`` for the raw conversation log + (requires turn embeddings to have been enabled when the turns were written). + """ + container_key = resolve_search_target(target) terms = require_search_terms(search_terms, query) _validate_hybrid_search(hybrid_search, terms) top = top_literal(top_k, name="top_k") @@ -843,13 +854,17 @@ async def search( parameters.append({"name": "@key_terms", "value": terms}) partition_key, _ = query_scope(user_id, thread_id) - if thread_id is not None and (not memory_types or set(memory_types) & USER_SCOPED_MEMORIES_TYPES): + if ( + container_key == ContainerKey.MEMORIES + and thread_id is not None + and (not memory_types or set(memory_types) & USER_SCOPED_MEMORIES_TYPES) + ): partition_key = None logger.debug("AsyncMemoryStore.search query: %s", sql) return await self.query( sql, parameters, - container_key=ContainerKey.MEMORIES, + container_key=container_key, partition_key=partition_key, ) diff --git a/azure/cosmos/agent_memory/cosmos_memory_client.py b/azure/cosmos/agent_memory/cosmos_memory_client.py index fc54a40..dababbe 100644 --- a/azure/cosmos/agent_memory/cosmos_memory_client.py +++ b/azure/cosmos/agent_memory/cosmos_memory_client.py @@ -36,18 +36,6 @@ logger = get_logger(__name__) -_TURNS_INDEXING_POLICY = { - "indexingMode": "consistent", - "automatic": True, - "includedPaths": [{"path": "/*"}], - "excludedPaths": [ - {"path": "/embedding/?"}, - {"path": "/source_memory_ids/*"}, - {"path": "/supersedes_ids/*"}, - {"path": '/"_etag"/?'}, - ], -} - _SUMMARIES_INDEXING_POLICY = { "indexingMode": "consistent", "automatic": True, @@ -91,6 +79,7 @@ def __init__( embedding_dimensions: Optional[int] = None, chat_deployment_name: str = "gpt-4o-mini", use_default_credential: bool = True, + enable_turn_embeddings: Optional[bool] = None, processor: Optional[MemoryProcessor] = None, transcript_metadata_keys: Optional[Iterable[str]] = None, ) -> None: @@ -113,6 +102,7 @@ def __init__( embedding_dimensions=embedding_dimensions, chat_deployment_name=chat_deployment_name, use_default_credential=use_default_credential, + enable_turn_embeddings=enable_turn_embeddings, ) self._embeddings_client = EmbeddingsClient( endpoint=self._ai_foundry_endpoint, @@ -276,12 +266,18 @@ def create_memory_store( autoscale_max_ru=self._cosmos_autoscale_max_ru, throughput_properties_cls=ThroughputProperties, ) - vec_policy, idx_policy, ft_policy = _container_policies( + _policy_kwargs = dict( embedding_dimensions=embedding_dimensions or self._embedding_dimensions or 1536, embedding_data_type=_resolve_embedding_data_type(embedding_data_type), distance_function=_resolve_distance_function(distance_function), full_text_language=_resolve_full_text_language(full_text_language), ) + vec_policy, idx_policy, ft_policy = _container_policies(**_policy_kwargs) + # Turns always carry the vector index (primed for search) but skip the + # salience composite index, which only procedural synthesis needs. + turns_vec_policy, turns_idx_policy, turns_ft_policy = _container_policies( + **_policy_kwargs, include_salience_composite=False + ) self._memories_container_client = db.create_container_if_not_exists( **_build_container_kwargs( container_id=self._cosmos_container, @@ -299,7 +295,9 @@ def create_memory_store( partition_key=partition_key, offer_throughput=offer, default_ttl=DEFAULT_TTL_BY_TYPE["turn"], - indexing_policy=_TURNS_INDEXING_POLICY, + indexing_policy=turns_idx_policy, + vector_embedding_policy=turns_vec_policy, + full_text_policy=turns_ft_policy, ) ) logger.info("Created turns container: %s/%s", self._cosmos_database, self._cosmos_turns_container) @@ -368,7 +366,11 @@ def validate_topology(self) -> None: ) from exc def _build_store(self) -> MemoryStore: - return MemoryStore(containers=self._containers, embeddings_client=self._embeddings_client) + return MemoryStore( + containers=self._containers, + embeddings_client=self._embeddings_client, + enable_turn_embeddings=self._enable_turn_embeddings, + ) def _build_pipeline(self, store: MemoryStore) -> PipelineService: return PipelineService( @@ -625,8 +627,13 @@ def search_cosmos( min_confidence: Optional[float] = None, created_after: Optional[str | datetime] = None, created_before: Optional[str | datetime] = None, + target: str = "memories", ) -> list[dict[str, Any]]: - """Search memories in Cosmos DB using vector similarity.""" + """Search memories in Cosmos DB using vector similarity. + + Set ``target="turns"`` to vector-search the raw conversation log + instead of facts/episodic/procedural (requires turn embeddings). + """ return self._get_store().search( search_terms=search_terms, memory_id=memory_id, @@ -644,6 +651,7 @@ def search_cosmos( min_confidence=min_confidence, created_after=created_after, created_before=created_before, + target=target, ) def get_thread( diff --git a/azure/cosmos/agent_memory/processors/inprocess.py b/azure/cosmos/agent_memory/processors/inprocess.py index 5e66e4e..4a113c6 100644 --- a/azure/cosmos/agent_memory/processors/inprocess.py +++ b/azure/cosmos/agent_memory/processors/inprocess.py @@ -42,13 +42,18 @@ def __init__( from .._container_routing import ContainerKey from ..services.pipeline import PipelineService from ..store import MemoryStore + from ..thresholds import get_enable_turn_embeddings containers = { ContainerKey.TURNS: turns_container, ContainerKey.MEMORIES: cosmos_container, ContainerKey.SUMMARIES: summaries_container, } - store = MemoryStore(containers=containers, embeddings_client=embeddings_client) + store = MemoryStore( + containers=containers, + embeddings_client=embeddings_client, + enable_turn_embeddings=get_enable_turn_embeddings(), + ) pipeline = PipelineService(store, chat_client, embeddings_client, containers=containers) self._pipeline = pipeline diff --git a/azure/cosmos/agent_memory/store/memory_store.py b/azure/cosmos/agent_memory/store/memory_store.py index b49e4c1..a63221e 100644 --- a/azure/cosmos/agent_memory/store/memory_store.py +++ b/azure/cosmos/agent_memory/store/memory_store.py @@ -10,6 +10,7 @@ USER_SCOPED_MEMORIES_TYPES, ContainerKey, container_key_for_type, + resolve_search_target, ) from azure.cosmos.agent_memory._query_builder import _QueryBuilder from azure.cosmos.agent_memory._utils import ( @@ -71,12 +72,14 @@ def __init__( *, containers: dict[ContainerKey, Any], embeddings_client: Any = None, + enable_turn_embeddings: bool = False, ) -> None: self._containers = containers self._turns_container = containers[ContainerKey.TURNS] self._memories_container = containers[ContainerKey.MEMORIES] self._summaries_container = containers[ContainerKey.SUMMARIES] self._embeddings_client = embeddings_client + self._enable_turn_embeddings = enable_turn_embeddings @property def container(self) -> Any: @@ -216,7 +219,7 @@ def add( body = record.to_cosmos_dict() if embed is None: - embed = memory_type != "turn" + embed = memory_type != "turn" or self._enable_turn_embeddings if embedding is not None: body["embedding"] = embedding elif embed and content and self._embeddings_client is not None: @@ -255,7 +258,8 @@ def push(self, local_memory: list[dict[str, Any]], batch_size: int = 25) -> None to_embed_idx: list[int] = [] to_embed_text: list[str] = [] for i, body in enumerate(bodies): - if body.get("type") != "turn" and body.get("content") and not body.get("embedding"): + embeddable_type = body.get("type") != "turn" or self._enable_turn_embeddings + if embeddable_type and body.get("content") and not body.get("embedding"): to_embed_idx.append(i) to_embed_text.append(body["content"]) if to_embed_text and self._embeddings_client is not None: @@ -830,8 +834,15 @@ def search( created_before: Optional[str | datetime] = None, *, query: Optional[str] = None, + target: str = "memories", ) -> list[dict[str, Any]]: - """Search memories using vector similarity with optional full-text hybrid ranking.""" + """Search memories using vector similarity with optional full-text hybrid ranking. + + ``target`` selects the container to search: ``"memories"`` (default) for + facts/episodic/procedural, or ``"turns"`` for the raw conversation log + (requires turn embeddings to have been enabled when the turns were written). + """ + container_key = resolve_search_target(target) terms = require_search_terms(search_terms, query) _validate_hybrid_search(hybrid_search, terms) top = top_literal(top_k, name="top_k") @@ -862,13 +873,17 @@ def search( parameters.append({"name": "@key_terms", "value": terms}) partition_key, cross_partition = query_scope(user_id, thread_id) - if thread_id is not None and (not memory_types or set(memory_types) & USER_SCOPED_MEMORIES_TYPES): + if ( + container_key == ContainerKey.MEMORIES + and thread_id is not None + and (not memory_types or set(memory_types) & USER_SCOPED_MEMORIES_TYPES) + ): partition_key, cross_partition = None, True logger.debug("MemoryStore.search query: %s", sql) return self.query( sql, parameters, - container_key=ContainerKey.MEMORIES, + container_key=container_key, partition_key=partition_key, cross_partition=cross_partition, ) diff --git a/azure/cosmos/agent_memory/thresholds.py b/azure/cosmos/agent_memory/thresholds.py index 104e1e3..79c40c8 100644 --- a/azure/cosmos/agent_memory/thresholds.py +++ b/azure/cosmos/agent_memory/thresholds.py @@ -42,6 +42,12 @@ DEFAULT_PROCEDURAL_SYNTHESIS_AUTO = True +# Whether raw conversation turns are embedded on write so they can be vector +# searched. Default ``False`` preserves today's behavior (turns are stored +# without an ``embedding`` field). The turns container always carries the +# vector index, so this only governs whether vectors are generated/searched. +DEFAULT_ENABLE_TURN_EMBEDDINGS = False + # Owner exclusivity — declares which backend is authoritative for the shared # memories + counter container. When set, the *other* backend skips its # auto-trigger and logs a loud warning. Default unset preserves today's @@ -143,6 +149,17 @@ def get_procedural_synthesis_auto() -> bool: return _parse_bool("PROCEDURAL_SYNTHESIS_AUTO", DEFAULT_PROCEDURAL_SYNTHESIS_AUTO) +def get_enable_turn_embeddings() -> bool: + """Whether raw turns are embedded on write and made vector-searchable. + + Set ``ENABLE_TURN_EMBEDDINGS=true`` to generate embeddings for ``turn`` + documents (so ``search(target="turns")`` returns ranked turns). Default + ``False`` keeps turns un-embedded. The turns container always carries the + vector index, so enabling this never requires recreating the container. + """ + return _parse_bool("ENABLE_TURN_EMBEDDINGS", DEFAULT_ENABLE_TURN_EMBEDDINGS) + + def get_processor_owner() -> Optional[str]: """Return the configured ``MEMORY_PROCESSOR_OWNER`` or ``None``. @@ -188,6 +205,7 @@ def get_processor_owner() -> Optional[str]: "DEFAULT_DEDUP_POOL_SIZE", "DEFAULT_TTL_BY_TYPE", "DEFAULT_PROCEDURAL_SYNTHESIS_AUTO", + "DEFAULT_ENABLE_TURN_EMBEDDINGS", "PROCESSOR_OWNER_INPROCESS", "PROCESSOR_OWNER_DURABLE", "default_ttl_for", @@ -197,5 +215,6 @@ def get_processor_owner() -> Optional[str]: "get_dedup_every_n", "get_dedup_pool_size", "get_procedural_synthesis_auto", + "get_enable_turn_embeddings", "get_processor_owner", ] diff --git a/function_app/local.settings.json.template b/function_app/local.settings.json.template index 47202ed..5e457f5 100644 --- a/function_app/local.settings.json.template +++ b/function_app/local.settings.json.template @@ -28,6 +28,9 @@ "MEMORY_PROCESSOR_OWNER": "durable", "// --- Batch knob ---": "", - "MAX_BATCH_SIZE": "20" + "MAX_BATCH_SIZE": "20", + + "// --- Turn vector search: embed raw turns on write so search(target='turns') works. Turns container is always vector-indexed; default false. ---": "", + "ENABLE_TURN_EMBEDDINGS": "false" } } diff --git a/function_app/shared/config.py b/function_app/shared/config.py index f15ed35..c898409 100644 --- a/function_app/shared/config.py +++ b/function_app/shared/config.py @@ -62,6 +62,7 @@ from azure.cosmos.agent_memory.thresholds import ( # noqa: E402 DEFAULT_DEDUP_EVERY_N, + DEFAULT_ENABLE_TURN_EMBEDDINGS, DEFAULT_FACT_EXTRACTION_EVERY_N, DEFAULT_PROCEDURAL_SYNTHESIS_AUTO, DEFAULT_THREAD_SUMMARY_EVERY_N, @@ -188,6 +189,18 @@ def get_procedural_synthesis_auto() -> bool: ) +def get_enable_turn_embeddings() -> bool: + """Embed raw turns on write so they can be vector-searched. + + Default ``false`` keeps turns un-embedded. The turns container always + carries the vector index, so enabling this never requires recreating it. + """ + return _parse_bool( + "ENABLE_TURN_EMBEDDINGS", + DEFAULT_ENABLE_TURN_EMBEDDINGS, + ) + + def get_cosmos_endpoint() -> str: """Return the Cosmos data-plane endpoint. diff --git a/function_app/shared/pipeline_factory.py b/function_app/shared/pipeline_factory.py index dc43fd0..1b881aa 100644 --- a/function_app/shared/pipeline_factory.py +++ b/function_app/shared/pipeline_factory.py @@ -73,7 +73,11 @@ def get_pipeline(): ContainerKey.MEMORIES: memories_container, ContainerKey.SUMMARIES: summaries_container, } - store = MemoryStore(containers=containers, embeddings_client=embeddings) + store = MemoryStore( + containers=containers, + embeddings_client=embeddings, + enable_turn_embeddings=config.get_enable_turn_embeddings(), + ) _pipeline = PipelineService( store, chat, diff --git a/infra/main.bicep b/infra/main.bicep index 1f4b7e0..b5b6731 100644 --- a/infra/main.bicep +++ b/infra/main.bicep @@ -63,7 +63,7 @@ param chatDeploymentName string = '' @description('Azure OpenAI REST API version pinned for both chat and embedding clients (SDK + function-app). Newer preview versions are required for strict JSON-schema response_format on gpt-5.x models.') param azureOpenAiApiVersion string = '2024-12-01-preview' -@description('Embedding output dimensions. MUST equal the dimensions configured on the Cosmos memories container vectorEmbeddingPolicy. text-embedding-3-large natively returns 3072; we set 1536 here so DiskANN (also 1536 in cosmos.bicep) can match. Change this only if you also change cosmos.bicep.') +@description('Embedding output dimensions. MUST equal the dimensions configured on the Cosmos memories container vectorEmbeddingPolicy. text-embedding-3-large natively returns 3072; we set 1536 here so the quantizedFlat vector indexes (also 1536 in cosmos.bicep) can match. Change this only if you also change cosmos.bicep.') param embeddingDimensions int = 1536 @description('Run thread-summary orchestration every N turns within a (user_id, thread_id). 0 = disabled.') @@ -88,6 +88,9 @@ param maxBatchSize int = 20 ]) param memoryProcessorOwner string = 'durable' +@description('Embed raw conversation turns on write so they can be vector-searched (search target="turns"). The turns container is always provisioned with a vector index, so this only toggles whether embeddings are generated. Default false.') +param enableTurnEmbeddings bool = false + // --- Naming --------------------------------------------------------------- var abbrs = loadJsonContent('./abbreviations.json') @@ -195,6 +198,7 @@ module functions 'modules/functions.bicep' = if (deployFunctionApp) { userSummaryEveryN: userSummaryEveryN maxBatchSize: maxBatchSize memoryProcessorOwner: memoryProcessorOwner + enableTurnEmbeddings: enableTurnEmbeddings tags: commonTags } } diff --git a/infra/main.parameters.json b/infra/main.parameters.json index 4207199..8e2ead3 100644 --- a/infra/main.parameters.json +++ b/infra/main.parameters.json @@ -64,6 +64,9 @@ }, "memoryProcessorOwner": { "value": "${MEMORY_PROCESSOR_OWNER=durable}" + }, + "enableTurnEmbeddings": { + "value": "${ENABLE_TURN_EMBEDDINGS=false}" } } } diff --git a/infra/modules/cosmos.bicep b/infra/modules/cosmos.bicep index e9e6043..7f5c07e 100644 --- a/infra/modules/cosmos.bicep +++ b/infra/modules/cosmos.bicep @@ -105,9 +105,6 @@ resource memoriesContainer 'Microsoft.DocumentDB/databaseAccounts/sqlDatabases/c } ] excludedPaths: [ - { - path: '/embedding/?' - } { path: '/source_memory_ids/*' } @@ -121,7 +118,7 @@ resource memoriesContainer 'Microsoft.DocumentDB/databaseAccounts/sqlDatabases/c vectorIndexes: [ { path: '/embedding' - type: 'diskANN' + type: 'quantizedFlat' } ] fullTextIndexes: [ @@ -193,9 +190,6 @@ resource memoriesTurnsContainer 'Microsoft.DocumentDB/databaseAccounts/sqlDataba } ] excludedPaths: [ - { - path: '/embedding/?' - } { path: '/source_memory_ids/*' } @@ -206,6 +200,39 @@ resource memoriesTurnsContainer 'Microsoft.DocumentDB/databaseAccounts/sqlDataba path: '/"_etag"/?' } ] + // Turns always carry the vector index so they are primed for search + // even when ENABLE_TURN_EMBEDDINGS is off. Vector policy is immutable + // at container creation, so priming here avoids a recreate later. + vectorIndexes: [ + { + path: '/embedding' + type: 'quantizedFlat' + } + ] + fullTextIndexes: [ + { + path: '/content' + } + ] + } + vectorEmbeddingPolicy: { + vectorEmbeddings: [ + { + path: '/embedding' + dataType: 'float32' + distanceFunction: 'cosine' + dimensions: embeddingDimensions + } + ] + } + fullTextPolicy: { + defaultLanguage: 'en-US' + fullTextPaths: [ + { + path: '/content' + language: 'en-US' + } + ] } } } diff --git a/infra/modules/functions.bicep b/infra/modules/functions.bicep index aad4f63..8725c19 100644 --- a/infra/modules/functions.bicep +++ b/infra/modules/functions.bicep @@ -92,6 +92,9 @@ param maxBatchSize int ]) param memoryProcessorOwner string +@description('Embed raw conversation turns on write so they can be vector-searched. The turns container is always provisioned with a vector index; this only toggles embedding generation. Default false.') +param enableTurnEmbeddings bool = false + @description('Tags to apply.') param tags object = {} @@ -329,6 +332,10 @@ resource functionApp 'Microsoft.Web/sites@2023-12-01' = { name: 'MEMORY_PROCESSOR_OWNER' value: memoryProcessorOwner } + { + name: 'ENABLE_TURN_EMBEDDINGS' + value: string(enableTurnEmbeddings) + } ] } } diff --git a/tests/unit/aio/store/test_memory_store.py b/tests/unit/aio/store/test_memory_store.py index c8b6ef9..7e3f7f2 100644 --- a/tests/unit/aio/store/test_memory_store.py +++ b/tests/unit/aio/store/test_memory_store.py @@ -466,3 +466,79 @@ async def test_search_fact_only_with_thread_id_uses_partition_path(): call_kwargs = memories.query_items.call_args.kwargs assert call_kwargs.get("partition_key") == ["u1", "t1"] + + +async def test_add_turn_skips_embedding_by_default(): + turns = MagicMock() + turns.upsert_item = AsyncMock() + embeddings = MagicMock() + embeddings.generate = AsyncMock(return_value=[0.1, 0.2]) + store = AsyncMemoryStore(containers=_containers(turns=turns), embeddings_client=embeddings) + + await store.add(user_id="u1", role="user", content="hello", thread_id="t1") + + embeddings.generate.assert_not_awaited() + body = turns.upsert_item.call_args.kwargs["body"] + assert "embedding" not in body + + +async def test_add_turn_embeds_when_enabled(): + turns = MagicMock() + turns.upsert_item = AsyncMock() + embeddings = MagicMock() + embeddings.generate = AsyncMock(return_value=[0.1, 0.2]) + store = AsyncMemoryStore( + containers=_containers(turns=turns), + embeddings_client=embeddings, + enable_turn_embeddings=True, + ) + + await store.add(user_id="u1", role="user", content="hello", thread_id="t1") + + embeddings.generate.assert_awaited_once_with("hello") + body = turns.upsert_item.call_args.kwargs["body"] + assert body["embedding"] == [0.1, 0.2] + + +async def test_push_embeds_turns_when_enabled(): + turns = MagicMock() + turns.upsert_item = AsyncMock() + embeddings = MagicMock() + embeddings.generate_batch = AsyncMock(return_value=[[0.1, 0.2]]) + local = [_doc(id="x1", type="turn", content="hello", thread_id="t1")] + store = AsyncMemoryStore( + containers=_containers(turns=turns), + embeddings_client=embeddings, + enable_turn_embeddings=True, + ) + + await store.push(local, batch_size=10) + + embeddings.generate_batch.assert_awaited_once_with(["hello"]) + body = turns.upsert_item.call_args.kwargs["body"] + assert body["embedding"] == [0.1, 0.2] + + +async def test_search_target_turns_queries_turns_container(): + turns = MagicMock() + turns.query_items.return_value = AsyncIterator([]) + memories = MagicMock() + memories.query_items.return_value = AsyncIterator([]) + embeddings = MagicMock() + embeddings.generate = AsyncMock(return_value=[0.1, 0.2]) + store = AsyncMemoryStore( + containers=_containers(turns=turns, memories=memories), + embeddings_client=embeddings, + ) + + await store.search(search_terms="hello", user_id="u1", thread_id="t1", target="turns") + + turns.query_items.assert_called_once() + memories.query_items.assert_not_called() + + +async def test_search_invalid_target_raises(): + store = AsyncMemoryStore(containers=_containers()) + + with pytest.raises(ValueError): + await store.search(search_terms="hello", user_id="u1", target="bogus") diff --git a/tests/unit/aio/test_cosmos_memory_client.py b/tests/unit/aio/test_cosmos_memory_client.py index eab2f84..c251c8a 100644 --- a/tests/unit/aio/test_cosmos_memory_client.py +++ b/tests/unit/aio/test_cosmos_memory_client.py @@ -360,8 +360,12 @@ async def test_create_memory_store_turns_container_uses_30_day_ttl(self): turns_call = mock_db.create_container_if_not_exists.await_args_list[1] assert turns_call.kwargs["id"] == "memories_turns" assert turns_call.kwargs["default_ttl"] == 2_592_000 - assert "vector_embedding_policy" not in turns_call.kwargs - assert "full_text_policy" not in turns_call.kwargs + # The turns container is always provisioned with a vector index + full-text + # policy so it is primed for search(target="turns") even when turn + # embeddings are disabled. Vector indexes use quantizedFlat. + assert "vector_embedding_policy" in turns_call.kwargs + assert "full_text_policy" in turns_call.kwargs + assert turns_call.kwargs["indexing_policy"]["vectorIndexes"][0]["type"] == "quantizedFlat" assert mem._turns_container_client is mock_turns_container async def test_create_memory_store_defaults_to_serverless(self): diff --git a/tests/unit/function_app/test_pipeline_factory.py b/tests/unit/function_app/test_pipeline_factory.py index 37b7cce..9ead078 100644 --- a/tests/unit/function_app/test_pipeline_factory.py +++ b/tests/unit/function_app/test_pipeline_factory.py @@ -101,6 +101,7 @@ def test_builds_pipeline_from_complete_env(mocks): mocks.store_ctor.assert_called_once_with( containers=expected_containers, embeddings_client=mocks.embed_instance, + enable_turn_embeddings=False, ) mocks.pipeline_ctor.assert_called_once_with( mocks.store_instance, diff --git a/tests/unit/store/test_memory_store.py b/tests/unit/store/test_memory_store.py index bef3d76..b5cdf3e 100644 --- a/tests/unit/store/test_memory_store.py +++ b/tests/unit/store/test_memory_store.py @@ -455,3 +455,90 @@ def test_search_fact_only_with_thread_id_uses_partition_path(): call_kwargs = memories.query_items.call_args.kwargs assert call_kwargs.get("partition_key") == ["u1", "t1"] assert "enable_cross_partition_query" not in call_kwargs + + +def test_add_turn_skips_embedding_by_default(): + turns = MagicMock() + embeddings = MagicMock() + embeddings.generate.return_value = [0.1, 0.2] + store = MemoryStore(containers=_containers(turns=turns), embeddings_client=embeddings) + + store.add(user_id="u1", role="user", content="hello", thread_id="t1") + + embeddings.generate.assert_not_called() + body = turns.upsert_item.call_args.kwargs["body"] + assert "embedding" not in body + + +def test_add_turn_embeds_when_enabled(): + turns = MagicMock() + embeddings = MagicMock() + embeddings.generate.return_value = [0.1, 0.2] + store = MemoryStore( + containers=_containers(turns=turns), + embeddings_client=embeddings, + enable_turn_embeddings=True, + ) + + store.add(user_id="u1", role="user", content="hello", thread_id="t1") + + embeddings.generate.assert_called_once_with("hello") + body = turns.upsert_item.call_args.kwargs["body"] + assert body["embedding"] == [0.1, 0.2] + + +def test_push_skips_turn_embedding_by_default(): + turns = MagicMock() + embeddings = MagicMock() + embeddings.generate_batch.return_value = [[0.1, 0.2]] + local = [_doc(id="x1", type="turn", content="hello", thread_id="t1")] + store = MemoryStore(containers=_containers(turns=turns), embeddings_client=embeddings) + + store.push(local, batch_size=10) + + embeddings.generate_batch.assert_not_called() + body = turns.upsert_item.call_args.kwargs["body"] + assert "embedding" not in body + + +def test_push_embeds_turns_when_enabled(): + turns = MagicMock() + embeddings = MagicMock() + embeddings.generate_batch.return_value = [[0.1, 0.2]] + local = [_doc(id="x1", type="turn", content="hello", thread_id="t1")] + store = MemoryStore( + containers=_containers(turns=turns), + embeddings_client=embeddings, + enable_turn_embeddings=True, + ) + + store.push(local, batch_size=10) + + embeddings.generate_batch.assert_called_once_with(["hello"]) + body = turns.upsert_item.call_args.kwargs["body"] + assert body["embedding"] == [0.1, 0.2] + + +def test_search_target_turns_queries_turns_container(): + turns = MagicMock() + turns.query_items.return_value = [] + memories = MagicMock() + memories.query_items.return_value = [] + embeddings = MagicMock() + embeddings.generate.return_value = [0.1, 0.2] + store = MemoryStore( + containers=_containers(turns=turns, memories=memories), + embeddings_client=embeddings, + ) + + store.search(search_terms="hello", user_id="u1", thread_id="t1", target="turns") + + turns.query_items.assert_called_once() + memories.query_items.assert_not_called() + + +def test_search_invalid_target_raises(): + store = MemoryStore(containers=_containers()) + + with pytest.raises(ValueError): + store.search(search_terms="hello", user_id="u1", target="bogus") diff --git a/tests/unit/test_container_routing.py b/tests/unit/test_container_routing.py index 3508b2c..4365d38 100644 --- a/tests/unit/test_container_routing.py +++ b/tests/unit/test_container_routing.py @@ -15,6 +15,7 @@ ContainerKey, container_key_for_type, container_keys_for_types, + resolve_search_target, ) from azure.cosmos.agent_memory._utils import VALID_TYPES @@ -110,3 +111,16 @@ def test_summary_types_both_route_to_summaries(self) -> None: assert container_keys_for_types(["thread_summary", "user_summary"]) == [ ContainerKey.SUMMARIES, ] + + +class TestResolveSearchTarget: + def test_memories_target_routes_to_memories_container(self) -> None: + assert resolve_search_target("memories") is ContainerKey.MEMORIES + + def test_turns_target_routes_to_turns_container(self) -> None: + assert resolve_search_target("turns") is ContainerKey.TURNS + + def test_unknown_target_raises_value_error(self) -> None: + with pytest.raises(ValueError) as exc_info: + resolve_search_target("summaries") + assert "summaries" in str(exc_info.value) diff --git a/tests/unit/test_cosmos_memory_client.py b/tests/unit/test_cosmos_memory_client.py index a991517..1c7e130 100644 --- a/tests/unit/test_cosmos_memory_client.py +++ b/tests/unit/test_cosmos_memory_client.py @@ -426,8 +426,12 @@ def test_create_memory_store_turns_container_uses_30_day_ttl(self): turns_call = mock_db.create_container_if_not_exists.call_args_list[1] assert turns_call.kwargs["id"] == "memories_turns" assert turns_call.kwargs["default_ttl"] == 2_592_000 - assert "vector_embedding_policy" not in turns_call.kwargs - assert "full_text_policy" not in turns_call.kwargs + # The turns container is always provisioned with a vector index + full-text + # policy so it is primed for search(target="turns") even when turn + # embeddings are disabled. Vector indexes use quantizedFlat. + assert "vector_embedding_policy" in turns_call.kwargs + assert "full_text_policy" in turns_call.kwargs + assert turns_call.kwargs["indexing_policy"]["vectorIndexes"][0]["type"] == "quantizedFlat" assert mem._turns_container_client is mock_turns_container def test_create_memory_store_defaults_to_serverless(self): diff --git a/tests/unit/test_thresholds.py b/tests/unit/test_thresholds.py index 77671b9..1add114 100644 --- a/tests/unit/test_thresholds.py +++ b/tests/unit/test_thresholds.py @@ -1,6 +1,13 @@ from __future__ import annotations -from azure.cosmos.agent_memory.thresholds import DEFAULT_TTL_BY_TYPE, default_ttl_for +import pytest + +from azure.cosmos.agent_memory.thresholds import ( + DEFAULT_ENABLE_TURN_EMBEDDINGS, + DEFAULT_TTL_BY_TYPE, + default_ttl_for, + get_enable_turn_embeddings, +) def test_default_ttl_table_values() -> None: @@ -25,3 +32,21 @@ def test_default_ttl_for_never_and_unknown_types() -> None: assert default_ttl_for("fact") is None assert default_ttl_for("procedural") is None assert default_ttl_for("unknown") is None + + +def test_enable_turn_embeddings_defaults_to_false(monkeypatch) -> None: + monkeypatch.delenv("ENABLE_TURN_EMBEDDINGS", raising=False) + assert DEFAULT_ENABLE_TURN_EMBEDDINGS is False + assert get_enable_turn_embeddings() is False + + +@pytest.mark.parametrize("raw", ["true", "True", "1", "yes", "on"]) +def test_enable_turn_embeddings_truthy_values(monkeypatch, raw) -> None: + monkeypatch.setenv("ENABLE_TURN_EMBEDDINGS", raw) + assert get_enable_turn_embeddings() is True + + +@pytest.mark.parametrize("raw", ["false", "False", "0", "no", "off"]) +def test_enable_turn_embeddings_falsy_values(monkeypatch, raw) -> None: + monkeypatch.setenv("ENABLE_TURN_EMBEDDINGS", raw) + assert get_enable_turn_embeddings() is False