from flask import Flask, request, jsonify
import os
import gc
import io
import cv2
import base64
import pathlib
import numpy as np
from PIL import Image
import cv2
import pytesseract
from transformers import AutoTokenizer, AutoModelForTokenClassification
from transformers import pipeline
import torch
import torchvision.transforms as torchvision_T
from torchvision.models.segmentation import deeplabv3_resnet50, deeplabv3_mobilenet_v3_large
import re
from pynlp import StanfordCoreNLP
from termcolor import colored

app = Flask(__name__)


def load_model(num_classes=2, model_name="mbv3", device=torch.device("cpu")):
    if model_name == "mbv3":
        model = deeplabv3_mobilenet_v3_large(num_classes=num_classes, aux_loss=True)
        checkpoint_path = os.path.join(os.getcwd(), "model_mbv3_iou_mix_2C049.pth")
    else:
        model = deeplabv3_resnet50(num_classes=num_classes, aux_loss=True)
        checkpoint_path = os.path.join(os.getcwd(), "model_r50_iou_mix_2C020.pth")

    model.to(device)
    checkpoints = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoints, strict=False)
    model.eval()

    _ = model(torch.randn((1, 3, 384, 384)))

    return model

def order_points(pts):
    """Rearrange coordinates to order:
    top-left, top-right, bottom-right, bottom-left"""
    rect = np.zeros((4, 2), dtype="float32")
    pts = np.array(pts)
    s = pts.sum(axis=1)

    rect[0] = pts[np.argmin(s)]
    rect[2] = pts[np.argmax(s)]

    diff = np.diff(pts, axis=1)
    rect[1] = pts[np.argmin(diff)]
    rect[3] = pts[np.argmax(diff)]
    return rect.astype("int").tolist()


def find_dest(pts):
    (tl, tr, br, bl) = pts
    # Finding the maximum width.
    widthA = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2))
    widthB = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2))
    maxWidth = max(int(widthA), int(widthB))

    # Finding the maximum height.
    heightA = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2))
    heightB = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2))
    maxHeight = max(int(heightA), int(heightB))
    # Final destination co-ordinates.
    destination_corners = [[0, 0], [maxWidth, 0], [maxWidth, maxHeight], [0, maxHeight]]

    return order_points(destination_corners)


def get_image_download_link(img, filename, text):
    buffered = io.BytesIO()
    img.save(buffered, format="JPEG")
    img_str = base64.b64encode(buffered.getvalue()).decode()
    print('img_str-->',img_str)
    # href = f'<a href="data:image/jpeg;base64,{img_str}" download="{filename}">{text}</a>'
    return img_str


def image_preprocess_transforms(mean=(0.4611, 0.4359, 0.3905), std=(0.2193, 0.2150, 0.2109)):
    common_transforms = torchvision_T.Compose(
        [
            torchvision_T.ToTensor(),
            torchvision_T.Normalize(mean, std),
        ]
    )
    return common_transforms


IMAGE_SIZE = 384
preprocess_transforms = image_preprocess_transforms()
image = None
final = None
result = None


