import os
import sys

# Fix PIL.Image.ANTIALIAS compatibility issue before importing anything else
try:
    from PIL import Image
    # Add ANTIALIAS back if it doesn't exist (for newer Pillow versions)
    if not hasattr(Image, 'ANTIALIAS'):
        Image.ANTIALIAS = Image.LANCZOS
        print("✅ Added Image.ANTIALIAS compatibility fix")
except ImportError:
    print("❌ PIL/Pillow not available")

# Now import everything else
import json
import re
import mysql.connector
from datetime import datetime, timedelta
import easyocr
import cv2
import numpy as np
import argparse
import logging
from typing import Dict, List, Optional, Tuple

# Set custom cache directory for EasyOCR
os.environ['EASYOCR_MODULE_PATH'] = '/var/www/html/scripts/easyocr_cache'

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('/var/www/html/logs/ocr_processor.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

class ReceiptOCRProcessor:
    def __init__(self, db_config: Dict[str, str]):
        """Initialize OCR processor with database configuration"""
        self.db_config = db_config
        self.reader = None
        self.init_ocr()
        
    def init_ocr(self):
        """Initialize EasyOCR reader with custom cache directory"""
        try:
            # Create cache directory if it doesn't exist
            cache_dir = '/var/www/html/scripts/easyocr_cache'
            os.makedirs(cache_dir, exist_ok=True)
            
            # Initialize EasyOCR with English language and custom cache
            self.reader = easyocr.Reader(['en'], gpu=False, model_storage_directory=cache_dir)
            logger.info("EasyOCR initialized successfully with custom cache directory")
        except Exception as e:
            logger.error(f"Failed to initialize EasyOCR: {e}")
            raise
    
    def get_db_connection(self):
        """Get database connection"""
        try:
            connection = mysql.connector.connect(**self.db_config)
            return connection
        except mysql.connector.Error as e:
            logger.error(f"Database connection failed: {e}")
            raise
    
    def preprocess_image(self, image_path: str) -> np.ndarray:
        """Preprocess image for better OCR accuracy"""
        try:
            # Read image
            image = cv2.imread(image_path)
            if image is None:
                raise ValueError(f"Could not read image from {image_path}")
            
            # Convert to grayscale
            gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
            
            # Apply denoising
            denoised = cv2.fastNlMeansDenoising(gray)
            
            # Apply adaptive thresholding
            thresh = cv2.adaptiveThreshold(
                denoised, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 11, 2
            )
            
            logger.info(f"Image preprocessed successfully: {image_path}")
            return thresh
            
        except Exception as e:
            logger.error(f"Image preprocessing failed: {e}")
            # Return original image if preprocessing fails
            return cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    
    def extract_text_from_image(self, image_path: str) -> Tuple[str, float]:
        """Extract text from image using EasyOCR"""
        try:
            # Preprocess image
            processed_image = self.preprocess_image(image_path)
            
            # Perform OCR
            results = self.reader.readtext(processed_image)
            
            # Extract text and calculate average confidence
            extracted_text = []
            confidences = []
            
            for (bbox, text, confidence) in results:
                if confidence > 0.5:  # Filter low confidence results
                    extracted_text.append(text)
                    confidences.append(confidence)
            
            full_text = '\n'.join(extracted_text)
            avg_confidence = sum(confidences) / len(confidences) if confidences else 0
            
            logger.info(f"OCR completed. Extracted {len(extracted_text)} text elements with avg confidence {avg_confidence:.2f}")
            return full_text, avg_confidence * 100
            
        except Exception as e:
            logger.error(f"OCR text extraction failed: {e}")
            return "", 0.0
    
    def parse_receipt_data(self, ocr_text: str) -> Dict:
        """Parse structured data from OCR text"""
        try:
            data = {
                'store_name': None,
                'store_address': None,
                'receipt_date': None,
                'receipt_time': None,
                'total_amount': None,
                'currency': 'ZAR',
                'items_count': 0,
                'receipt_number': None,
                'pos_terminal': None
            }
            
            lines = ocr_text.split('\n')
            
            # South African retailers
            retailers = [
                'PICK N PAY', 'SHOPRITE', 'CHECKERS', 'SPAR', 'WOOLWORTHS',
                'GAME', 'MAKRO', 'DISCHEM', 'CLICKS', 'PEP'
            ]
            
            # Find store name
            for line in lines[:5]:  # Check first 5 lines
                line_upper = line.upper().strip()
                for retailer in retailers:
                    if retailer in line_upper:
                        data['store_name'] = retailer
                        break
                if data['store_name']:
                    break
            
            # Find date patterns
            date_patterns = [
                r'(\d{4}[-/]\d{2}[-/]\d{2})',
                r'(\d{2}[-/]\d{2}[-/]\d{4})',
                r'(\d{2}[-/]\d{2}[-/]\d{2})'
            ]
            
            for line in lines:
                for pattern in date_patterns:
                    match = re.search(pattern, line)
                    if match:
                        try:
                            date_str = match.group(1)
                            # Try to parse different date formats
                            for fmt in ['%Y-%m-%d', '%Y/%m/%d', '%d-%m-%Y', '%d/%m/%Y', '%y-%m-%d', '%y/%m/%d']:
                                try:
                                    parsed_date = datetime.strptime(date_str, fmt)
                                    data['receipt_date'] = parsed_date.strftime('%Y-%m-%d')
                                    break
                                except ValueError:
                                    continue
                            if data['receipt_date']:
                                break
                        except:
                            continue
                if data['receipt_date']:
                    break
            
            # Find time patterns
            time_pattern = r'(\d{2}:\d{2})'
            for line in lines:
                match = re.search(time_pattern, line)
                if match:
                    data['receipt_time'] = match.group(1) + ':00'
                    break
            
            # Find total amount
            amount_patterns = [
                r'TOTAL[:\s]*R?[\s]*(\d+[.,]\d{2})',
                r'AMOUNT[:\s]*R?[\s]*(\d+[.,]\d{2})',
                r'R[\s]*(\d+[.,]\d{2})',
                r'(\d+[.,]\d{2})[\s]*$'
            ]
            
            for line in lines:
                line_upper = line.upper()
                if 'TOTAL' in line_upper or 'AMOUNT' in line_upper:
                    for pattern in amount_patterns:
                        match = re.search(pattern, line_upper)
                        if match:
                            amount_str = match.group(1).replace(',', '.')
                            try:
                                data['total_amount'] = float(amount_str)
                                break
                            except ValueError:
                                continue
                    if data['total_amount']:
                        break
            
            # Count items (rough estimate)
            item_indicators = ['R', 'ZAR', '@', 'X', '*']
            item_count = 0
            for line in lines:
                if any(indicator in line.upper() for indicator in item_indicators):
                    if re.search(r'\d+[.,]\d{2}', line):
                        item_count += 1
            
            data['items_count'] = max(1, item_count // 2)  # Rough estimate
            
            # Find receipt number
            receipt_patterns = [
                r'RECEIPT[:\s#]*(\d+)',
                r'REF[:\s#]*(\d+)',
                r'NO[:\s#]*(\d+)'
            ]
            
            for line in lines:
                line_upper = line.upper()
                for pattern in receipt_patterns:
                    match = re.search(pattern, line_upper)
                    if match:
                        data['receipt_number'] = match.group(1)
                        break
                if data['receipt_number']:
                    break
            
            logger.info(f"Receipt data parsed: {data}")
            return data
            
        except Exception as e:
            logger.error(f"Receipt data parsing failed: {e}")
            return {}
    
    def save_ocr_data(self, receipt_id: int, raw_text: str, parsed_data: Dict, confidence: float):
        """Save OCR data to database"""
        try:
            connection = self.get_db_connection()
            cursor = connection.cursor()
            
            # Prepare data for insertion
            extracted_data_json = json.dumps(parsed_data)
            
            insert_query = """
                INSERT INTO receipt_ocr_data (
                    receipt_id, raw_ocr_text, extracted_data, store_name, store_address,
                    receipt_date, receipt_time, total_amount, currency, items_count,
                    receipt_number, pos_terminal, ocr_confidence, processing_status, created_at
                ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
            """
            
            values = (
                receipt_id,
                raw_text,
                extracted_data_json,
                parsed_data.get('store_name'),
                parsed_data.get('store_address'),
                parsed_data.get('receipt_date'),
                parsed_data.get('receipt_time'),
                parsed_data.get('total_amount'),
                parsed_data.get('currency', 'ZAR'),
                parsed_data.get('items_count', 0),
                parsed_data.get('receipt_number'),
                parsed_data.get('pos_terminal'),
                confidence,
                'completed',
                datetime.now()
            )
            
            cursor.execute(insert_query, values)
            connection.commit()
            
            ocr_data_id = cursor.lastrowid
            logger.info(f"OCR data saved successfully with ID: {ocr_data_id}")
            
            cursor.close()
            connection.close()
            
            return ocr_data_id
            
        except Exception as e:
            logger.error(f"Failed to save OCR data: {e}")
            raise
    
    def process_receipt(self, receipt_id: int) -> bool:
        """Process a single receipt"""
        try:
            # Get receipt information from database
            connection = self.get_db_connection()
            cursor = connection.cursor(dictionary=True)
            
            cursor.execute("""
                SELECT id, user_id, file_path, file_name, upload_status
                FROM receipt_uploads 
                WHERE id = %s AND upload_status = 'uploaded'
            """, (receipt_id,))
            
            receipt = cursor.fetchone()
            cursor.close()
            connection.close()
            
            if not receipt:
                logger.error(f"Receipt {receipt_id} not found or not in uploaded status")
                return False
            
            # Build full file path
            file_path = f"/var/www/html/uploads/{receipt['file_path']}"
            
            if not os.path.exists(file_path):
                logger.error(f"Receipt file not found: {file_path}")
                return False
            
            logger.info(f"Processing receipt {receipt_id}: {receipt['file_name']}")
            
            # Extract text using OCR
            raw_text, confidence = self.extract_text_from_image(file_path)
            
            if not raw_text:
                logger.error(f"No text extracted from receipt {receipt_id}")
                return False
            
            # Parse structured data
            parsed_data = self.parse_receipt_data(raw_text)
            
            # Save OCR data to database
            self.save_ocr_data(receipt_id, raw_text, parsed_data, confidence)
            
            logger.info(f"Receipt {receipt_id} processed successfully")
            return True
            
        except Exception as e:
            logger.error(f"Receipt processing failed for ID {receipt_id}: {e}")
            return False

def main():
    """Main function"""
    parser = argparse.ArgumentParser(description='Process receipt with OCR')
    parser.add_argument('--receipt-id', type=int, required=True, help='Receipt ID to process')
    
    args = parser.parse_args()
    
    # Database configuration
    db_config = {
        'host': 'localhost',
        'user': 'prime_usr',
        'password': '!1945@Tata!',
        'database': 'prime_dbs_loyalty',
        'charset': 'utf8mb4',
        'collation': 'utf8mb4_unicode_ci'
    }
    
    try:
        processor = ReceiptOCRProcessor(db_config)
        success = processor.process_receipt(args.receipt_id)
        
        if success:
            logger.info(f"Receipt {args.receipt_id} processed successfully")
            sys.exit(0)
        else:
            logger.error(f"Failed to process receipt {args.receipt_id}")
            sys.exit(1)
            
    except Exception as e:
        logger.error(f"OCR processor failed: {e}")
        sys.exit(1)

if __name__ == "__main__":
    main()

