#!/usr/bin/env python3 """ PaddleOCR-VL Full Pipeline API Server (Transformers backend) Provides REST API for document parsing using: - PP-DocLayoutV2 for layout detection - PaddleOCR-VL (transformers) for recognition - Structured JSON/Markdown output """ import os import io import base64 import logging import tempfile import time import json from typing import Optional, List, Union from pathlib import Path from fastapi import FastAPI, HTTPException, UploadFile, File, Form from fastapi.responses import JSONResponse from pydantic import BaseModel from PIL import Image import torch # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # Environment configuration SERVER_HOST = os.environ.get('SERVER_HOST', '0.0.0.0') SERVER_PORT = int(os.environ.get('SERVER_PORT', '8000')) MODEL_NAME = "PaddlePaddle/PaddleOCR-VL" # Device configuration DEVICE = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device: {DEVICE}") # Task prompts TASK_PROMPTS = { "ocr": "OCR:", "table": "Table Recognition:", "formula": "Formula Recognition:", "chart": "Chart Recognition:", } # Initialize FastAPI app app = FastAPI( title="PaddleOCR-VL Full Pipeline Server", description="Document parsing with PP-DocLayoutV2 + PaddleOCR-VL (transformers)", version="1.0.0" ) # Global model instances vl_model = None vl_processor = None layout_model = None def load_vl_model(): """Load the PaddleOCR-VL model for element recognition""" global vl_model, vl_processor if vl_model is not None: return logger.info(f"Loading PaddleOCR-VL model: {MODEL_NAME}") from transformers import AutoModelForCausalLM, AutoProcessor vl_processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True) if DEVICE == "cuda": vl_model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, trust_remote_code=True, torch_dtype=torch.bfloat16, ).to(DEVICE).eval() else: vl_model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, trust_remote_code=True, torch_dtype=torch.float32, low_cpu_mem_usage=True, ).eval() logger.info("PaddleOCR-VL model loaded successfully") def load_layout_model(): """Load the LayoutDetection model for layout detection""" global layout_model if layout_model is not None: return try: logger.info("Loading LayoutDetection model (PP-DocLayout_plus-L)...") from paddleocr import LayoutDetection layout_model = LayoutDetection() logger.info("LayoutDetection model loaded successfully") except Exception as e: logger.warning(f"Could not load LayoutDetection: {e}") logger.info("Falling back to VL-only mode (no layout detection)") def recognize_element(image: Image.Image, task: str = "ocr") -> str: """Recognize a single element using PaddleOCR-VL""" load_vl_model() prompt = TASK_PROMPTS.get(task, TASK_PROMPTS["ocr"]) messages = [ { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": prompt}, ] } ] inputs = vl_processor.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt" ) if DEVICE == "cuda": inputs = {k: v.to(DEVICE) for k, v in inputs.items()} with torch.inference_mode(): outputs = vl_model.generate( **inputs, max_new_tokens=4096, do_sample=False, use_cache=True ) response = vl_processor.batch_decode(outputs, skip_special_tokens=True)[0] # Extract only the assistant's response content # The response format is: "User: \nAssistant: " # We want to extract just the content after "Assistant:" if "Assistant:" in response: parts = response.split("Assistant:") if len(parts) > 1: response = parts[-1].strip() elif "assistant:" in response.lower(): # Case-insensitive fallback import re match = re.split(r'[Aa]ssistant:', response) if len(match) > 1: response = match[-1].strip() return response def detect_layout(image: Image.Image) -> List[dict]: """Detect layout regions in the image""" load_layout_model() if layout_model is None: # No layout model - return a single region covering the whole image return [{ "type": "text", "bbox": [0, 0, image.width, image.height], "score": 1.0 }] # Save image to temp file with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: image.save(tmp.name, "PNG") tmp_path = tmp.name try: results = layout_model.predict(tmp_path) regions = [] for res in results: # LayoutDetection returns boxes in 'boxes' key for box in res.get("boxes", []): coord = box.get("coordinate", [0, 0, image.width, image.height]) # Convert numpy floats to regular floats bbox = [float(c) for c in coord] regions.append({ "type": box.get("label", "text"), "bbox": bbox, "score": float(box.get("score", 1.0)) }) # Sort regions by vertical position (top to bottom) regions.sort(key=lambda r: r["bbox"][1]) return regions if regions else [{ "type": "text", "bbox": [0, 0, image.width, image.height], "score": 1.0 }] finally: os.unlink(tmp_path) def process_document(image: Image.Image) -> dict: """Process a document through the full pipeline""" logger.info(f"Processing document: {image.size}") # Step 1: Detect layout regions = detect_layout(image) logger.info(f"Detected {len(regions)} layout regions") # Step 2: Recognize each region blocks = [] for i, region in enumerate(regions): region_type = region["type"].lower() bbox = region["bbox"] # Crop region from image x1, y1, x2, y2 = [int(c) for c in bbox] region_image = image.crop((x1, y1, x2, y2)) # Determine task based on region type if "table" in region_type: task = "table" elif "formula" in region_type or "math" in region_type: task = "formula" elif "chart" in region_type or "figure" in region_type: task = "chart" else: task = "ocr" # Recognize the region try: content = recognize_element(region_image, task) blocks.append({ "index": i, "type": region_type, "bbox": bbox, "content": content, "task": task }) logger.info(f" Region {i} ({region_type}): {len(content)} chars") except Exception as e: logger.error(f" Region {i} error: {e}") blocks.append({ "index": i, "type": region_type, "bbox": bbox, "content": "", "error": str(e) }) return {"blocks": blocks, "image_size": list(image.size)} def result_to_markdown(result: dict) -> str: """Convert result to Markdown format""" lines = [] for block in result.get("blocks", []): block_type = block.get("type", "text") content = block.get("content", "") if "table" in block_type.lower(): lines.append(f"\n{content}\n") elif "formula" in block_type.lower(): lines.append(f"\n$$\n{content}\n$$\n") else: lines.append(content) return "\n\n".join(lines) # Request/Response models class ParseRequest(BaseModel): image: str # base64 encoded image output_format: Optional[str] = "json" class ParseResponse(BaseModel): success: bool format: str result: Union[dict, str] processing_time: float error: Optional[str] = None def decode_image(image_source: str) -> Image.Image: """Decode image from base64 or data URL""" if image_source.startswith("data:"): header, data = image_source.split(",", 1) image_data = base64.b64decode(data) else: image_data = base64.b64decode(image_source) return Image.open(io.BytesIO(image_data)).convert("RGB") @app.on_event("startup") async def startup_event(): """Pre-load models on startup""" logger.info("Starting PaddleOCR-VL Full Pipeline Server...") try: load_vl_model() load_layout_model() logger.info("Models loaded successfully") except Exception as e: logger.error(f"Failed to pre-load models: {e}") @app.get("/health") async def health_check(): """Health check endpoint""" return { "status": "healthy" if vl_model is not None else "loading", "service": "PaddleOCR-VL Full Pipeline (Transformers)", "device": DEVICE, "vl_model_loaded": vl_model is not None, "layout_model_loaded": layout_model is not None } @app.get("/formats") async def supported_formats(): """List supported output formats""" return { "output_formats": ["json", "markdown"], "image_formats": ["PNG", "JPEG", "WebP", "BMP", "GIF", "TIFF"], "capabilities": [ "Layout detection (PP-DocLayoutV2)", "Text recognition (OCR)", "Table recognition", "Formula recognition (LaTeX)", "Chart recognition", "Multi-language support (109 languages)" ] } @app.post("/parse", response_model=ParseResponse) async def parse_document_endpoint(request: ParseRequest): """Parse a document image and return structured output""" try: start_time = time.time() image = decode_image(request.image) result = process_document(image) if request.output_format == "markdown": markdown = result_to_markdown(result) output = {"markdown": markdown} else: output = result elapsed = time.time() - start_time logger.info(f"Processing complete in {elapsed:.2f}s") return ParseResponse( success=True, format=request.output_format, result=output, processing_time=elapsed ) except Exception as e: logger.error(f"Error processing document: {e}", exc_info=True) return ParseResponse( success=False, format=request.output_format, result={}, processing_time=0, error=str(e) ) @app.post("/v1/chat/completions") async def chat_completions(request: dict): """OpenAI-compatible chat completions endpoint""" try: messages = request.get("messages", []) output_format = request.get("output_format", "json") # Find user message with image image = None for msg in reversed(messages): if msg.get("role") == "user": content = msg.get("content", []) if isinstance(content, list): for item in content: if item.get("type") == "image_url": url = item.get("image_url", {}).get("url", "") image = decode_image(url) break break if image is None: raise HTTPException(status_code=400, detail="No image provided") start_time = time.time() result = process_document(image) if output_format == "markdown": content = result_to_markdown(result) else: content = json.dumps(result, ensure_ascii=False, indent=2) elapsed = time.time() - start_time return { "id": f"chatcmpl-{int(time.time()*1000)}", "object": "chat.completion", "created": int(time.time()), "model": "paddleocr-vl-full", "choices": [{ "index": 0, "message": {"role": "assistant", "content": content}, "finish_reason": "stop" }], "usage": { "prompt_tokens": 100, "completion_tokens": len(content) // 4, "total_tokens": 100 + len(content) // 4 }, "processing_time": elapsed } except HTTPException: raise except Exception as e: logger.error(f"Error in chat completions: {e}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn uvicorn.run(app, host=SERVER_HOST, port=SERVER_PORT)