From 379b5c19ebb27900f98ab7ebf2d9b096cd0c6b37 Mon Sep 17 00:00:00 2001 From: Juergen Kunz Date: Fri, 16 Jan 2026 10:22:15 +0000 Subject: [PATCH] feat(ocr): add PaddleOCR GPU Docker image and FastAPI OCR server with entrypoint; implement OCR endpoints and consensus extraction testing --- Dockerfile_paddleocr | 51 +++ Dockerfile_paddleocr_cpu | 54 +++ changelog.md | 16 + image_support_files/paddleocr-entrypoint.sh | 25 ++ image_support_files/paddleocr-server.py | 258 ++++++++++++++ test/test.invoices.ts | 377 ++++++++++++++++++++ test/test.node.ts | 80 ++++- 7 files changed, 847 insertions(+), 14 deletions(-) create mode 100644 Dockerfile_paddleocr create mode 100644 Dockerfile_paddleocr_cpu create mode 100644 changelog.md create mode 100644 image_support_files/paddleocr-entrypoint.sh create mode 100644 image_support_files/paddleocr-server.py create mode 100644 test/test.invoices.ts diff --git a/Dockerfile_paddleocr b/Dockerfile_paddleocr new file mode 100644 index 0000000..89cc5a2 --- /dev/null +++ b/Dockerfile_paddleocr @@ -0,0 +1,51 @@ +# PaddleOCR GPU Variant +# OCR processing with NVIDIA GPU support using PaddlePaddle +FROM paddlepaddle/paddle:3.0.0-gpu-cuda11.8-cudnn8.9-trt8.6 + +LABEL maintainer="Task Venture Capital GmbH " +LABEL description="PaddleOCR PP-OCRv4 - GPU optimized" +LABEL org.opencontainers.image.source="https://code.foss.global/host.today/ht-docker-ai" + +# Environment configuration +ENV OCR_LANGUAGE="en" +ENV SERVER_PORT="5000" +ENV SERVER_HOST="0.0.0.0" +ENV PYTHONUNBUFFERED=1 + +# Set working directory +WORKDIR /app + +# Install system dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + libgl1-mesa-glx \ + libglib2.0-0 \ + curl \ + && rm -rf /var/lib/apt/lists/* + +# Install Python dependencies +RUN pip install --no-cache-dir \ + paddleocr \ + fastapi \ + uvicorn[standard] \ + python-multipart \ + opencv-python-headless \ + pillow + +# Copy server files +COPY image_support_files/paddleocr-server.py /app/paddleocr-server.py +COPY image_support_files/paddleocr-entrypoint.sh /usr/local/bin/paddleocr-entrypoint.sh +RUN chmod +x /usr/local/bin/paddleocr-entrypoint.sh + +# Pre-download OCR models during build (PP-OCRv4) +RUN python -c "from paddleocr import PaddleOCR; \ + ocr = PaddleOCR(use_angle_cls=True, lang='en', use_gpu=False, show_log=True); \ + print('English model downloaded')" + +# Expose API port +EXPOSE 5000 + +# Health check +HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \ + CMD curl -f http://localhost:5000/health || exit 1 + +ENTRYPOINT ["/usr/local/bin/paddleocr-entrypoint.sh"] diff --git a/Dockerfile_paddleocr_cpu b/Dockerfile_paddleocr_cpu new file mode 100644 index 0000000..2ea9cbc --- /dev/null +++ b/Dockerfile_paddleocr_cpu @@ -0,0 +1,54 @@ +# PaddleOCR CPU Variant +# OCR processing optimized for CPU-only inference +FROM python:3.10-slim + +LABEL maintainer="Task Venture Capital GmbH " +LABEL description="PaddleOCR PP-OCRv4 - CPU optimized" +LABEL org.opencontainers.image.source="https://code.foss.global/host.today/ht-docker-ai" + +# Environment configuration for CPU-only mode +ENV OCR_LANGUAGE="en" +ENV SERVER_PORT="5000" +ENV SERVER_HOST="0.0.0.0" +ENV PYTHONUNBUFFERED=1 +# Disable GPU usage for CPU-only variant +ENV CUDA_VISIBLE_DEVICES="-1" + +# Set working directory +WORKDIR /app + +# Install system dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + libgl1-mesa-glx \ + libglib2.0-0 \ + curl \ + && rm -rf /var/lib/apt/lists/* + +# Install Python dependencies (CPU version of PaddlePaddle) +RUN pip install --no-cache-dir \ + paddlepaddle \ + paddleocr \ + fastapi \ + uvicorn[standard] \ + python-multipart \ + opencv-python-headless \ + pillow + +# Copy server files +COPY image_support_files/paddleocr-server.py /app/paddleocr-server.py +COPY image_support_files/paddleocr-entrypoint.sh /usr/local/bin/paddleocr-entrypoint.sh +RUN chmod +x /usr/local/bin/paddleocr-entrypoint.sh + +# Pre-download OCR models during build (PP-OCRv4) +RUN python -c "from paddleocr import PaddleOCR; \ + ocr = PaddleOCR(use_angle_cls=True, lang='en', use_gpu=False, show_log=True); \ + print('English model downloaded')" + +# Expose API port +EXPOSE 5000 + +# Health check (longer start-period for CPU variant) +HEALTHCHECK --interval=30s --timeout=10s --start-period=120s --retries=3 \ + CMD curl -f http://localhost:5000/health || exit 1 + +ENTRYPOINT ["/usr/local/bin/paddleocr-entrypoint.sh"] diff --git a/changelog.md b/changelog.md new file mode 100644 index 0000000..acd83bd --- /dev/null +++ b/changelog.md @@ -0,0 +1,16 @@ +# Changelog + +## 2026-01-16 - 1.1.0 - feat(ocr) +add PaddleOCR GPU Docker image and FastAPI OCR server with entrypoint; implement OCR endpoints and consensus extraction testing + +- Add Dockerfile_paddleocr for GPU-accelerated PaddleOCR image (pre-downloads PP-OCRv4 models, exposes port 5000, healthcheck, entrypoint) +- Add image_support_files/paddleocr-server.py: FastAPI app providing /ocr (base64), /ocr/upload (file), and /health endpoints; model warm-up on startup; structured JSON responses and error handling +- Add image_support_files/paddleocr-entrypoint.sh to configure environment, detect GPU/CPU mode, and launch uvicorn +- Update test/test.node.ts to replace streaming extraction with a consensus-based extraction flow (multiple passes, hashing of results, majority voting) and improve logging/prompt text +- Add test/test.invoices.ts: integration tests for invoice extraction that call PaddleOCR, build prompts with optional OCR text, run consensus extraction, and produce a summary report + +## 2026-01-16 - 1.0.0 - initial release +Initial project files added with two small follow-up updates. + +- initial: base project commit. +- update: two minor follow-up updates refining the initial commit. \ No newline at end of file diff --git a/image_support_files/paddleocr-entrypoint.sh b/image_support_files/paddleocr-entrypoint.sh new file mode 100644 index 0000000..aff8737 --- /dev/null +++ b/image_support_files/paddleocr-entrypoint.sh @@ -0,0 +1,25 @@ +#!/bin/bash +set -e + +# Configuration from environment +OCR_LANGUAGE="${OCR_LANGUAGE:-en}" +SERVER_PORT="${SERVER_PORT:-5000}" +SERVER_HOST="${SERVER_HOST:-0.0.0.0}" + +echo "Starting PaddleOCR Server..." +echo " Language: ${OCR_LANGUAGE}" +echo " Host: ${SERVER_HOST}" +echo " Port: ${SERVER_PORT}" + +# Check GPU availability +if [ "${CUDA_VISIBLE_DEVICES}" = "-1" ]; then + echo " GPU: Disabled (CPU mode)" +else + echo " GPU: Enabled" +fi + +# Start the FastAPI server with uvicorn +exec python -m uvicorn paddleocr-server:app \ + --host "${SERVER_HOST}" \ + --port "${SERVER_PORT}" \ + --workers 1 diff --git a/image_support_files/paddleocr-server.py b/image_support_files/paddleocr-server.py new file mode 100644 index 0000000..be23a2a --- /dev/null +++ b/image_support_files/paddleocr-server.py @@ -0,0 +1,258 @@ +#!/usr/bin/env python3 +""" +PaddleOCR FastAPI Server +Provides REST API for OCR operations using PaddleOCR +""" + +import os +import io +import base64 +import logging +from typing import Optional, List, Any + +from fastapi import FastAPI, File, UploadFile, Form, HTTPException +from fastapi.responses import JSONResponse +from pydantic import BaseModel +import numpy as np +from PIL import Image +from paddleocr import PaddleOCR + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +# Environment configuration +OCR_LANGUAGE = os.environ.get('OCR_LANGUAGE', 'en') +USE_GPU = os.environ.get('CUDA_VISIBLE_DEVICES', '') != '-1' + +# Initialize FastAPI app +app = FastAPI( + title="PaddleOCR Server", + description="REST API for OCR operations using PaddleOCR PP-OCRv4", + version="1.0.0" +) + +# Global OCR instance +ocr_instance: Optional[PaddleOCR] = None + + +class OCRRequest(BaseModel): + """Request model for base64 image OCR""" + image: str + language: Optional[str] = None + + +class BoundingBox(BaseModel): + """Bounding box for detected text""" + points: List[List[float]] + + +class OCRResult(BaseModel): + """Single OCR detection result""" + text: str + confidence: float + box: List[List[float]] + + +class OCRResponse(BaseModel): + """OCR response model""" + success: bool + results: List[OCRResult] + error: Optional[str] = None + + +class HealthResponse(BaseModel): + """Health check response""" + status: str + model: str + language: str + gpu_enabled: bool + + +def get_ocr() -> PaddleOCR: + """Get or initialize the OCR instance""" + global ocr_instance + if ocr_instance is None: + logger.info(f"Initializing PaddleOCR with language={OCR_LANGUAGE}, use_gpu={USE_GPU}") + ocr_instance = PaddleOCR( + use_angle_cls=True, + lang=OCR_LANGUAGE, + use_gpu=USE_GPU, + show_log=False + ) + logger.info("PaddleOCR initialized successfully") + return ocr_instance + + +def decode_base64_image(base64_string: str) -> np.ndarray: + """Decode base64 string to numpy array""" + # Remove data URL prefix if present + if ',' in base64_string: + base64_string = base64_string.split(',')[1] + + image_data = base64.b64decode(base64_string) + image = Image.open(io.BytesIO(image_data)) + + # Convert to RGB if necessary + if image.mode != 'RGB': + image = image.convert('RGB') + + return np.array(image) + + +def process_ocr_result(result: Any) -> List[OCRResult]: + """Process PaddleOCR result into structured format""" + results = [] + + if result is None or len(result) == 0: + return results + + # PaddleOCR returns list of results per image + # Each result is a list of [box, (text, confidence)] + for line in result[0] if result[0] else []: + if line is None: + continue + + box = line[0] # [[x1,y1], [x2,y2], [x3,y3], [x4,y4]] + text_info = line[1] # (text, confidence) + + results.append(OCRResult( + text=text_info[0], + confidence=float(text_info[1]), + box=[[float(p[0]), float(p[1])] for p in box] + )) + + return results + + +@app.on_event("startup") +async def startup_event(): + """Pre-warm the OCR model on startup""" + logger.info("Pre-warming OCR model...") + try: + ocr = get_ocr() + # Create a small test image to warm up the model + test_image = np.zeros((100, 100, 3), dtype=np.uint8) + test_image.fill(255) # White image + ocr.ocr(test_image, cls=True) + logger.info("OCR model pre-warmed successfully") + except Exception as e: + logger.error(f"Failed to pre-warm OCR model: {e}") + + +@app.get("/health", response_model=HealthResponse) +async def health_check(): + """Health check endpoint""" + try: + # Ensure OCR is initialized + get_ocr() + return HealthResponse( + status="healthy", + model="PP-OCRv4", + language=OCR_LANGUAGE, + gpu_enabled=USE_GPU + ) + except Exception as e: + logger.error(f"Health check failed: {e}") + raise HTTPException(status_code=503, detail=str(e)) + + +@app.post("/ocr", response_model=OCRResponse) +async def ocr_base64(request: OCRRequest): + """ + Perform OCR on a base64-encoded image + + Args: + request: OCRRequest with base64 image and optional language + + Returns: + OCRResponse with detected text, confidence scores, and bounding boxes + """ + try: + # Decode image + image = decode_base64_image(request.image) + + # Get OCR instance (use request language if provided) + ocr = get_ocr() + + # If a different language is requested, create a new instance + if request.language and request.language != OCR_LANGUAGE: + logger.info(f"Creating OCR instance for language: {request.language}") + temp_ocr = PaddleOCR( + use_angle_cls=True, + lang=request.language, + use_gpu=USE_GPU, + show_log=False + ) + result = temp_ocr.ocr(image, cls=True) + else: + result = ocr.ocr(image, cls=True) + + # Process results + results = process_ocr_result(result) + + return OCRResponse(success=True, results=results) + + except Exception as e: + logger.error(f"OCR processing failed: {e}") + return OCRResponse(success=False, results=[], error=str(e)) + + +@app.post("/ocr/upload", response_model=OCRResponse) +async def ocr_upload( + img: UploadFile = File(...), + language: Optional[str] = Form(None) +): + """ + Perform OCR on an uploaded image file + + Args: + img: Uploaded image file + language: Optional language code (default: env OCR_LANGUAGE) + + Returns: + OCRResponse with detected text, confidence scores, and bounding boxes + """ + try: + # Read image + contents = await img.read() + image = Image.open(io.BytesIO(contents)) + + # Convert to RGB if necessary + if image.mode != 'RGB': + image = image.convert('RGB') + + image_array = np.array(image) + + # Get OCR instance + ocr = get_ocr() + + # If a different language is requested, create a new instance + if language and language != OCR_LANGUAGE: + logger.info(f"Creating OCR instance for language: {language}") + temp_ocr = PaddleOCR( + use_angle_cls=True, + lang=language, + use_gpu=USE_GPU, + show_log=False + ) + result = temp_ocr.ocr(image_array, cls=True) + else: + result = ocr.ocr(image_array, cls=True) + + # Process results + results = process_ocr_result(result) + + return OCRResponse(success=True, results=results) + + except Exception as e: + logger.error(f"OCR processing failed: {e}") + return OCRResponse(success=False, results=[], error=str(e)) + + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=5000) diff --git a/test/test.invoices.ts b/test/test.invoices.ts new file mode 100644 index 0000000..db5d60e --- /dev/null +++ b/test/test.invoices.ts @@ -0,0 +1,377 @@ +import { tap, expect } from '@git.zone/tstest/tapbundle'; +import * as fs from 'fs'; +import * as path from 'path'; +import { execSync } from 'child_process'; +import * as os from 'os'; + +const OLLAMA_URL = 'http://localhost:11434'; +const MODEL = 'openbmb/minicpm-v4.5:q8_0'; +const PADDLEOCR_URL = 'http://localhost:5000'; + +interface IInvoice { + invoice_number: string; + invoice_date: string; + vendor_name: string; + currency: string; + net_amount: number; + vat_amount: number; + total_amount: number; +} + +/** + * Extract OCR text from an image using PaddleOCR + */ +async function extractOcrText(imageBase64: string): Promise { + const formData = new FormData(); + const imageBuffer = Buffer.from(imageBase64, 'base64'); + const blob = new Blob([imageBuffer], { type: 'image/png' }); + formData.append('img', blob, 'image.png'); + formData.append('outtype', 'json'); + + try { + const response = await fetch(`${PADDLEOCR_URL}/ocr`, { + method: 'POST', + body: formData, + }); + + if (!response.ok) return ''; + + const data = await response.json(); + if (data.success && data.results) { + return data.results.map((r: { text: string }) => r.text).join('\n'); + } + } catch { + // PaddleOCR unavailable + } + return ''; +} + +/** + * Build prompt with optional OCR text + */ +function buildPrompt(ocrText: string): string { + const base = `You are an invoice parser. Extract the following fields from this invoice: + +1. invoice_number: The invoice/receipt number +2. invoice_date: Date in YYYY-MM-DD format +3. vendor_name: Company that issued the invoice +4. currency: EUR, USD, etc. +5. net_amount: Amount before tax (if shown) +6. vat_amount: Tax/VAT amount (if shown, 0 if reverse charge or no tax) +7. total_amount: Final amount due + +Return ONLY valid JSON in this exact format: +{"invoice_number":"XXX","invoice_date":"YYYY-MM-DD","vendor_name":"Company Name","currency":"EUR","net_amount":100.00,"vat_amount":19.00,"total_amount":119.00} + +If a field is not visible, use null for strings or 0 for numbers. +No explanation, just the JSON object.`; + + if (ocrText) { + return `${base} + +OCR text extracted from the invoice: +--- +${ocrText} +--- + +Cross-reference the image with the OCR text above for accuracy.`; + } + return base; +} + +/** + * Convert PDF to PNG images using ImageMagick + */ +function convertPdfToImages(pdfPath: string): string[] { + const tempDir = fs.mkdtempSync(path.join(os.tmpdir(), 'pdf-convert-')); + const outputPattern = path.join(tempDir, 'page-%d.png'); + + try { + execSync( + `convert -density 200 -quality 90 "${pdfPath}" -background white -alpha remove "${outputPattern}"`, + { stdio: 'pipe' } + ); + + const files = fs.readdirSync(tempDir).filter((f) => f.endsWith('.png')).sort(); + const images: string[] = []; + + for (const file of files) { + const imagePath = path.join(tempDir, file); + const imageData = fs.readFileSync(imagePath); + images.push(imageData.toString('base64')); + } + + return images; + } finally { + fs.rmSync(tempDir, { recursive: true, force: true }); + } +} + +/** + * Single extraction pass + */ +async function extractOnce(images: string[], passNum: number, ocrText: string = ''): Promise { + const payload = { + model: MODEL, + prompt: buildPrompt(ocrText), + images, + stream: true, + options: { + num_predict: 2048, + temperature: 0.1, + }, + }; + + const response = await fetch(`${OLLAMA_URL}/api/generate`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify(payload), + }); + + if (!response.ok) { + throw new Error(`Ollama API error: ${response.status}`); + } + + const reader = response.body?.getReader(); + if (!reader) { + throw new Error('No response body'); + } + + const decoder = new TextDecoder(); + let fullText = ''; + + while (true) { + const { done, value } = await reader.read(); + if (done) break; + + const chunk = decoder.decode(value, { stream: true }); + const lines = chunk.split('\n').filter((l) => l.trim()); + + for (const line of lines) { + try { + const json = JSON.parse(line); + if (json.response) { + fullText += json.response; + } + } catch { + // Skip invalid JSON lines + } + } + } + + // Extract JSON from response + const startIdx = fullText.indexOf('{'); + const endIdx = fullText.lastIndexOf('}') + 1; + + if (startIdx < 0 || endIdx <= startIdx) { + throw new Error(`No JSON object found in response: ${fullText.substring(0, 200)}`); + } + + const jsonStr = fullText.substring(startIdx, endIdx); + return JSON.parse(jsonStr); +} + +/** + * Create a hash of invoice for comparison (using key fields) + */ +function hashInvoice(invoice: IInvoice): string { + return `${invoice.invoice_number}|${invoice.invoice_date}|${invoice.total_amount.toFixed(2)}`; +} + +/** + * Extract with majority voting - run until 2 passes match + */ +async function extractWithConsensus(images: string[], invoiceName: string, maxPasses: number = 5): Promise { + const results: Array<{ invoice: IInvoice; hash: string }> = []; + const hashCounts: Map = new Map(); + + // Extract OCR text from first page + const ocrText = await extractOcrText(images[0]); + if (ocrText) { + console.log(` [OCR] Extracted ${ocrText.split('\n').length} text lines`); + } + + for (let pass = 1; pass <= maxPasses; pass++) { + try { + const invoice = await extractOnce(images, pass, ocrText); + const hash = hashInvoice(invoice); + + results.push({ invoice, hash }); + hashCounts.set(hash, (hashCounts.get(hash) || 0) + 1); + + console.log(` [Pass ${pass}] ${invoice.invoice_number} | ${invoice.invoice_date} | ${invoice.total_amount} ${invoice.currency}`); + + // Check if we have consensus (2+ matching) + const count = hashCounts.get(hash)!; + if (count >= 2) { + console.log(` [Consensus] Reached after ${pass} passes`); + return invoice; + } + } catch (err) { + console.log(` [Pass ${pass}] Error: ${err}`); + } + } + + // No consensus reached - return the most common result + let bestHash = ''; + let bestCount = 0; + for (const [hash, count] of hashCounts) { + if (count > bestCount) { + bestCount = count; + bestHash = hash; + } + } + + if (!bestHash) { + throw new Error(`No valid results for ${invoiceName}`); + } + + const best = results.find((r) => r.hash === bestHash)!; + console.log(` [No consensus] Using most common result (${bestCount}/${maxPasses} passes)`); + return best.invoice; +} + +/** + * Compare extracted invoice against expected + */ +function compareInvoice( + extracted: IInvoice, + expected: IInvoice +): { match: boolean; errors: string[] } { + const errors: string[] = []; + + // Compare invoice number (normalize by removing spaces and case) + const extNum = extracted.invoice_number?.replace(/\s/g, '').toLowerCase() || ''; + const expNum = expected.invoice_number?.replace(/\s/g, '').toLowerCase() || ''; + if (extNum !== expNum) { + errors.push(`invoice_number: expected "${expected.invoice_number}", got "${extracted.invoice_number}"`); + } + + // Compare date + if (extracted.invoice_date !== expected.invoice_date) { + errors.push(`invoice_date: expected "${expected.invoice_date}", got "${extracted.invoice_date}"`); + } + + // Compare total amount (with tolerance) + if (Math.abs(extracted.total_amount - expected.total_amount) > 0.02) { + errors.push(`total_amount: expected ${expected.total_amount}, got ${extracted.total_amount}`); + } + + // Compare currency + if (extracted.currency?.toUpperCase() !== expected.currency?.toUpperCase()) { + errors.push(`currency: expected "${expected.currency}", got "${extracted.currency}"`); + } + + return { match: errors.length === 0, errors }; +} + +/** + * Find all test cases (PDF + JSON pairs) in .nogit/invoices/ + */ +function findTestCases(): Array<{ name: string; pdfPath: string; jsonPath: string }> { + const testDir = path.join(process.cwd(), '.nogit/invoices'); + if (!fs.existsSync(testDir)) { + return []; + } + + const files = fs.readdirSync(testDir); + const pdfFiles = files.filter((f) => f.endsWith('.pdf')); + const testCases: Array<{ name: string; pdfPath: string; jsonPath: string }> = []; + + for (const pdf of pdfFiles) { + const baseName = pdf.replace('.pdf', ''); + const jsonFile = `${baseName}.json`; + if (files.includes(jsonFile)) { + testCases.push({ + name: baseName, + pdfPath: path.join(testDir, pdf), + jsonPath: path.join(testDir, jsonFile), + }); + } + } + + return testCases; +} + +// Tests + +tap.test('should connect to Ollama API', async () => { + const response = await fetch(`${OLLAMA_URL}/api/tags`); + expect(response.ok).toBeTrue(); + const data = await response.json(); + expect(data.models).toBeArray(); +}); + +tap.test('should have MiniCPM-V 4.5 model loaded', async () => { + const response = await fetch(`${OLLAMA_URL}/api/tags`); + const data = await response.json(); + const modelNames = data.models.map((m: { name: string }) => m.name); + expect(modelNames.some((name: string) => name.includes('minicpm-v4.5'))).toBeTrue(); +}); + +// Dynamic test for each PDF/JSON pair +const testCases = findTestCases(); +console.log(`\nFound ${testCases.length} invoice test cases\n`); + +let passedCount = 0; +let failedCount = 0; +const processingTimes: number[] = []; + +for (const testCase of testCases) { + tap.test(`should extract invoice: ${testCase.name}`, async () => { + // Load expected data + const expected: IInvoice = JSON.parse(fs.readFileSync(testCase.jsonPath, 'utf-8')); + console.log(`\n=== ${testCase.name} ===`); + console.log(`Expected: ${expected.invoice_number} | ${expected.invoice_date} | ${expected.total_amount} ${expected.currency}`); + + const startTime = Date.now(); + + // Convert PDF to images + const images = convertPdfToImages(testCase.pdfPath); + console.log(` Pages: ${images.length}`); + + // Extract with consensus voting + const extracted = await extractWithConsensus(images, testCase.name); + + const endTime = Date.now(); + const elapsedMs = endTime - startTime; + processingTimes.push(elapsedMs); + + // Compare results + const result = compareInvoice(extracted, expected); + + if (result.match) { + passedCount++; + console.log(` Result: MATCH (${(elapsedMs / 1000).toFixed(1)}s)`); + } else { + failedCount++; + console.log(` Result: MISMATCH (${(elapsedMs / 1000).toFixed(1)}s)`); + result.errors.forEach((e) => console.log(` - ${e}`)); + } + + // Assert match + expect(result.match).toBeTrue(); + }); +} + +tap.test('summary', async () => { + const totalInvoices = testCases.length; + const accuracy = totalInvoices > 0 ? (passedCount / totalInvoices) * 100 : 0; + const totalTimeMs = processingTimes.reduce((a, b) => a + b, 0); + const avgTimeMs = processingTimes.length > 0 ? totalTimeMs / processingTimes.length : 0; + const avgTimeSec = avgTimeMs / 1000; + const totalTimeSec = totalTimeMs / 1000; + + console.log(`\n========================================`); + console.log(` Invoice Extraction Summary`); + console.log(`========================================`); + console.log(` Passed: ${passedCount}/${totalInvoices}`); + console.log(` Failed: ${failedCount}/${totalInvoices}`); + console.log(` Accuracy: ${accuracy.toFixed(1)}%`); + console.log(`----------------------------------------`); + console.log(` Total time: ${totalTimeSec.toFixed(1)}s`); + console.log(` Avg per inv: ${avgTimeSec.toFixed(1)}s`); + console.log(`========================================\n`); +}); + +export default tap.start(); diff --git a/test/test.node.ts b/test/test.node.ts index 0872c72..4dab37c 100644 --- a/test/test.node.ts +++ b/test/test.node.ts @@ -7,7 +7,7 @@ import * as os from 'os'; const OLLAMA_URL = 'http://localhost:11434'; const MODEL = 'openbmb/minicpm-v4.5:q8_0'; -const BANK_STATEMENT_PROMPT = `You are a bank statement parser. Extract EVERY transaction from the table. +const EXTRACT_PROMPT = `You are a bank statement parser. Extract EVERY transaction from the table. Read the Amount column carefully: - "- 21,47 €" means DEBIT, output as: -21.47 @@ -16,7 +16,7 @@ Read the Amount column carefully: For each row output: {"date":"YYYY-MM-DD","counterparty":"NAME","amount":-21.47} -Do not skip any rows. Return complete JSON array:`; +Do not skip any rows. Return ONLY the JSON array, no explanation.`; interface ITransaction { date: string; @@ -53,12 +53,12 @@ function convertPdfToImages(pdfPath: string): string[] { } /** - * Extract transactions from images using Ollama with streaming + * Single extraction pass */ -async function extractTransactionsStreaming(images: string[]): Promise { +async function extractOnce(images: string[], passNum: number): Promise { const payload = { model: MODEL, - prompt: BANK_STATEMENT_PROMPT, + prompt: EXTRACT_PROMPT, images, stream: true, options: { @@ -86,7 +86,8 @@ async function extractTransactionsStreaming(images: string[]): Promise `${t.date}|${t.amount.toFixed(2)}`) + .sort() + .join(';'); +} + +/** + * Extract with majority voting - run until 2 passes match + */ +async function extractWithConsensus(images: string[], maxPasses: number = 5): Promise { + const results: Array<{ transactions: ITransaction[]; hash: string }> = []; + const hashCounts: Map = new Map(); + + for (let pass = 1; pass <= maxPasses; pass++) { + const transactions = await extractOnce(images, pass); + const hash = hashTransactions(transactions); + + results.push({ transactions, hash }); + hashCounts.set(hash, (hashCounts.get(hash) || 0) + 1); + + console.log(`[Pass ${pass}] Got ${transactions.length} transactions (hash: ${hash.substring(0, 20)}...)`); + + // Check if we have consensus (2+ matching) + const count = hashCounts.get(hash)!; + if (count >= 2) { + console.log(`[Consensus] Reached after ${pass} passes (${count} matching results)`); + return transactions; + } + + // After 2 passes, if no match yet, continue + if (pass >= 2) { + console.log(`[Pass ${pass}] No consensus yet, trying again...`); + } + } + + // No consensus reached - return the most common result + let bestHash = ''; + let bestCount = 0; + for (const [hash, count] of hashCounts) { + if (count > bestCount) { + bestCount = count; + bestHash = hash; + } + } + + const best = results.find((r) => r.hash === bestHash)!; + console.log(`[No consensus] Using most common result (${bestCount}/${maxPasses} passes)`); + return best.transactions; +} + /** * Compare extracted transactions against expected */ @@ -227,16 +280,15 @@ for (const testCase of testCases) { // Convert PDF to images console.log('Converting PDF to images...'); const images = convertPdfToImages(testCase.pdfPath); - console.log(`Converted: ${images.length} pages`); + console.log(`Converted: ${images.length} pages\n`); - // Extract transactions with streaming output - console.log('Extracting transactions (streaming)...\n'); - const extracted = await extractTransactionsStreaming(images); - console.log(`Extracted: ${extracted.length} transactions`); + // Extract with consensus voting + const extracted = await extractWithConsensus(images); + console.log(`\nFinal: ${extracted.length} transactions`); // Compare results const result = compareTransactions(extracted, expected); - console.log(`Matches: ${result.matches}/${result.total}`); + console.log(`Accuracy: ${result.matches}/${result.total}`); if (result.errors.length > 0) { console.log('Errors:');