Advertisement
formulake

launch.py

Feb 18th, 2024 (edited)
58
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.77 KB | Source Code | 0 0
  1. import os
  2. import gradio as gr
  3. import shutil
  4. import cv2
  5. import numpy as np
  6. import insightface
  7. import webbrowser
  8. import logging
  9. import threading
  10. from ifnude import detect
  11. import random  # Import the random module
  12.  
  13. # Initialize logging
  14. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
  15.  
  16. # Function to generate a random port within the specified range
  17. def generate_random_port():
  18.     return random.randint(7862, 7868)
  19.  
  20. def process_images(reference_image, source_folder, selected_buckets, nsfw_check, nsfw_sensitivity):
  21.     try:
  22.         # Load the reference image using OpenCV
  23.         ref_img = cv2.imread(reference_image.name)
  24.         ref_img_rgb = cv2.cvtColor(ref_img, cv2.COLOR_BGR2RGB)
  25.  
  26.         # Initialize InsightFace model
  27.         model = insightface.app.FaceAnalysis()
  28.         model.prepare(ctx_id=-1)
  29.  
  30.         ref_faces = model.get(ref_img_rgb)
  31.         if not ref_faces:
  32.             return "No face detected in the reference image. Please use a different image."
  33.         ref_embedding = ref_faces[0].embedding
  34.  
  35.         # Create folders based on selected buckets
  36.         for bucket in selected_buckets:
  37.             os.makedirs(os.path.join(source_folder, f"bucket_{int(float(bucket) * 100)}"), exist_ok=True)
  38.         os.makedirs(os.path.join(source_folder, "rejected"), exist_ok=True)
  39.  
  40.         # Check if source folder exists
  41.         if not os.path.exists(source_folder):
  42.             return "Invalid source folder path. Please check and try again."
  43.  
  44.         # Process source images
  45.         image_count = 0
  46.         for filename in os.listdir(source_folder):
  47.             if image_count >= 1000:
  48.                 break
  49.  
  50.             filepath = os.path.join(source_folder, filename)
  51.             src_img = cv2.imread(filepath)
  52.        
  53.             # Check if the image is loaded properly
  54.             if src_img is None:
  55.                 print(f"Error loading image: {filename}. Skipping...")
  56.                 continue
  57.  
  58.             # NSFW check
  59.             if nsfw_check:
  60.                 nsfw_result = detect(filepath)
  61.                 if nsfw_result and any([res['score'] > nsfw_sensitivity for res in nsfw_result]):
  62.                     shutil.move(filepath, os.path.join(source_folder, "rejected"))
  63.                     continue
  64.  
  65.             src_img_rgb = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)
  66.             src_faces = model.get(src_img_rgb)
  67.             if not src_faces:
  68.                 shutil.move(filepath, os.path.join(source_folder, "rejected"))
  69.                 continue
  70.  
  71.             src_embedding = src_faces[0].embedding
  72.             similarity = np.dot(ref_embedding, src_embedding) / (np.linalg.norm(ref_embedding) * np.linalg.norm(src_embedding))
  73.  
  74.             # Move to appropriate bucket
  75.             moved = False
  76.             for bucket in selected_buckets:
  77.                 if similarity >= float(bucket):
  78.                     shutil.move(filepath, os.path.join(source_folder, f"bucket_{int(float(bucket) * 100)}"))
  79.                     moved = True
  80.                     break
  81.  
  82.             if not moved:
  83.                 shutil.move(filepath, os.path.join(source_folder, "rejected"))
  84.  
  85.             image_count += 1
  86.  
  87.         return f"Processed {image_count} images."
  88.        
  89.     except Exception as e:
  90.         logging.error(f"Error processing images: {e}")
  91.         return f"Error: {e}"
  92.  
  93. def launch_gradio(event):
  94.     try:
  95.         # Generate a random port within the specified range
  96.         # port = 8800
  97.        
  98.         # Gradio Interface
  99.         iface = gr.Interface(
  100.             process_images,
  101.             [
  102.                 gr.components.File(label="Reference Image"),
  103.                 gr.components.Textbox(label="Source Folder Path"),
  104.                 gr.components.CheckboxGroup(choices=[str(i/10) for i in range(1, 11)], label="Select Buckets (10% increments)"),
  105.                 gr.components.Checkbox(label="Enable NSFW Check"),
  106.                 gr.components.Slider(minimum=0, maximum=1, label="NSFW Sensitivity")
  107.             ],
  108.             gr.components.Textbox(label="Process Status"),
  109.             # port=8800  # Use the generated random port
  110.         )
  111.         # Launch the interface using the generated random port
  112.         iface.launch(server_port=8800)
  113.         event.set()  # Signal that the Gradio server has started
  114.  
  115.         # Print the URL with the generated port
  116.         print(f"Gradio interface is available at: http://127.0.0.1:8800")
  117.        
  118.     except Exception as e:
  119.         logging.error(f"Error launching Gradio: {e}")
  120.  
  121. if __name__ == "__main__":
  122.     try:
  123.         event = threading.Event()
  124.         threading.Thread(target=launch_gradio, args=(event,)).start()
  125.         event.wait()  # Wait for the Gradio server to start
  126.     except Exception as e:
  127.         logging.error(f"Error in main: {e}")
  128.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement