feat(paddleocr-vl): add PaddleOCR-VL full pipeline Docker image and API server, plus integration tests and docker helpers

This commit is contained in:
2026-01-17 20:22:23 +00:00
parent addae20cbd
commit 80e6866442
12 changed files with 2414 additions and 21 deletions

View File

@@ -0,0 +1,443 @@
#!/usr/bin/env python3
"""
PaddleOCR-VL Full Pipeline API Server (Transformers backend)
Provides REST API for document parsing using:
- PP-DocLayoutV2 for layout detection
- PaddleOCR-VL (transformers) for recognition
- Structured JSON/Markdown output
"""
import os
import io
import base64
import logging
import tempfile
import time
import json
from typing import Optional, List, Union
from pathlib import Path
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from PIL import Image
import torch
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Environment configuration
SERVER_HOST = os.environ.get('SERVER_HOST', '0.0.0.0')
SERVER_PORT = int(os.environ.get('SERVER_PORT', '8000'))
MODEL_NAME = "PaddlePaddle/PaddleOCR-VL"
# Device configuration
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {DEVICE}")
# Task prompts
TASK_PROMPTS = {
"ocr": "OCR:",
"table": "Table Recognition:",
"formula": "Formula Recognition:",
"chart": "Chart Recognition:",
}
# Initialize FastAPI app
app = FastAPI(
title="PaddleOCR-VL Full Pipeline Server",
description="Document parsing with PP-DocLayoutV2 + PaddleOCR-VL (transformers)",
version="1.0.0"
)
# Global model instances
vl_model = None
vl_processor = None
layout_model = None
def load_vl_model():
"""Load the PaddleOCR-VL model for element recognition"""
global vl_model, vl_processor
if vl_model is not None:
return
logger.info(f"Loading PaddleOCR-VL model: {MODEL_NAME}")
from transformers import AutoModelForCausalLM, AutoProcessor
vl_processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
if DEVICE == "cuda":
vl_model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
).to(DEVICE).eval()
else:
vl_model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
).eval()
logger.info("PaddleOCR-VL model loaded successfully")
def load_layout_model():
"""Load the LayoutDetection model for layout detection"""
global layout_model
if layout_model is not None:
return
try:
logger.info("Loading LayoutDetection model (PP-DocLayout_plus-L)...")
from paddleocr import LayoutDetection
layout_model = LayoutDetection()
logger.info("LayoutDetection model loaded successfully")
except Exception as e:
logger.warning(f"Could not load LayoutDetection: {e}")
logger.info("Falling back to VL-only mode (no layout detection)")
def recognize_element(image: Image.Image, task: str = "ocr") -> str:
"""Recognize a single element using PaddleOCR-VL"""
load_vl_model()
prompt = TASK_PROMPTS.get(task, TASK_PROMPTS["ocr"])
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt},
]
}
]
inputs = vl_processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt"
)
if DEVICE == "cuda":
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
with torch.inference_mode():
outputs = vl_model.generate(
**inputs,
max_new_tokens=4096,
do_sample=False,
use_cache=True
)
response = vl_processor.batch_decode(outputs, skip_special_tokens=True)[0]
# Extract only the assistant's response content
# The response format is: "User: <prompt>\nAssistant: <content>"
# We want to extract just the content after "Assistant:"
if "Assistant:" in response:
parts = response.split("Assistant:")
if len(parts) > 1:
response = parts[-1].strip()
elif "assistant:" in response.lower():
# Case-insensitive fallback
import re
match = re.split(r'[Aa]ssistant:', response)
if len(match) > 1:
response = match[-1].strip()
return response
def detect_layout(image: Image.Image) -> List[dict]:
"""Detect layout regions in the image"""
load_layout_model()
if layout_model is None:
# No layout model - return a single region covering the whole image
return [{
"type": "text",
"bbox": [0, 0, image.width, image.height],
"score": 1.0
}]
# Save image to temp file
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
image.save(tmp.name, "PNG")
tmp_path = tmp.name
try:
results = layout_model.predict(tmp_path)
regions = []
for res in results:
# LayoutDetection returns boxes in 'boxes' key
for box in res.get("boxes", []):
coord = box.get("coordinate", [0, 0, image.width, image.height])
# Convert numpy floats to regular floats
bbox = [float(c) for c in coord]
regions.append({
"type": box.get("label", "text"),
"bbox": bbox,
"score": float(box.get("score", 1.0))
})
# Sort regions by vertical position (top to bottom)
regions.sort(key=lambda r: r["bbox"][1])
return regions if regions else [{
"type": "text",
"bbox": [0, 0, image.width, image.height],
"score": 1.0
}]
finally:
os.unlink(tmp_path)
def process_document(image: Image.Image) -> dict:
"""Process a document through the full pipeline"""
logger.info(f"Processing document: {image.size}")
# Step 1: Detect layout
regions = detect_layout(image)
logger.info(f"Detected {len(regions)} layout regions")
# Step 2: Recognize each region
blocks = []
for i, region in enumerate(regions):
region_type = region["type"].lower()
bbox = region["bbox"]
# Crop region from image
x1, y1, x2, y2 = [int(c) for c in bbox]
region_image = image.crop((x1, y1, x2, y2))
# Determine task based on region type
if "table" in region_type:
task = "table"
elif "formula" in region_type or "math" in region_type:
task = "formula"
elif "chart" in region_type or "figure" in region_type:
task = "chart"
else:
task = "ocr"
# Recognize the region
try:
content = recognize_element(region_image, task)
blocks.append({
"index": i,
"type": region_type,
"bbox": bbox,
"content": content,
"task": task
})
logger.info(f" Region {i} ({region_type}): {len(content)} chars")
except Exception as e:
logger.error(f" Region {i} error: {e}")
blocks.append({
"index": i,
"type": region_type,
"bbox": bbox,
"content": "",
"error": str(e)
})
return {"blocks": blocks, "image_size": list(image.size)}
def result_to_markdown(result: dict) -> str:
"""Convert result to Markdown format"""
lines = []
for block in result.get("blocks", []):
block_type = block.get("type", "text")
content = block.get("content", "")
if "table" in block_type.lower():
lines.append(f"\n{content}\n")
elif "formula" in block_type.lower():
lines.append(f"\n$$\n{content}\n$$\n")
else:
lines.append(content)
return "\n\n".join(lines)
# Request/Response models
class ParseRequest(BaseModel):
image: str # base64 encoded image
output_format: Optional[str] = "json"
class ParseResponse(BaseModel):
success: bool
format: str
result: Union[dict, str]
processing_time: float
error: Optional[str] = None
def decode_image(image_source: str) -> Image.Image:
"""Decode image from base64 or data URL"""
if image_source.startswith("data:"):
header, data = image_source.split(",", 1)
image_data = base64.b64decode(data)
else:
image_data = base64.b64decode(image_source)
return Image.open(io.BytesIO(image_data)).convert("RGB")
@app.on_event("startup")
async def startup_event():
"""Pre-load models on startup"""
logger.info("Starting PaddleOCR-VL Full Pipeline Server...")
try:
load_vl_model()
load_layout_model()
logger.info("Models loaded successfully")
except Exception as e:
logger.error(f"Failed to pre-load models: {e}")
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy" if vl_model is not None else "loading",
"service": "PaddleOCR-VL Full Pipeline (Transformers)",
"device": DEVICE,
"vl_model_loaded": vl_model is not None,
"layout_model_loaded": layout_model is not None
}
@app.get("/formats")
async def supported_formats():
"""List supported output formats"""
return {
"output_formats": ["json", "markdown"],
"image_formats": ["PNG", "JPEG", "WebP", "BMP", "GIF", "TIFF"],
"capabilities": [
"Layout detection (PP-DocLayoutV2)",
"Text recognition (OCR)",
"Table recognition",
"Formula recognition (LaTeX)",
"Chart recognition",
"Multi-language support (109 languages)"
]
}
@app.post("/parse", response_model=ParseResponse)
async def parse_document_endpoint(request: ParseRequest):
"""Parse a document image and return structured output"""
try:
start_time = time.time()
image = decode_image(request.image)
result = process_document(image)
if request.output_format == "markdown":
markdown = result_to_markdown(result)
output = {"markdown": markdown}
else:
output = result
elapsed = time.time() - start_time
logger.info(f"Processing complete in {elapsed:.2f}s")
return ParseResponse(
success=True,
format=request.output_format,
result=output,
processing_time=elapsed
)
except Exception as e:
logger.error(f"Error processing document: {e}", exc_info=True)
return ParseResponse(
success=False,
format=request.output_format,
result={},
processing_time=0,
error=str(e)
)
@app.post("/v1/chat/completions")
async def chat_completions(request: dict):
"""OpenAI-compatible chat completions endpoint"""
try:
messages = request.get("messages", [])
output_format = request.get("output_format", "json")
# Find user message with image
image = None
for msg in reversed(messages):
if msg.get("role") == "user":
content = msg.get("content", [])
if isinstance(content, list):
for item in content:
if item.get("type") == "image_url":
url = item.get("image_url", {}).get("url", "")
image = decode_image(url)
break
break
if image is None:
raise HTTPException(status_code=400, detail="No image provided")
start_time = time.time()
result = process_document(image)
if output_format == "markdown":
content = result_to_markdown(result)
else:
content = json.dumps(result, ensure_ascii=False, indent=2)
elapsed = time.time() - start_time
return {
"id": f"chatcmpl-{int(time.time()*1000)}",
"object": "chat.completion",
"created": int(time.time()),
"model": "paddleocr-vl-full",
"choices": [{
"index": 0,
"message": {"role": "assistant", "content": content},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 100,
"completion_tokens": len(content) // 4,
"total_tokens": 100 + len(content) // 4
},
"processing_time": elapsed
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error in chat completions: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host=SERVER_HOST, port=SERVER_PORT)