def scan(image_true=None, trained_model=None, image_size=384, BUFFER=10):
    global preprocess_transforms

    IMAGE_SIZE = image_size
    half = IMAGE_SIZE // 2

    imH, imW, C = image_true.shape

    image_model = cv2.resize(image_true, (IMAGE_SIZE, IMAGE_SIZE), interpolation=cv2.INTER_NEAREST)

    scale_x = imW / IMAGE_SIZE
    scale_y = imH / IMAGE_SIZE

    image_model = preprocess_transforms(image_model)
    image_model = torch.unsqueeze(image_model, dim=0)

    with torch.no_grad():
        out = trained_model(image_model)["out"].cpu()

    del image_model
    gc.collect()

    out = torch.argmax(out, dim=1, keepdims=True).permute(0, 2, 3, 1)[0].numpy().squeeze().astype(np.int32)
    r_H, r_W = out.shape

    _out_extended = np.zeros((IMAGE_SIZE + r_H, IMAGE_SIZE + r_W), dtype=out.dtype)
    _out_extended[half : half + IMAGE_SIZE, half : half + IMAGE_SIZE] = out * 255
    out = _out_extended.copy()

    del _out_extended
    gc.collect()

    # Edge Detection.
    canny = cv2.Canny(out.astype(np.uint8), 225, 255)
    canny = cv2.dilate(canny, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)))
    contours, _ = cv2.findContours(canny, cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE)
    page = sorted(contours, key=cv2.contourArea, reverse=True)[0]

    # ==========================================
    epsilon = 0.02 * cv2.arcLength(page, True)
    corners = cv2.approxPolyDP(page, epsilon, True)

    corners = np.concatenate(corners).astype(np.float32)

    corners[:, 0] -= half
    corners[:, 1] -= half

    corners[:, 0] *= scale_x
    corners[:, 1] *= scale_y

    # check if corners are inside.
    # if not find smallest enclosing box, expand_image then extract document
    # else extract document

    if not (np.all(corners.min(axis=0) >= (0, 0)) and np.all(corners.max(axis=0) <= (imW, imH))):

        left_pad, top_pad, right_pad, bottom_pad = 0, 0, 0, 0

        rect = cv2.minAreaRect(corners.reshape((-1, 1, 2)))
        box = cv2.boxPoints(rect)
        box_corners = np.int32(box)
        #     box_corners = minimum_bounding_rectangle(corners)

        box_x_min = np.min(box_corners[:, 0])
        box_x_max = np.max(box_corners[:, 0])
        box_y_min = np.min(box_corners[:, 1])
        box_y_max = np.max(box_corners[:, 1])

        # Find corner point which doesn't satify the image constraint
        # and record the amount of shift required to make the box
        # corner satisfy the constraint
        if box_x_min <= 0:
            left_pad = abs(box_x_min) + BUFFER

        if box_x_max >= imW:
            right_pad = (box_x_max - imW) + BUFFER

        if box_y_min <= 0:
            top_pad = abs(box_y_min) + BUFFER

        if box_y_max >= imH:
            bottom_pad = (box_y_max - imH) + BUFFER

        # new image with additional zeros pixels
        image_extended = np.zeros((top_pad + bottom_pad + imH, left_pad + right_pad + imW, C), dtype=image_true.dtype)

        # adjust original image within the new 'image_extended'
        image_extended[top_pad : top_pad + imH, left_pad : left_pad + imW, :] = image_true
        image_extended = image_extended.astype(np.float32)

        # shifting 'box_corners' the required amount
        box_corners[:, 0] += left_pad
        box_corners[:, 1] += top_pad

        corners = box_corners
        image_true = image_extended

    corners = sorted(corners.tolist())
    corners = order_points(corners)
    destination_corners = find_dest(corners)
    M = cv2.getPerspectiveTransform(np.float32(corners), np.float32(destination_corners))

    final = cv2.warpPerspective(image_true, M, (destination_corners[2][0], destination_corners[2][1]), flags=cv2.INTER_LANCZOS4)
    final = np.clip(final, a_min=0, a_max=255)
    final = final.astype(np.uint8)

    return final



@app.route('/v1/user/vCard', methods=['POST'])
def scan_document():
    try:
        uploaded_file = request.files['image']
        filename =uploaded_file.filename
        if uploaded_file:
            file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
            image = cv2.imdecode(file_bytes, 1)
    
            model = load_model(model_name="r50")
            
            final = scan(image_true=image, trained_model=model, image_size=IMAGE_SIZE)

            result = Image.fromarray(final[:, :, ::-1])
            result.save('./static/'+filename)
        
        
            image = Image.open('./static/'+filename)
            text = pytesseract.image_to_string(image)
        
            x = text.split('\n')
            
            tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")
            model = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER")
            finalList = []

            details={}
            phone_pattern = r'\+?[\d\s()-]{11,15}'
            email_pattern = r'\S+@\S+'
            website_pattern =  r'\b(?:https?://|www\.)[^\s<>"\']+'

            mobileList  = []
            emailList = []
            websiteList = []
            addressList = []

            nlp = pipeline("ner", model=model, tokenizer=tokenizer)
            for textFinal in x:
                if textFinal == '' or textFinal == '\f':
                    pass
                else:
                    ner_results = nlp(textFinal)
            
                    if ner_results == []:
                        pass
                    else:
                        entities = ner_results[0]['entity']
                        
                        if entities == 'B-PER':
                            data = {
                                'name':textFinal,
                                'score':ner_results[0]['score']
                            }
                            finalList.append(data)
                        if entities == 'B-ORG':
                            data = {
                                'organization':textFinal,
                                'score':ner_results[0]['score']
                            }
                            finalList.append(data)
            
                    mobNumber = re.findall(phone_pattern, textFinal)
                    if mobNumber == []:
                        pass
                    else:
                        mobileList.append(mobNumber[0])
                    
                    email = re.findall(email_pattern, textFinal)
                    if email == []:
                        pass
                    else:
                        emailList.append(email[0])

                    website = re.findall(website_pattern, textFinal)
                    if website == []:
                        pass
                    else:
                        websiteList.append(website[0])

            target_key = 'organization'
            target_key_name = 'name'
            max_score = None
            max_score_name = None

            for item in finalList:
                if target_key in item:
                    if 'score' in item:
                        score = item['score']
                        if max_score is None or score > max_score:
                            max_score = score
                if target_key_name in item:
                    if 'score' in item:
                        score = item['score']
                        if max_score_name is None or score > max_score_name:
                            max_score_name = score
                            
            for i in finalList:
                if 'name' in i.keys():
                    if i['score'] == max_score_name:
                        details['personName']=i['name']
                    else:
                        addressList.append(i['name'])
                    

                if 'organization' in i.keys():
                    if i['score'] == max_score:
                        details['organizationName']=i['organization']
                    else:
                        addressList.append(i['organization'])
                
            if 'personName' not in details:
                details['personName'] = ''

            if 'organizationName' not in details:
                details['organizationName'] = ''

            strAddress= ','.join(addressList)
            details['Mobile'] = mobileList
            details['Email']= emailList
            details['Website']= websiteList 
            details['address'] = strAddress
            details['personDesignation'] = ''

            if len(mobileList) == 0 and len(emailList) == 0 and len(website) == 0 and details['personName'] == '' and details['organizationName'] == '' and details['personDesignation'] == '': 
                result = {
                    "status": 200,
                    "data": "",
                    "message": "Card Not Scanned Please try recapturing"
                } 
                return jsonify(result)
            else:
                details['imageUrl'] = './static/'+filename
                return jsonify({"status": 200,"data": details})
        else:
            return jsonify({"status": 200,"message": "No image uploaded"})
    except:
       return jsonify({"status": 200,"message": "Something went wrong", "data": ""})


