Skip to content

Message

GenerateLLMRequest

Bases: BaseModel

Request body for LLM-only generation (EVE-Instruct v5, no RAG, no conversation).

Source code in routers/message.py
150
151
152
153
class GenerateLLMRequest(BaseModel):
    """Request body for LLM-only generation (EVE-Instruct v5, no RAG, no conversation)."""

    query: str = Field(..., description="User prompt to send to the LLM")

build_equivalence_sentence(usage_kg) async

Build an equivalence sentence for CO2 usage.

Source code in routers/message.py
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
async def build_equivalence_sentence(usage_kg: float) -> CO2EquivalenceResult:
    """Build an equivalence sentence for CO2 usage."""
    doc = await closest_with_direction(usage_kg)
    if not doc:
        raise RuntimeError("No comparison items in DB")
    base = doc.co2eq_kg
    title = doc.title

    if base == 0:
        text = f""
        return CO2EquivalenceResult(
            title=title,
            co2eq_kg=base,
            equivalent_count=None,
            text=text
        )

    count = max(1, int(round(usage_kg / base)))
    unit = pluralize(count, doc.unit_singular, doc.unit_plural)

    text = f"This is equivalent to: {count} {unit}"

    return CO2EquivalenceResult(
        title=title,
        co2eq_kg=base,
        equivalent_count=count,
        text=text
    )

closest_with_direction(usage_kg) async

Returns either lower bound or upper bound if lower is missing (usage smaller than min >0).

Source code in routers/message.py
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
async def closest_with_direction(usage_kg: float) -> Optional[CO2EquivalenceComparison]:
    """
    Returns either lower bound or upper bound if lower is missing (usage smaller than min >0).
    """
    try:
        lower = await get_lower_bound(usage_kg)
        if lower and lower.co2eq_kg > 0:
            return lower

        col = CO2EQComparison.get_collection()
        upper_doc = await col.find_one(
            {"enabled": True, "co2eq_kg": {"$gt": usage_kg}},
            sort=[("co2eq_kg", 1)],
            projection={"_id": 0}
        )
        if upper_doc:
            return CO2EquivalenceComparison(**upper_doc)
        return None
    except Exception as e:
        logger.error(f"Error in closest_with_direction: {str(e)}", exc_info=True)
        return None

create_message(request, conversation_id, background_tasks, requesting_user=Depends(get_current_user)) async

Create a new message in a conversation and generate an answer.

Validates conversation ownership, normalizes requested public collections, persists a placeholder Message, runs generation, updates the message with answer and retrieval metadata, and schedules rollup/trimming of history.

Parameters:

Name Type Description Default
request GenerationRequest

Generation parameters including query, collections, and model settings.

required
conversation_id str

Target conversation identifier.

required
background_tasks BackgroundTasks

Background task runner used to schedule rollups.

required
requesting_user User

Authenticated user injected by dependency.

Depends(get_current_user)

Returns:

Type Description
CreateMessageResponse

Message id, query, answer, documents, flags, and metadata.

Raises:

Type Description
HTTPException

404 if conversation is not found; 403 if ownership/collections invalid; 500 for server errors.

Source code in routers/message.py
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
@router.post(
    "/conversations/{conversation_id}/messages", response_model=CreateMessageResponse
)
async def create_message(
    request: GenerationRequest,
    conversation_id: str,
    background_tasks: BackgroundTasks,
    requesting_user: User = Depends(get_current_user),
) -> CreateMessageResponse:
    """
    Create a new message in a conversation and generate an answer.

    Validates conversation ownership, normalizes requested public collections, persists a placeholder `Message`, runs generation, updates the message with answer and retrieval metadata, and schedules rollup/trimming of history.

    Args:
        request (GenerationRequest): Generation parameters including query, collections, and model settings.
        conversation_id (str): Target conversation identifier.
        background_tasks (BackgroundTasks): Background task runner used to schedule rollups.
        requesting_user (User): Authenticated user injected by dependency.

    Returns:
        Message id, query, answer, documents, flags, and metadata.

    Raises:
        HTTPException: 404 if conversation is not found; 403 if ownership/collections invalid; 500 for server errors.
    """
    set_user_context(requesting_user.id)
    set_conversation_context(conversation_id)

    message = None
    try:
        conversation = await Conversation.find_by_id(conversation_id)
        if not conversation:
            raise HTTPException(status_code=404, detail="Conversation not found")

        if conversation.user_id != requesting_user.id:
            raise HTTPException(
                status_code=403,
                detail="You are not allowed to add a message to this conversation",
            )

        # Normalize and validate requested public collections against allowed lists
        allowed_source = PUBLIC_COLLECTIONS if IS_PROD else STAGING_PUBLIC_COLLECTIONS
        try:
            allowed_names = {
                item.get("name")
                for item in (allowed_source + WILEY_PUBLIC_COLLECTIONS)
                if isinstance(item, dict) and item.get("name")
            }
        except Exception:
            allowed_names = set()

        public_collections = [
            n for n in request.public_collections if n in allowed_names
        ]
        request.public_collections = public_collections

        # lookup query to check if some of the collection ids from other users are in the request.collection_ids
        other_users_collections = await CollectionModel.find_all(
            filter_dict={
                "id": {"$in": request.public_collections},
                "user_id": {"$ne": requesting_user.id},
            }
        )

        if len(other_users_collections) > 0:
            raise HTTPException(
                status_code=403,
                detail="You are not allowed to use collections from other users",
            )

        request.collection_ids = request.collection_ids + request.public_collections

        # All user collections are used by default
        user_collections = await CollectionModel.find_all(
            filter_dict={"user_id": requesting_user.id}
        )

        request.private_collections_map = {c.id: c.name for c in user_collections}
        if len(user_collections) > 0:
            request.collection_ids = request.collection_ids + [
                c.id for c in user_collections
            ]
        # remove "Wiley AI Gateway" from collection_ids
        request.collection_ids = [
            c for c in request.collection_ids if c != "Wiley AI Gateway"
        ]
        logger.info(f"Collection IDs: {request.collection_ids}")

        # Extract year range from filters for MCP usage
        try:
            request.year = extract_year_range_from_filters(request.filters)
        except Exception:
            request.year = None

        message = await Message.create(
            conversation_id=conversation_id,
            input=request.query,
            output="",
            documents=[],
            use_rag=False,
            request_input=request,
            metadata={},
        )

        set_message_context(message.id)

        answer, results, is_rag, latencies, prompts, retrieved_docs = (
            await generate_answer(request, conversation_id=conversation_id)
        )

        documents_data = []
        if results:
            documents_data = [extract_document_data(result) for result in results]

        message.output = answer
        message.documents = documents_data
        message.use_rag = is_rag
        existing_metadata = dict(getattr(message, "metadata", {}) or {})
        existing_metadata.update(
            {
                "latencies": latencies,
                "prompts": prompts,
                "retrieved_docs": retrieved_docs,
            }
        )
        message.metadata = existing_metadata
        await message.save()

        # Schedule rollup as background task to avoid blocking response
        background_tasks.add_task(maybe_rollup_and_trim_history, conversation_id)

        return {
            "id": message.id,
            "query": request.query,
            "answer": answer,
            "documents": documents_data,
            "use_rag": is_rag,
            "conversation_id": conversation_id,
            "collection_ids": request.collection_ids,
            "metadata": {
                "latencies": latencies,
            },
        }
    except HTTPException as http_exc:
        if message:
            try:
                existing_metadata = dict(getattr(message, "metadata", {}) or {})
                existing_metadata["error"] = str(getattr(http_exc, "detail", http_exc))
                message.metadata = existing_metadata
                await message.save()
            except Exception:
                pass
        raise
    except Exception as e:
        if message:
            try:
                existing_metadata = dict(getattr(message, "metadata", {}) or {})
                existing_metadata["error"] = str(e)
                message.metadata = existing_metadata
                await message.save()
            except Exception:
                pass
        raise HTTPException(status_code=500, detail=f"Server error: {str(e)}")

