import argparse
import os
import numpy as np
import cv2
from stl import mesh
import pymeshlab

def read_depth_image(file_name):
    im = cv2.imread(file_name, cv2.IMREAD_UNCHANGED)
    im_array = np.array(im)
    im_array = np.rot90(im_array, -1, (0, 1))
    im_array = np.max(im_array) - im_array
    return im_array

def create_initial_mesh(im_array, height_div_width):
    mesh_size = im_array.shape[:2]
    mesh_max = np.max(im_array)
    
    if len(im_array.shape) == 3:
        scaled_mesh = mesh_size[0] * float(height_div_width) * im_array[:,:,0] / mesh_max
    else:
        scaled_mesh = mesh_size[0] * float(height_div_width) * im_array / mesh_max
    
    mesh_shape = mesh.Mesh(np.zeros((mesh_size[0] - 1) * (mesh_size[1] - 1) * 2, dtype=mesh.Mesh.dtype))
    
    for i in range(0, mesh_size[0]-1):
        for j in range(0, mesh_size[1]-1):
            mesh_num = i * (mesh_size[1]-1) + j
            
            mesh_shape.vectors[2 * mesh_num][2] = [i, j, scaled_mesh[i,j]]
            mesh_shape.vectors[2 * mesh_num][1] = [i, j+1, scaled_mesh[i,j+1]]
            mesh_shape.vectors[2 * mesh_num][0] = [i+1, j, scaled_mesh[i+1,j]]
            
            mesh_shape.vectors[2 * mesh_num + 1][0] = [i+1, j+1, scaled_mesh[i+1,j+1]]
            mesh_shape.vectors[2 * mesh_num + 1][1] = [i, j+1, scaled_mesh[i,j+1]]
            mesh_shape.vectors[2 * mesh_num + 1][2] = [i+1, j, scaled_mesh[i+1,j]]
    
    return mesh_shape

def optimize_mesh(mesh_file, target_faces):
    ms = pymeshlab.MeshSet()
    ms.load_new_mesh(mesh_file)
    ms.meshing_decimation_quadric_edge_collapse(targetfacenum=target_faces, preservenormal=True)
    ms.meshing_isotropic_explicit_remeshing(iterations=1)
    return ms

def create_base_mesh(optimized_mesh, base_height_ratio):
    min_x, max_x = np.min(optimized_mesh.x), np.max(optimized_mesh.x)
    min_y, max_y = np.min(optimized_mesh.y), np.max(optimized_mesh.y)
    min_z = np.min(optimized_mesh.z)
    
    width = max_x - min_x
    base_height = base_height_ratio * width
    
    base_mesh = mesh.Mesh(np.zeros(12, dtype=mesh.Mesh.dtype))
    
    # Bottom face
    base_mesh.vectors[0] = [[min_x, min_y, min_z-base_height], [max_x, min_y, min_z-base_height], [max_x, max_y, min_z-base_height]]
    base_mesh.vectors[1] = [[min_x, min_y, min_z-base_height], [max_x, max_y, min_z-base_height], [min_x, max_y, min_z-base_height]]
    
    # Side faces
    base_mesh.vectors[2] = [[min_x, min_y, min_z-base_height], [min_x, max_y, min_z-base_height], [min_x, max_y, min_z]]
    base_mesh.vectors[3] = [[min_x, min_y, min_z-base_height], [min_x, max_y, min_z], [min_x, min_y, min_z]]
    
    base_mesh.vectors[4] = [[max_x, min_y, min_z-base_height], [max_x, min_y, min_z], [max_x, max_y, min_z]]
    base_mesh.vectors[5] = [[max_x, min_y, min_z-base_height], [max_x, max_y, min_z], [max_x, max_y, min_z-base_height]]
    
    base_mesh.vectors[6] = [[min_x, min_y, min_z-base_height], [min_x, min_y, min_z], [max_x, min_y, min_z]]
    base_mesh.vectors[7] = [[min_x, min_y, min_z-base_height], [max_x, min_y, min_z], [max_x, min_y, min_z-base_height]]
    
    base_mesh.vectors[8] = [[min_x, max_y, min_z-base_height], [max_x, max_y, min_z-base_height], [max_x, max_y, min_z]]
    base_mesh.vectors[9] = [[min_x, max_y, min_z-base_height], [max_x, max_y, min_z], [min_x, max_y, min_z]]
    
    # Top face (connects to the relief)
    base_mesh.vectors[10] = [[min_x, min_y, min_z], [max_x, max_y, min_z], [max_x, min_y, min_z]]
    base_mesh.vectors[11] = [[min_x, min_y, min_z], [min_x, max_y, min_z], [max_x, max_y, min_z]]
    
    return base_mesh

def depth_image_to_stl(input_file, output_file, height_div_width, target_faces, base_height_ratio):
    im_array = read_depth_image(input_file)
    
    initial_mesh = create_initial_mesh(im_array, height_div_width)
    temp_file = 'temp_mesh.stl'
    initial_mesh.save(temp_file)
    
    ms = optimize_mesh(temp_file, target_faces)
    
    optimized_temp_file = 'optimized_temp_mesh.stl'
    ms.save_current_mesh(optimized_temp_file)
    
    optimized_mesh = mesh.Mesh.from_file(optimized_temp_file)
    base_mesh = create_base_mesh(optimized_mesh, base_height_ratio)
    
    combined_mesh = mesh.Mesh(np.concatenate([optimized_mesh.data, base_mesh.data]))
    combined_mesh.save(output_file)
    
    os.remove(temp_file)
    os.remove(optimized_temp_file)
    
    print(f"Optimized mesh with base saved as {output_file}")

def main():
    parser = argparse.ArgumentParser(description="Convert a depth image to an STL file with an added base.")
    parser.add_argument("input_file", help="Input depth image file")
    parser.add_argument("output_file", help="Output STL file")
    parser.add_argument("--height_div_width", type=float, default=0.1, help="Height to width ratio")
    parser.add_argument("--target_faces", type=int, default=500000, help="Target number of faces for optimization")
    parser.add_argument("--base_height_ratio", type=float, default=0.1, help="Base height ratio")
    
    args = parser.parse_args()
    
    depth_image_to_stl(args.input_file, args.output_file, args.height_div_width, args.target_faces, args.base_height_ratio)

if __name__ == "__main__":
    main()