@app.route('/v2/user/vCard', methods=['POST'])
def scan_documentV2():
    try:
        uploaded_file = request.files['image']
        filename =uploaded_file.filename
        if uploaded_file:
            file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
            image = cv2.imdecode(file_bytes, 1)
    
            model = load_model(model_name="r50")
            
            final = scan(image_true=image, trained_model=model, image_size=IMAGE_SIZE)

            result = Image.fromarray(final[:, :, ::-1])
            result.save('./static/'+filename)

            import easyocr
            reader = easyocr.Reader(['en'])
            result = reader.readtext('./static/'+filename, detail = 0)
            print(result)

            details = {}
            phone_pattern = r'\+?[\d\s()-]{11,20}'
            email_pattern = r'\S+@\S+'
            website_pattern =  r'\b(?:https?://|www\.)[^\s<>"\']+'

            mobileList  = []
            emailList = []
            websiteList = []
            addressList = []
            person = []
            organization = []
            designation = []

            for text in result:
                annotators = 'ner'
                options = {'openie.resolve_coref': True}

                nlp = StanfordCoreNLP(annotators=annotators, options=options)
                document = nlp(text)
                print('document---->',document)
                print('document type---->',document.entities)

                for entity in document.entities: 
                    if entity.type == 'PERSON':
                        person.append(str(entity))

                    if entity.type == 'ORGANIZATION':
                        organization.append(str(entity))

                    if entity.type == 'TITLE':
                        designation.append(text)

                    if entity.type == 'LOCATION':
                        addressList.append(text)

                if document.entities == []:
                    addressList.append(text)

                mobNumber = re.findall(phone_pattern, text)
                if mobNumber:
                    mobileList.append(mobNumber[0])
                
                email = re.findall(email_pattern, text)
                if email:
                    emailList.append(text)

                website = re.findall(website_pattern, text)
                if website:
                    websiteList.append(website[0])


            strAddress= ','.join(addressList)
            strPerson= ','.join(person)
            strorganization= ','.join(organization)
            strDesignation = ','.join(designation)
            details['Mobile'] = mobileList
            details['Email']= emailList
            details['Website']= websiteList 
            details['personName'] = strPerson
            details['organizationName'] = strorganization
            details['address'] = strAddress
            details['personDesignation'] = strDesignation


            print('details---->',details)
            if len(mobileList) == 0 and len(emailList) == 0 and len(website) == 0 and details['personName'] == '' and details['organizationName'] == '' and details['personDesignation'] == '': 
                result = {
                    "status": 200,
                    "data": "",
                    "message": "Card Not Scanned Please try recapturing"
                } 
                return jsonify(result)
            else:
                details['imageUrl'] = './static/'+filename
                return jsonify({"status": 200,"data": details})
        else:
            return jsonify({"status": 200,"message": "No image uploaded"})

    except:
        return jsonify({"status": 200,"message": "Something went wrong", "data": ""})



if __name__ == '__main__':
    app.run(host='0.0.0.0', port=8080)