create_message_stream(request, conversation_id, background_tasks, requesting_user=Depends(get_current_user)) async

Create a new message and stream generation via Server-Sent Events (SSE).

Sets up a per-message stream bus and runs generation in a decoupled task. Yields SSE-formatted chunks including status updates, tokens, and final payloads.

Parameters:

Name Type Description Default
request GenerationRequest

Generation parameters including query, collections, and model settings.

required
conversation_id str

Target conversation identifier.

required
background_tasks BackgroundTasks

Background task runner used to schedule rollups.

required
requesting_user User

Authenticated user injected by dependency.

Depends(get_current_user)

Returns:

Type Description
StreamingResponse

SSE stream for the generation lifecycle.

Raises:

Type Description
HTTPException

404 if conversation is not found; 403 if ownership/collections invalid; 500 for server errors.

Source code in routers/message.py
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
@router.post(
    "/conversations/{conversation_id}/stream_messages",
    response_class=StreamingResponse,
)
async def create_message_stream(
    request: GenerationRequest,
    conversation_id: str,
    background_tasks: BackgroundTasks,
    requesting_user: User = Depends(get_current_user),
) -> StreamingResponse:
    """
    Create a new message and stream generation via Server-Sent Events (SSE).

    Sets up a per-message stream bus and runs generation in a decoupled task. Yields SSE-formatted chunks including status updates, tokens, and final payloads.

    Args:
        request (GenerationRequest): Generation parameters including query, collections, and model settings.
        conversation_id (str): Target conversation identifier.
        background_tasks (BackgroundTasks): Background task runner used to schedule rollups.
        requesting_user (User): Authenticated user injected by dependency.

    Returns:
        SSE stream for the generation lifecycle.

    Raises:
        HTTPException: 404 if conversation is not found; 403 if ownership/collections invalid; 500 for server errors.
    """
    set_user_context(requesting_user.id)
    set_conversation_context(conversation_id)

    message = None
    try:
        conversation = await Conversation.find_by_id(conversation_id)
        if not conversation:
            raise HTTPException(status_code=404, detail="Conversation not found")

        if conversation.user_id != requesting_user.id:
            raise HTTPException(
                status_code=403,
                detail="You are not allowed to add a message to this conversation",
            )

        # Normalize and validate requested public collections against allowed lists
        allowed_source = PUBLIC_COLLECTIONS if IS_PROD else STAGING_PUBLIC_COLLECTIONS
        try:
            allowed_names = {
                item.get("name")
                for item in (allowed_source + WILEY_PUBLIC_COLLECTIONS)
                if isinstance(item, dict) and item.get("name")
            }
        except Exception:
            allowed_names = set()

        public_collections = [
            n for n in request.public_collections if n in allowed_names
        ]
        request.public_collections = public_collections

        # lookup query to check if some of the collection ids from other users are in the request.collection_ids
        other_users_collections = await CollectionModel.find_all(
            filter_dict={
                "id": {"$in": request.public_collections},
                "user_id": {"$ne": requesting_user.id},
            }
        )

        if len(other_users_collections) > 0:
            raise HTTPException(
                status_code=403,
                detail="You are not allowed to use collections from other users",
            )

        request.collection_ids = request.collection_ids + request.public_collections

        # All user collections are used by default
        user_collections = await CollectionModel.find_all(
            filter_dict={"user_id": requesting_user.id}
        )

        request.private_collections_map = {c.id: c.name for c in user_collections}
        if len(user_collections) > 0:
            request.collection_ids = request.collection_ids + [
                c.id for c in user_collections
            ]
        # remove "Wiley AI Gateway" from collection_ids
        request.collection_ids = [
            c for c in request.collection_ids if c != "Wiley AI Gateway"
        ]
        logger.info(f"Collection IDs: {request.collection_ids}")

        # Extract year range from filters for MCP usage
        try:
            request.year = extract_year_range_from_filters(request.filters)
        except Exception:
            request.year = None

        message = await Message.create(
            conversation_id=conversation_id,
            input=request.query,
            output="",
            documents=[],
            use_rag=False,
            request_input=request,
            metadata={},
        )

        set_message_context(message.id)

        # Start decoupled background job that publishes to bus
        cancel_mgr = get_cancel_manager()
        cancel_event = cancel_mgr.create(message.id)
        cancel_mgr.link_conversation(conversation_id, message.id)
        gen_task = asyncio.create_task(
            run_generation_to_bus(
                request=request,
                conversation_id=conversation_id,
                message_id=message.id,
                background_tasks=background_tasks,
                cancel_event=cancel_event,
            )
        )
        cancel_mgr.set_task(message.id, gen_task)

        bus = get_stream_bus()

        async def _gen():
            # Optional catch-up from currently saved output (usually empty right after create)
            try:
                if message.output:
                    yield f"data: {json.dumps({'type':'partial','content': message.output})}\n\n"
            except Exception:
                pass
            async for data in bus.subscribe(message.id):
                yield data

        response = StreamingResponse(_gen(), media_type="text/event-stream")
        # Set SSE-friendly headers to prevent proxy/client reconnect loops
        response.headers["Cache-Control"] = "no-cache"
        response.headers["Connection"] = "keep-alive"
        response.headers["X-Accel-Buffering"] = "no"  # Nginx buffering off if present
        return response
    except HTTPException as http_exc:
        if message:
            error_logger = get_error_logger()
            await error_logger.log_error_sync(
                error=http_exc,
                component=Component.ROUTER,
                pipeline_stage=PipelineStage.ROUTER,
                description="HTTPException in create_message_stream",
                error_type=type(http_exc).__name__,
            )
        raise http_exc
    except Exception as e:
        if message:
            error_logger = get_error_logger()
            await error_logger.log_error_sync(
                error=e,
                component=Component.ROUTER,
                pipeline_stage=PipelineStage.ROUTER,
                description="Exception in create_message_stream",
                error_type=type(e).__name__,
            )
        raise HTTPException(status_code=500, detail=f"Server error: {str(e)}")

generate(request, requesting_user=Depends(get_current_user)) async

Run a one-off generation (testing only) and return the full answer and metadata.

Normalizes and validates requested public collections against allowed lists, ensures the user does not reference other users' collections, merges the user's collections and public collections (excluding "Wiley AI Gateway"), extracts year range from filters, then runs the full generation pipeline via generate_answer and returns the answer, documents, RAG flag, latencies, prompts, and retrieved docs.

Parameters:

Name Type Description Default
request GenerationRequest

Generation parameters including query, collections, and model settings.

required
requesting_user User

Authenticated user injected by dependency.

Depends(get_current_user)

Returns:

Type Description
dict

Dictionary containing: - answer: Generated answer text. - documents: Extracted document data from retrieval results. - use_rag: Whether RAG was used for this generation. - latencies: Timing information for pipeline steps. - prompts: Prompt data from generation. - retrieved_docs: Raw retrieved documents from RAG.

Raises:

Type Description
HTTPException

403 if the request references collections owned by other users.

HTTPException

500 for server errors during generation.

Source code in routers/message.py
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
@router.post("/generate")
async def generate(
    request: GenerationRequest,
    requesting_user: User = Depends(get_current_user),
) -> dict:
    """
    Run a one-off generation (testing only) and return the full answer and metadata.

    Normalizes and validates requested public collections against allowed lists,
    ensures the user does not reference other users' collections, merges the user's
    collections and public collections (excluding "Wiley AI Gateway"), extracts year
    range from filters, then runs the full generation pipeline via generate_answer
    and returns the answer, documents, RAG flag, latencies, prompts, and retrieved docs.

    Args:
        request (GenerationRequest): Generation parameters including query, collections,
            and model settings.
        requesting_user (User): Authenticated user injected by dependency.

    Returns:
        Dictionary containing:
            - answer: Generated answer text.
            - documents: Extracted document data from retrieval results.
            - use_rag: Whether RAG was used for this generation.
            - latencies: Timing information for pipeline steps.
            - prompts: Prompt data from generation.
            - retrieved_docs: Raw retrieved documents from RAG.

    Raises:
        HTTPException: 403 if the request references collections owned by other users.
        HTTPException: 500 for server errors during generation.
    """
    message = None
    try:
        # Normalize and validate requested public collections against allowed lists
        allowed_source = PUBLIC_COLLECTIONS if IS_PROD else STAGING_PUBLIC_COLLECTIONS
        try:
            allowed_names = {
                item.get("name")
                for item in (allowed_source + WILEY_PUBLIC_COLLECTIONS)
                if isinstance(item, dict) and item.get("name")
            }
        except Exception:
            allowed_names = set()

        public_collections = [
            n for n in request.public_collections if n in allowed_names
        ]
        request.public_collections = public_collections

        # lookup query to check if some of the collection ids from other users are in the request.collection_ids
        other_users_collections = await CollectionModel.find_all(
            filter_dict={
                "id": {"$in": request.public_collections},
                "user_id": {"$ne": requesting_user.id},
            }
        )

        if len(other_users_collections) > 0:
            raise HTTPException(
                status_code=403,
                detail="You are not allowed to use collections from other users",
            )

        request.collection_ids = request.collection_ids + request.public_collections

        # All user collections are used by default
        user_collections = await CollectionModel.find_all(
            filter_dict={"user_id": requesting_user.id}
        )

        request.private_collections_map = {c.id: c.name for c in user_collections}
        if len(user_collections) > 0:
            request.collection_ids = request.collection_ids + [
                c.id for c in user_collections
            ]
        # remove "Wiley AI Gateway" from collection_ids
        request.collection_ids = [
            c for c in request.collection_ids if c != "Wiley AI Gateway"
        ]
        logger.info(f"Collection IDs: {request.collection_ids}")

        # Extract year range from filters for MCP usage
        try:
            request.year = extract_year_range_from_filters(request.filters)
        except Exception:
            request.year = None

        answer, results, is_rag, latencies, prompts, retrieved_docs = (
            await generate_answer(request)
        )

        documents_data = []
        if results:
            documents_data = [extract_document_data(result) for result in results]

        return {
            "answer": answer,
            "documents": documents_data,
            "use_rag": is_rag,
            "latencies": latencies,
            "prompts": prompts,
            "retrieved_docs": retrieved_docs,
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Server error: {str(e)}")

generate_llm(request) async

Call EVE-Instruct (v5) with a single query. No RAG, no conversation context.

Body: query. Returns the model reply only.

Source code in routers/message.py
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
@router.post("/generate-llm")
async def generate_llm(
    request: GenerateLLMRequest,
) -> dict:
    """
    Call EVE-Instruct (v5) with a single query. No RAG, no conversation context.

    Body: query. Returns the model reply only.
    """
    try:
        llm_manager = get_shared_llm_manager()
        llm = llm_manager.get_client_for_model("eve_v05")
        messages = [HumanMessage(content=request.query)]
        response = await llm.ainvoke(messages)
        content = getattr(response, "content", str(response))
        return {"answer": content}
    except Exception as e:
        logger.exception("generate_llm failed: %s", e)
        raise HTTPException(status_code=500, detail=str(e))

get_average_latencies(start_date=None, end_date=None) async

Return average latencies aggregated across all messages.

Optionally filters the aggregation by a timestamp window.

Parameters:

Name Type Description Default
start_date datetime | None

Optional start of the time window (inclusive).

None
end_date datetime | None

Optional end of the time window (inclusive).

None

Returns:

Type Description
dict

Mapping of latency metric name to average value.

Raises:

Type Description
HTTPException

500 for server errors during aggregation.

Source code in routers/message.py
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
@router.get("/conversations/messages/average-latencies")
async def get_average_latencies(
    start_date: datetime | None = None, end_date: datetime | None = None
) -> dict:
    """
    Return average latencies aggregated across all messages.

    Optionally filters the aggregation by a timestamp window.

    Args:
        start_date (datetime | None): Optional start of the time window (inclusive).
        end_date (datetime | None): Optional end of the time window (inclusive).

    Returns:
        Mapping of latency metric name to average value.

    Raises:
        HTTPException: 500 for server errors during aggregation.
    """
    try:
        messages_col = Message.get_collection()
        pipeline = []
        if start_date is not None or end_date is not None:
            time_filter = {}
            if start_date is not None:
                time_filter["$gte"] = start_date
            if end_date is not None:
                time_filter["$lte"] = end_date
            pipeline.append({"$match": {"timestamp": time_filter}})
        pipeline.append(
            {
                "$group": {
                    "_id": None,
                    "rag_decision_latency": {
                        "$avg": "$metadata.latencies.rag_decision_latency"
                    },
                    "query_embedding_latency": {
                        "$avg": "$metadata.latencies.query_embedding_latency"
                    },
                    "qdrant_retrieval_latency": {
                        "$avg": "$metadata.latencies.qdrant_retrieval_latency"
                    },
                    "mcp_retrieval_latency": {
                        "$avg": "$metadata.latencies.mcp_retrieval_latency"
                    },
                    "reranking_latency": {
                        "$avg": "$metadata.latencies.reranking_latency"
                    },
                    "first_token_latency": {
                        "$avg": "$metadata.latencies.first_token_latency"
                    },
                    "mistral_first_token_latency": {
                        "$avg": "$metadata.latencies.mistral_first_token_latency"
                    },
                    "base_generation_latency": {
                        "$avg": "$metadata.latencies.base_generation_latency"
                    },
                }
            }
        )
        cursor = messages_col.aggregate(pipeline, allowDiskUse=True)
        results = await cursor.to_list(length=1)
        return results[0]
    except HTTPException:
        raise
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Server error: {str(e)}")

get_lower_bound(usage_kg) async

Find the largest co2eq_kg <= usage_kg.

Source code in routers/message.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
async def get_lower_bound(usage_kg: float) -> Optional[CO2EquivalenceComparison]:
    """Find the largest co2eq_kg <= usage_kg."""
    try:
        col = CO2EQComparison.get_collection()

        query = {"enabled": True, "co2eq_kg": {"$lte": usage_kg}}
        doc = await col.find_one(
            query,
            sort=[("co2eq_kg", -1)],
            projection={"_id": 0}
        )

        if doc:
            return CO2EquivalenceComparison(**doc)
        return None
    except Exception as e:
        logger.error(f"Error in get_lower_bound: {str(e)}", exc_info=True)
        return None

get_my_message_stats(requesting_user=Depends(get_current_user)) async

Return counts and character totals for the current user's messages.

Aggregates across all messages belonging to conversations owned by the user.

Parameters:

Name Type Description Default
requesting_user User

Authenticated user injected by dependency.

Depends(get_current_user)

Returns:

Type Description
dict

Aggregated stats including counts and character sums.

Raises:

Type Description
HTTPException

500 for server errors during aggregation.

Source code in routers/message.py
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
@router.get("/conversations/messages/me/stats")
async def get_my_message_stats(requesting_user: User = Depends(get_current_user)) -> dict:
    """
    Return counts and character totals for the current user's messages.

    Aggregates across all messages belonging to conversations owned by the user.

    Args:
        requesting_user (User): Authenticated user injected by dependency.

    Returns:
        Aggregated stats including counts and character sums.

    Raises:
        HTTPException: 500 for server errors during aggregation.
    """
    try:
        messages_col = Message.get_collection()

        # Fetch conversation IDs owned by the current user (avoid $lookup pipelines unsupported by DocumentDB)
        user_conversations = await Conversation.find_all(
            filter_dict={"user_id": requesting_user.id}
        )
        conversation_ids = [c.id for c in user_conversations if getattr(c, "id", None)]

        if not conversation_ids:
            return {
                "message_count": 0,
                "input_characters": 0,
                "output_characters": 0,
                "total_characters": 0,
                "co2eq_kg": 0.0,
                "text": "",
            }

        pipeline = [
            {"$match": {"conversation_id": {"$in": conversation_ids}}},
            {
                "$group": {
                    "_id": None,
                    "message_count": {"$sum": 1},
                    "input_characters": {
                        "$sum": {
                            "$strLenCP": {
                                "$ifNull": [
                                    "$metadata.prompts.generation_prompt",
                                    "",
                                ]
                            }
                        }
                    },
                    "output_characters": {
                        "$sum": {"$strLenCP": {"$ifNull": ["$output", ""]}}
                    },
                }
            },
            {
                "$project": {
                    "_id": 0,
                    "message_count": 1,
                    "input_characters": 1,
                    "output_characters": 1,
                    "total_characters": {
                        "$add": ["$input_characters", "$output_characters"],
                    },
                }
            },
        ]

        cursor = messages_col.aggregate(pipeline, allowDiskUse=True)
        results = await cursor.to_list(length=1)

        if results:
            stats = results[0]
        else:
            stats = {
                "message_count": 0,
                "input_characters": 0,
                "output_characters": 0,
                "total_characters": 0,
            }


        total_chars = stats.get("total_characters", 0)
        usage_kg = get_co2_usage_kg(total_chars=total_chars)

        # Get CO2 equivalence data
        text = ""
        try:
            equivalence_data = await build_equivalence_sentence(usage_kg)
            text = equivalence_data.text
        except Exception as e:
            logger.error(f"Failed to get CO2 equivalence data: {str(e)}", exc_info=True)

        # Add CO2 data to response
        stats["co2eq_kg"] = usage_kg
        stats["text"] = text

        return stats
    except HTTPException:
        raise
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Server error: {str(e)}")

get_source_logs(conversation_id, message_id, request, requesting_user=Depends(get_current_user)) async

Append a source log entry to a message's metadata.

Stores user-attributed source inspection information such as id, url, title, and collection name, with a server-side timestamp.

Parameters:

Name Type Description Default
conversation_id str

Conversation identifier.

required
message_id str

Message identifier.

required
request SourceLogsRequest

Source log details to append.

required
requesting_user User

Authenticated user injected by dependency.

Depends(get_current_user)

Returns:

Type Description
dict

Confirmation message upon successful append.

Raises:

Type Description
HTTPException

404 if conversation/message not found or mismatched; 500 for server errors.

Source code in routers/message.py
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
@router.post("/conversations/{conversation_id}/messages/{message_id}/source_logs")
async def get_source_logs(
    conversation_id: str,
    message_id: str,
    request: SourceLogsRequest,
    requesting_user: User = Depends(get_current_user),
) -> dict:
    """
    Append a source log entry to a message's metadata.

    Stores user-attributed source inspection information such as id, url, title, and collection name, with a server-side timestamp.

    Args:
        conversation_id (str): Conversation identifier.
        message_id (str): Message identifier.
        request (SourceLogsRequest): Source log details to append.
        requesting_user (User): Authenticated user injected by dependency.

    Returns:
        Confirmation message upon successful append.

    Raises:
        HTTPException: 404 if conversation/message not found or mismatched; 500 for server errors.
    """
    try:
        conversation = await Conversation.find_by_id(conversation_id)
        if not conversation:
            raise HTTPException(status_code=404, detail="Conversation not found")

        message = await Message.find_by_id(message_id)
        if not message:
            raise HTTPException(status_code=404, detail="Message not found")

        if message.conversation_id != conversation_id:
            raise HTTPException(
                status_code=404, detail="Message not found in this conversation"
            )

        # store source logs as an array and append each new entry
        existing_metadata = dict(getattr(message, "metadata", {}) or {})
        source_logs = list(existing_metadata.get("source_logs") or [])
        source_logs.append(
            {
                **request.model_dump(),
                "timestamp": datetime.now().isoformat(),
                "user_id": requesting_user.id,
            }
        )
        existing_metadata["source_logs"] = source_logs
        message.metadata = existing_metadata
        await message.save()
        return {"message": "Source logs stored successfully"}
    except HTTPException:
        raise
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Server error: {str(e)}")

hallucination_detect(conversation_id, message_id, requesting_user=Depends(get_current_user)) async

Detect and persist hallucination analysis for a message.

Runs a multi-step pipeline (detect, optionally rewrite, retrieve, answer) and stores the result and latency breakdown on the message metadata.

Parameters:

Name Type Description Default
conversation_id str

Conversation identifier.

required
message_id str

Message identifier to analyze.

required
requesting_user User

Authenticated user injected by dependency.

Depends(get_current_user)

Returns:

Type Description
HallucinationDetectResponse

Structured hallucination analysis with optional final answer.

Raises:

Type Description
HTTPException

404 if conversation/message not found or mismatched; 403 if ownership invalid; 500 for server errors.

Source code in routers/message.py
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
@router.post(
    "/conversations/{conversation_id}/messages/{message_id}/hallucination",
    response_model=HallucinationDetectResponse,
)
async def hallucination_detect(
    conversation_id: str,
    message_id: str,
    requesting_user: User = Depends(get_current_user),
) -> HallucinationDetectResponse:
    """
    Detect and persist hallucination analysis for a message.

    Runs a multi-step pipeline (detect, optionally rewrite, retrieve, answer) and stores the result and latency breakdown on the message metadata.

    Args:
        conversation_id (str): Conversation identifier.
        message_id (str): Message identifier to analyze.
        requesting_user (User): Authenticated user injected by dependency.

    Returns:
        Structured hallucination analysis with optional final answer.

    Raises:
        HTTPException: 404 if conversation/message not found or mismatched; 403 if ownership invalid; 500 for server errors.
    """
    # Validate conversation ownership and message relationship
    conversation = await Conversation.find_by_id(conversation_id)
    if not conversation:
        raise HTTPException(status_code=404, detail="Conversation not found")
    if conversation.user_id != requesting_user.id:
        raise HTTPException(
            status_code=403,
            detail="You are not allowed to access this conversation",
        )

    message = await Message.find_by_id(message_id)
    if not message:
        raise HTTPException(status_code=404, detail="Message not found")
    if message.conversation_id != conversation_id:
        raise HTTPException(
            status_code=400, detail="Message does not belong to this conversation"
        )

    try:
        total_start = time.perf_counter()
        detector = HallucinationDetector()

        (
            label,
            reason,
            _orig_q,
            rewritten_question,
            final_answer,
            latencies,
        ) = await detector.run(
            query=message.input,
            model_response=message.output,
            docs=build_context(message.documents),
            llm_type=message.request_input.llm_type,
        )
        total_latency = time.perf_counter() - total_start

        # Persist hallucination result to Message
        try:
            existing_metadata = dict(getattr(message, "metadata", {}) or {})
            hallucination_payload = {
                "label": label,
                "reason": reason,
                "rewritten_question": rewritten_question,
                "final_answer": final_answer,
                "latencies": {
                    "detect": latencies.get("detect") if latencies else None,
                    "rewrite": latencies.get("rewrite") if latencies else None,
                    "final_answer": (
                        latencies.get("final_answer") if latencies else None
                    ),
                    "total": total_latency,
                },
            }
            existing_metadata["hallucination"] = hallucination_payload
            message.metadata = existing_metadata
            await message.save()
        except HTTPException:
            raise
        except Exception as e:
            logger.error(f"Failed to update Message with hallucination result: {e}")

        return HallucinationDetectResponse(
            label=label,
            reason=reason,
            original_question=message.input,
            rewritten_question=rewritten_question,
            final_answer=final_answer,
            latencies=latencies,
        )

    except HTTPException:
        raise
    except Exception as e:
        logger.error(f"Hallucination detection failed: {e}", exc_info=True)
        raise HTTPException(status_code=500, detail=f"Server error: {str(e)}")

retrieve(request) async

Run the entire retrieval pipeline and return all documents.

Runs the requery/rewrite step (same as generate_answer) to refine the query for retrieval, then executes the RAG retrieval pipeline using setup_rag_and_context and returns all retrieved documents.

Parameters:

Name Type Description Default
request GenerationRequest

Generation parameters including query, collections, and model settings.

required
requesting_user User

Authenticated user injected by dependency.

required

Returns:

Type Description
dict

Dictionary containing: - retrieved_docs: All formatted documents from the retrieval pipeline - latencies: Timing information (includes rewrite and retrieval operations) - original_query: The query as sent in the request - requery: The rewritten query used for retrieval (or original if rewrite skipped/failed)

Source code in routers/message.py
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
@router.post("/retrieve")
async def retrieve(
    request: GenerationRequest,
    # requesting_user: User = Depends(get_current_user),
) -> dict:
    """
    Run the entire retrieval pipeline and return all documents.

    Runs the requery/rewrite step (same as generate_answer) to refine the query
    for retrieval, then executes the RAG retrieval pipeline using
    setup_rag_and_context and returns all retrieved documents.

    Args:
        request (GenerationRequest): Generation parameters including query, collections, and model settings.
        requesting_user (User): Authenticated user injected by dependency.

    Returns:
        Dictionary containing:
            - retrieved_docs: All formatted documents from the retrieval pipeline
            - latencies: Timing information (includes rewrite and retrieval operations)
            - original_query: The query as sent in the request
            - requery: The rewritten query used for retrieval (or original if rewrite skipped/failed)
    """
    try:
        allowed_source = PUBLIC_COLLECTIONS if IS_PROD else STAGING_PUBLIC_COLLECTIONS
        try:
            allowed_names = {
                item.get("name")
                for item in (allowed_source + WILEY_PUBLIC_COLLECTIONS)
                if isinstance(item, dict) and item.get("name")
            }
        except Exception:
            allowed_names = set()

        public_collections = [
            n for n in request.public_collections if n in allowed_names
        ]
        request.public_collections = public_collections

        # other_users_collections = await CollectionModel.find_all(
        #     filter_dict={
        #         "id": {"$in": request.public_collections},
        #         "user_id": {"$ne": requesting_user.id},
        #     }
        # )

        # if len(other_users_collections) > 0:
        #     raise HTTPException(
        #         status_code=403,
        #         detail="You are not allowed to use collections from other users",
        #     )

        request.collection_ids = request.collection_ids + request.public_collections

        # user_collections = await CollectionModel.find_all(
        #     filter_dict={"user_id": requesting_user.id}
        # )

        # request.private_collections_map = {c.id: c.name for c in user_collections}
        # if len(user_collections) > 0:
        #     request.collection_ids = request.collection_ids + [
        #         c.id for c in user_collections
        #     ]
        request.collection_ids = [
            c for c in request.collection_ids if c != "Wiley AI Gateway"
        ]
        logger.info(f"Collection IDs: {request.collection_ids}")

        try:
            request.year = extract_year_range_from_filters(request.filters)
        except Exception:
            request.year = None

        original_query = request.query
        rewrite_latency = None
        requery = None
        try:
            t_rewrite = time.perf_counter()
            llm_manager = get_shared_llm_manager()
            rag_decision_result, _rag_prompt, _ = await should_use_rag(
                llm_manager,
                request.query,
                conversation="",
                llm_type=request.llm_type,
            )
            rewrite_latency = time.perf_counter() - t_rewrite
            if rag_decision_result and getattr(rag_decision_result, "requery", None):
                requery = rag_decision_result.requery
                request.query = requery
        except Exception as e:
            logger.warning(f"Requery/rewrite failed in /retrieve, using original query: {e}")
            requery = None

        _context, _results, latencies, formated_results = await setup_rag_and_context(
            request
        )

        if rewrite_latency is not None:
            latencies = dict(latencies) if latencies else {}
            latencies["rewrite"] = rewrite_latency

        return {
            "retrieved_docs": formated_results,
            "latencies": latencies,
            "original_query": original_query,
            "requery": requery or original_query,
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Server error: {str(e)}")

retry(conversation_id, message_id, background_tasks, requesting_user=Depends(get_current_user)) async

Retry generation for an existing message.

Re-validates conversation ownership and message relationship, reuses the original request_input stored on the message, regenerates the answer, and updates message content, documents, and metadata.

Parameters:

Name Type Description Default
conversation_id str

Conversation identifier.

required
message_id str

Message identifier to retry.

required
background_tasks BackgroundTasks

Background task runner used to schedule rollups.

required
requesting_user User

Authenticated user injected by dependency.

Depends(get_current_user)

Returns:

Type Description
dict

Response payload mirroring create_message with updated answer and metadata.

Raises:

Type Description
HTTPException

404 if conversation/message not found; 403 if ownership invalid; 400 if message cannot be retried; 500 for server errors.

Source code in routers/message.py
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
@router.post("/conversations/{conversation_id}/messages/{message_id}/retry")
async def retry(
    conversation_id: str,
    message_id: str,
    background_tasks: BackgroundTasks,
    requesting_user: User = Depends(get_current_user),
) -> dict:
    """
    Retry generation for an existing message.

    Re-validates conversation ownership and message relationship, reuses the original `request_input` stored on the message, regenerates the answer, and updates message content, documents, and metadata.

    Args:
        conversation_id (str): Conversation identifier.
        message_id (str): Message identifier to retry.
        background_tasks (BackgroundTasks): Background task runner used to schedule rollups.
        requesting_user (User): Authenticated user injected by dependency.

    Returns:
        Response payload mirroring create_message with updated answer and metadata.

    Raises:
        HTTPException: 404 if conversation/message not found; 403 if ownership invalid; 400 if message cannot be retried; 500 for server errors.
    """
    try:
        conversation = await Conversation.find_by_id(conversation_id)
        if not conversation:
            raise HTTPException(status_code=404, detail="Conversation not found")

        if conversation.user_id != requesting_user.id:
            raise HTTPException(
                status_code=403,
                detail="You are not allowed to add a message to this conversation",
            )

        message = await Message.find_by_id(message_id)
        if not message:
            raise HTTPException(status_code=404, detail="Message not found")

        if message.conversation_id != conversation_id:
            raise HTTPException(
                status_code=404, detail="Message not found in this conversation"
            )

        if not message.request_input:
            raise HTTPException(
                status_code=400,
                detail="This message cannot be retried",
            )

        answer, results, is_rag, latencies, prompts, retrieved_docs = (
            await generate_answer(
                message.request_input, conversation_id=conversation_id
            )
        )

        documents_data = []
        if results:
            documents_data = [extract_document_data(result) for result in results]

        message.output = answer
        message.documents = documents_data
        message.use_rag = is_rag
        existing_metadata = dict(getattr(message, "metadata", {}) or {})
        existing_metadata.update(
            {
                "latencies": latencies,
                "prompts": prompts,
                "retrieved_docs": retrieved_docs,
            }
        )
        message.metadata = existing_metadata
        await message.save()

        # Schedule rollup as background task to avoid blocking response
        background_tasks.add_task(maybe_rollup_and_trim_history, conversation_id)

        return {
            "id": message.id,
            "query": message.input,
            "answer": answer,
            "documents": documents_data,
            "use_rag": is_rag,
            "conversation_id": conversation_id,
            "collection_ids": message.request_input.collection_ids,
            "metadata": {
                "latencies": latencies,
                "prompts": prompts,
                "retrieved_docs": retrieved_docs,
            },
        }
    except HTTPException:
        raise
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Server error: {str(e)}")

stop_conversation(conversation_id, requesting_user=Depends(get_current_user)) async

Signal cancellation for the active generation within a conversation.

Uses the cancel manager to locate the in-flight message/task and requests cooperative cancellation, also notifying downstream subscribers via the stream bus.

Parameters:

Name Type Description Default
conversation_id str

Conversation identifier to stop generation for.

required
requesting_user User

Authenticated user injected by dependency.

Depends(get_current_user)

Returns:

Type Description
dict

Status payload indicating stop state or absence of active generation.

Raises:

Type Description
HTTPException

404 if conversation is not found; 403 if ownership invalid; 500 for server errors.

Source code in routers/message.py
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
@router.post("/conversations/{conversation_id}/stop")
async def stop_conversation(
    conversation_id: str,
    requesting_user: User = Depends(get_current_user),
) -> dict:
    """
    Signal cancellation for the active generation within a conversation.

    Uses the cancel manager to locate the in-flight message/task and requests cooperative cancellation, also notifying downstream subscribers via the stream bus.

    Args:
        conversation_id (str): Conversation identifier to stop generation for.
        requesting_user (User): Authenticated user injected by dependency.

    Returns:
        Status payload indicating stop state or absence of active generation.

    Raises:
        HTTPException: 404 if conversation is not found; 403 if ownership invalid; 500 for server errors.
    """
    try:
        logger.info(
            "generation.stop.requested user_id=%s conversation_id=%s",
            requesting_user.id,
            conversation_id,
        )
        conversation = await Conversation.find_by_id(conversation_id)
        if not conversation:
            raise HTTPException(status_code=404, detail="Conversation not found")
        if conversation.user_id != requesting_user.id:
            raise HTTPException(
                status_code=403,
                detail="You are not allowed to access this conversation",
            )

        cancel_mgr = get_cancel_manager()
        # Prefer async lookup to support Redis-backed mapping across workers
        try:
            message_id = await cancel_mgr.get_message_for_conversation_async(conversation_id)  # type: ignore
        except Exception:
            message_id = cancel_mgr.get_message_for_conversation(conversation_id)
        if not message_id:
            # Nothing active to stop; respond success for idempotency
            logger.info(
                "generation.stop.no_active user_id=%s conversation_id=%s",
                requesting_user.id,
                conversation_id,
            )
            return {"status": "no_active_generation"}

        cancel_mgr.cancel(message_id)
        try:
            bus = get_stream_bus()
            await bus.publish(message_id, f"data: {json.dumps({'type':'stopped'})}\n\n")
            await bus.close(message_id)
            logger.info(
                "generation.stop.signaled user_id=%s conversation_id=%s message_id=%s",
                requesting_user.id,
                conversation_id,
                message_id,
            )
        except Exception as e:
            logger.warning(
                "generation.stop.signal_failed conversation_id=%s message_id=%s err=%s",
                conversation_id,
                message_id,
                str(e),
            )
        return {"status": "stopping", "message_id": message_id}
    except HTTPException:
        raise
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Server error: {str(e)}")

stream_hallucination(conversation_id, message_id, requesting_user=Depends(get_current_user)) async

Stream hallucination handling result as Server-Sent Events (SSE).

Streams structured events for detection, optional rewriting, retrieval, and answer generation steps.

  • If label == 0 (factual), emits a final event with the reason.
  • If label == 1 (hallucination), streams tokens for the final answer and then a final event.

Parameters:

Name Type Description Default
conversation_id str

Conversation identifier.

required
message_id str

Message identifier to analyze.

required
requesting_user User

Authenticated user injected by dependency.

Depends(get_current_user)

Returns:

Type Description
StreamingResponse

SSE events for the detection workflow.

Raises:

Type Description
HTTPException

404 if conversation/message not found or mismatched; 403 if access is forbidden; 500 for streaming errors.

Source code in routers/message.py
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
@router.post(
    "/conversations/{conversation_id}/messages/{message_id}/stream-hallucination",
    response_class=StreamingResponse,
)
async def stream_hallucination(
    conversation_id: str,
    message_id: str,
    requesting_user: User = Depends(get_current_user),
) -> StreamingResponse:
    """
    Stream hallucination handling result as Server-Sent Events (SSE).

    Streams structured events for detection, optional rewriting, retrieval, and answer generation steps.

    - If label == 0 (factual), emits a final event with the reason.
    - If label == 1 (hallucination), streams tokens for the final answer and then a final event.

    Args:
        conversation_id (str): Conversation identifier.
        message_id (str): Message identifier to analyze.
        requesting_user (User): Authenticated user injected by dependency.

    Returns:
        SSE events for the detection workflow.

    Raises:
        HTTPException: 404 if conversation/message not found or mismatched; 403 if access is forbidden; 500 for streaming errors.
    """
    # Validate conversation ownership and message relationship
    conversation = await Conversation.find_by_id(conversation_id)
    if not conversation:
        raise HTTPException(status_code=404, detail="Conversation not found")
    if conversation.user_id != requesting_user.id:
        raise HTTPException(
            status_code=403,
            detail="You are not allowed to access this conversation",
        )

    message = await Message.find_by_id(message_id)
    if not message:
        raise HTTPException(status_code=404, detail="Message not found")
    if message.conversation_id != conversation_id:
        raise HTTPException(
            status_code=400, detail="Message does not belong to this conversation"
        )

    async def _generator():
        import json
        import time
        from src.utils.template_loader import get_template

        total_start = time.perf_counter()
        detector = HallucinationDetector()
        try:
            # Emit an initial status event early to start the stream promptly
            yield f"data: {json.dumps({'type': 'status', 'content': 'hallucination detection started...'})}\n\n"
            # Step 1: Detect
            t0 = time.perf_counter()
            label, reason = await detector.detect(
                query=message.input,
                model_response=message.output,
                docs=build_context(message.documents),
                llm_type=message.request_input.llm_type,
            )
            detect_latency = time.perf_counter() - t0

            yield f"data: {json.dumps({'type': 'label', 'content': label})}\n\n"
            yield f"data: {json.dumps({'type': 'reason', 'content': reason})}\n\n"

            # If factual (label == 0), emit reason and finish
            if label == 0:
                total_latency = time.perf_counter() - total_start
                latencies = {
                    "detect": detect_latency,
                    "rewrite": None,
                    "final_answer": None,
                    "total": total_latency,
                }
                # Persist to message metadata
                try:
                    message.hallucination = {
                        "label": label,
                        "reason": reason,
                        "rewritten_question": None,
                        "final_answer": None,
                        "latencies": latencies,
                    }
                    await message.save()
                except Exception:
                    pass

                final_payload = {
                    "type": "final",
                    "label": label,
                    "reason": reason,
                    "rewritten_question": None,
                    "answer": None,
                    "latencies": latencies,
                    "top_k_retrieved_docs": None,
                }
                yield f"data: {json.dumps(final_payload)}\n\n"
                return

            # Step 2: Rewrite (for hallucination)
            # Transparency: emit rewriting step
            yield f"data: {json.dumps({'type': 'status', 'content': 'Rewriting query...'})}\n\n"
            t1 = time.perf_counter()
            _orig_q, rewritten_question = await detector.rewrite_query(
                query=message.input,
                answer=message.output,
                reason=reason,
                llm_type=message.request_input.llm_type,
            )
            rewrite_latency = time.perf_counter() - t1
            yield f"data: {json.dumps({'type': 'rewritten_question', 'content': rewritten_question})}\n\n"

            # Step 3: Retrieve docs for rewritten_question (Qdrant + Wiley MCP)
            # Transparency: emit retrieving step
            yield f"data: {json.dumps({'type': 'status', 'content': 'Retrieving relevant documents...'})}\n\n"
            # Build a new GenerationRequest based on original, overriding the query
            req_in = message.request_input or GenerationRequest(query=message.input)
            rewritten_request = GenerationRequest(
                query=rewritten_question or message.input,
                year=getattr(req_in, "year", None),
                filters=getattr(req_in, "filters", None),
                llm_type=getattr(req_in, "llm_type", None),
                embeddings_model=getattr(req_in, "embeddings_model", None),
                k=getattr(req_in, "k", 5),
                temperature=getattr(req_in, "temperature", 0.3),
                score_threshold=getattr(req_in, "score_threshold", 0.7),
                max_new_tokens=getattr(req_in, "max_new_tokens", 1024),
                public_collections=list(
                    getattr(req_in, "public_collections", []) or []
                ),
            )
            try:
                rewritten_request.collection_ids = list(
                    getattr(req_in, "collection_ids", []) or []
                )
            except Exception:
                pass
            try:
                rewritten_request.private_collections_map = dict(
                    getattr(req_in, "private_collections_map", {}) or {}
                )
            except Exception:
                pass

            context = ""
            retrieved_docs = []
            rag_latencies = {}
            try:
                context, results, rag_latencies, retrieved_docs = (
                    await setup_rag_and_context(rewritten_request)
                )
            except Exception as e:
                # Soft-fail RAG retrieval; proceed without new docs
                rag_latencies = {"rag_error": str(e)}
                retrieved_docs = []
                context = ""

            # Step 4: Stream final answer from LLM using the hallucination answer template
            template = get_template(
                "llm_answer_template", filename="hallucination_detector.yaml"
            )
            # Inject retrieved context for better grounding
            prompt = template.format(query=rewritten_question or message.input)
            if context:
                prompt = f"{prompt}\n\nContext:\n{context}"

            yield f"data: {json.dumps({'type': 'status', 'content': 'Generating answer...'})}\n\n"
            final_answer_chunks = []
            t2 = time.perf_counter()
            # Try primary provider streaming first with a first-token timeout, then fallback to Mistral
            llm = detector.llm_manager.get_client_for_model(
                message.request_input.llm_type
            )
            used_stream = False
            try:
                astream = llm.astream(prompt)
                # Enforce first token timeout similar to generate_answer
                llm_instruct_timeout = MODEL_TIMEOUT
                async with asyncio.timeout(llm_instruct_timeout):
                    first = await astream.__anext__()
                    first_text = getattr(first, "content", None)
                    if first_text:
                        final_answer_chunks.append(first_text)
                        yield f"data: {json.dumps({'type':'token','content':first_text})}\n\n"
                # Continue without timeout
                async for token in astream:
                    text = getattr(token, "content", None)
                    if not text:
                        continue
                    final_answer_chunks.append(text)
                    yield f"data: {json.dumps({'type':'token','content':text})}\n\n"
                used_stream = True
            except Exception:
                used_stream = False

            # Fallback to Mistral streaming if needed
            if not used_stream:
                logger.info("Hallucination Falling back to Mistral streaming")
                async for (
                    token,
                    _prompt,
                ) in detector.llm_manager.generate_answer_mistral_stream(
                    query=rewritten_question or message.input,
                    context=context or "",
                    temperature=getattr(message.request_input, "temperature", 0.3),
                    conversation_context="",
                ):
                    if not token:
                        continue
                    final_answer_chunks.append(str(token))
                    yield f"data: {json.dumps({'type':'token','content':str(token)})}\n\n"

            final_latency = time.perf_counter() - t2

            final_answer = "".join(final_answer_chunks)
            total_latency = time.perf_counter() - total_start
            latencies = {
                "detect": detect_latency,
                "rewrite": rewrite_latency,
                **(rag_latencies or {}),
                "final_answer": final_latency,
                "total": total_latency,
            }

            # Persist to message.hallucination
            try:
                message.hallucination = {
                    "label": label,
                    "reason": reason,
                    "rewritten_question": rewritten_question,
                    "final_answer": final_answer,
                    "latencies": latencies,
                    "top_k_retrieved_docs": results,
                    "retrieved_docs": retrieved_docs,
                }
                await message.save()
            except Exception:
                pass

            final_payload = {
                "type": "final",
                "label": label,
                "reason": reason,
                "rewritten_question": rewritten_question,
                "answer": final_answer,
                "latencies": latencies,
                "top_k_retrieved_docs": results,
            }
            yield f"data: {json.dumps(final_payload)}\n\n"
        except Exception as e:
            # Persist error and stream error event
            try:
                existing_metadata = dict(getattr(message, "metadata", {}) or {})
                existing_metadata["error"] = str(e)
                message.metadata = existing_metadata
                await message.save()
            except Exception:
                pass
            yield f"data: {json.dumps({'type':'error','message':str(e)})}\n\n"

    response = StreamingResponse(_generator(), media_type="text/event-stream")
    response.headers["Cache-Control"] = "no-cache"
    response.headers["Connection"] = "keep-alive"
    response.headers["X-Accel-Buffering"] = "no"
    return response

update_message(conversation_id, message_id, request, requesting_user=Depends(get_current_user)) async

Update message feedback and related annotations.

Supports updating fields such as feedback, feedback_reason, was_copied, and hallucination feedback metadata on the target message.

Parameters:

Name Type Description Default
conversation_id str

Conversation identifier.

required
message_id str

Message identifier to update.

required
request MessageUpdate

Partial update payload for feedback fields.

required
requesting_user User

Authenticated user injected by dependency.

Depends(get_current_user)

Returns:

Type Description
dict

Success message upon update.

Raises:

Type Description
HTTPException

404 if conversation/message not found or mismatched; 403 if ownership invalid; 500 for server errors.

Source code in routers/message.py
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
@router.patch("/conversations/{conversation_id}/messages/{message_id}")
async def update_message(
    conversation_id: str,
    message_id: str,
    request: MessageUpdate,
    requesting_user: User = Depends(get_current_user),
) -> dict:
    """
    Update message feedback and related annotations.

    Supports updating fields such as `feedback`, `feedback_reason`, `was_copied`, and hallucination feedback metadata on the target message.

    Args:
        conversation_id (str): Conversation identifier.
        message_id (str): Message identifier to update.
        request (MessageUpdate): Partial update payload for feedback fields.
        requesting_user (User): Authenticated user injected by dependency.

    Returns:
        Success message upon update.

    Raises:
        HTTPException: 404 if conversation/message not found or mismatched; 403 if ownership invalid; 500 for server errors.
    """
    try:
        conversation = await Conversation.find_by_id(conversation_id)
        if not conversation:
            raise HTTPException(status_code=404, detail="Conversation not found")

        message = await Message.find_by_id(message_id)
        if not message:
            raise HTTPException(status_code=404, detail="Message not found")

        if message.conversation_id != conversation_id:
            raise HTTPException(
                status_code=404, detail="Message not found in this conversation"
            )

        if conversation.user_id != requesting_user.id:
            raise HTTPException(
                status_code=403,
                detail="You are not allowed to update feedback for this message",
            )

        if request.feedback is not None:
            message.feedback = request.feedback.value

        if request.was_copied is not None:
            message.was_copied = request.was_copied

        if request.feedback_reason is not None:
            message.feedback_reason = request.feedback_reason

        if request.hallucination_feedback is not None:
            if message.hallucination is None:
                message.hallucination = {}
            message.hallucination["feedback"] = request.hallucination_feedback.value

        if request.hallucination_feedback_reason is not None:
            if message.hallucination is None:
                message.hallucination = {}
            message.hallucination["feedback_reason"] = (
                request.hallucination_feedback_reason
            )

        if request.hallucination_was_copied is not None:
            if message.hallucination is None:
                message.hallucination = {}
            message.hallucination["was_copied"] = request.hallucination_was_copied

        await message.save()

        return {"message": "Feedback updated successfully"}

    except HTTPException:
        raise
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Server error: {str(e)}")