Advertisement
Georgiy1108

Untitled

Jun 17th, 2024
756
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 10.15 KB | Science | 0 0
  1. import os
  2. import face_recognition
  3. from collections import defaultdict
  4. import shutil
  5. import torch
  6. import numpy as np
  7. import time
  8. import logging
  9. import json
  10.  
  11. # Глобальные переменные
  12. LIMITED_FILES_COUNT = 20
  13. THRESHOLDS = [0.43, 0.54, 0.6]
  14. EMBEDDINGS_DIR = "embeddings"
  15.  
  16. # Настройка логирования
  17. logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s', handlers=[logging.FileHandler("log.txt"), logging.StreamHandler()])
  18.  
  19. # Функция для извлечения префикса из имени файла
  20. def get_prefix(filename):
  21.     parts = filename.split('-')
  22.     if len(parts) >= 3:
  23.         return '-'.join(parts[:2])
  24.     else:
  25.         prefix = ''
  26.         for char in filename:
  27.             if char.isdigit():
  28.                 break
  29.             prefix += char
  30.         return prefix if prefix else filename
  31.  
  32. # Функция для извлечения эмбеддингов лиц и сохранения их в файл
  33. def extract_embeddings(image_path):
  34.     embedding_path = os.path.join(EMBEDDINGS_DIR, os.path.basename(image_path) + ".json")
  35.     if os.path.exists(embedding_path):
  36.         with open(embedding_path, 'r') as f:
  37.             embedding = json.load(f)
  38.         return np.array(embedding)
  39.     else:
  40.         logging.info(f"Processing file: {image_path}")
  41.         image = face_recognition.load_image_file(image_path)
  42.         face_encodings = face_recognition.face_encodings(image)
  43.         if face_encodings:
  44.             embedding = face_encodings[0]
  45.             with open(embedding_path, 'w') as f:
  46.                 json.dump(embedding.tolist(), f)
  47.             return embedding
  48.         else:
  49.             return None
  50.  
  51. # Группировка файлов по префиксу
  52. def group_by_prefix(file_list):
  53.     groups = defaultdict(list)
  54.     for file in file_list:
  55.         prefix = get_prefix(file)
  56.         groups[prefix].append(file)
  57.     return groups
  58.  
  59. # Функция для извлечения и кэширования эмбеддингов для всех изображений
  60. def extract_and_cache_embeddings(grouped_files):
  61.     embeddings_cache = {}
  62.     for key, files in grouped_files.items():
  63.         limited_files = files[:LIMITED_FILES_COUNT // 2] + files[-LIMITED_FILES_COUNT // 2:]
  64.         embeddings_cache[key] = [(extract_embeddings(file), file) for file in limited_files if extract_embeddings(file) is not None]
  65.     return embeddings_cache
  66.  
  67. # Функция для сравнения двух эмбеддингов
  68. def compare_embeddings(embedding1, embedding2):
  69.     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  70.     embedding1 = torch.tensor(embedding1).to(device)
  71.     embedding2 = torch.tensor(embedding2).to(device)
  72.     distance = torch.nn.functional.pairwise_distance(embedding1.unsqueeze(0), embedding2.unsqueeze(0)).item()
  73.     return distance
  74.  
  75. # Функция для объединения групп на основе новых правил
  76. def merge_groups(groups, embeddings_cache, threshold):
  77.     merged_groups = []
  78.     group_keys = list(groups.keys())
  79.  
  80.     while group_keys:
  81.         current_key = group_keys.pop(0)
  82.         current_group = groups[current_key]
  83.         current_embeddings = embeddings_cache[current_key]
  84.  
  85.         i = 0
  86.         while i < len(group_keys):
  87.             comparison_key = group_keys[i]
  88.             comparison_group = groups[comparison_key]
  89.             comparison_embeddings = embeddings_cache[comparison_key]
  90.            
  91.             total_comparisons = 0
  92.             distances = []
  93.  
  94.             for (emb1, file1) in current_embeddings:
  95.                 for (emb2, file2) in comparison_embeddings:
  96.                     total_comparisons += 1
  97.                     distance = compare_embeddings(emb1, emb2)
  98.                     distances.append(distance)
  99.                     logging.info(f"Comparing {file1} and {file2}: Distance = {distance}")
  100.  
  101.             if total_comparisons == 0:
  102.                 i += 1
  103.                 continue
  104.  
  105.             # Применяем новые правила
  106.             count_45 = sum(1 for d in distances if d <= 0.45)
  107.             count_50 = sum(1 for d in distances if d <= 0.5)
  108.             count_55 = sum(1 for d in distances if d <= 0.55)
  109.             count_57 = sum(1 for d in distances if d >= 0.57)
  110.             count_65 = sum(1 for d in distances if d > 0.65)
  111.             count_70 = sum(1 for d in distances if d > 0.7)
  112.  
  113.             percent_45 = count_45 / total_comparisons
  114.             percent_50 = count_50 / total_comparisons
  115.             percent_55 = count_55 / total_comparisons
  116.             percent_57 = count_57 / total_comparisons
  117.             percent_65 = count_65 / total_comparisons
  118.             percent_70 = count_70 / total_comparisons
  119.  
  120.             logging.info(f"Group {current_key} vs Group {comparison_key}: {percent_45*100:.2f}% <= 0.45, {percent_50*100:.2f}% <= 0.5, {percent_55*100:.2f}% <= 0.55, {percent_57*100:.2f}% >= 0.57, {percent_65*100:.2f}% > 0.65, {percent_70*100:.2f}% > 0.7")
  121.  
  122.             if percent_45 >= 0.1:
  123.                 current_group.extend(comparison_group)
  124.                 del groups[comparison_key]
  125.                 del embeddings_cache[comparison_key]
  126.                 group_keys.pop(i)
  127.             elif percent_50 >= 0.4:
  128.                 current_group.extend(comparison_group)
  129.                 del groups[comparison_key]
  130.                 del embeddings_cache[comparison_key]
  131.                 group_keys.pop(i)
  132.             elif percent_55 >= 0.7:
  133.                 current_group.extend(comparison_group)
  134.                 del groups[comparison_key]
  135.                 del embeddings_cache[comparison_key]
  136.                 group_keys.pop(i)
  137.             elif percent_65 >= 0.3:
  138.                 i += 1
  139.             elif percent_70 >= 0.1:
  140.                 i += 1
  141.             elif percent_57 >= 0.8:
  142.                 i += 1
  143.             else:
  144.                 i += 1
  145.  
  146.         merged_groups.append((current_key, current_group))
  147.  
  148.     # Преобразуем список обратно в словарь для следующей итерации
  149.     new_groups = {}
  150.     new_embeddings_cache = {}
  151.     for i, (key, group) in enumerate(merged_groups):
  152.         new_key = f"group_{i}"
  153.         new_groups[new_key] = group
  154.         # Обновляем кэшированные эмбеддинги для новых групп
  155.         new_embeddings_cache[new_key] = []
  156.         for emb, file in embeddings_cache[key]:
  157.             if file in group:
  158.                 new_embeddings_cache[new_key].append((emb, file))
  159.  
  160.     return new_groups, new_embeddings_cache
  161.  
  162. # Функция для копирования файлов в выходную директорию
  163. def copy_files_to_output(merged_groups, output_directory):
  164.     if not os.path.exists(output_directory):
  165.         os.makedirs(output_directory)
  166.    
  167.     for i, (group_name, group) in enumerate(merged_groups.items()):
  168.         group_dir = os.path.join(output_directory, f"group_{i + 1}")
  169.         if not os.path.exists(group_dir):
  170.             os.makedirs(group_dir)
  171.        
  172.         for file in group:
  173.             shutil.copy(file, group_dir)
  174.  
  175. # Функция для копирования групп, содержащих только один префикс, в отдельную папку
  176. def copy_single_prefix_groups_to_output(groups, output_directory):
  177.     single_prefix_dir = os.path.join(output_directory, "single_prefix_groups")
  178.     if not os.path.exists(single_prefix_dir):
  179.         os.makedirs(single_prefix_dir)
  180.    
  181.     prefix_counts = defaultdict(int)
  182.     for group in groups.values():
  183.         for file in group:
  184.             prefix = get_prefix(file)
  185.             prefix_counts[prefix] += 1
  186.    
  187.     for i, (group_name, group) in enumerate(groups.items()):
  188.         group_prefixes = set(get_prefix(file) for file in group)
  189.         if all(prefix_counts[prefix] == len(group) for prefix in group_prefixes):
  190.             group_dir = os.path.join(single_prefix_dir, f"group_{i + 1}")
  191.             if not os.path.exists(group_dir):
  192.                 os.makedirs(group_dir)
  193.             for file in group:
  194.                 shutil.copy(file, group_dir)
  195.  
  196. # Основная функция
  197. def main(input_directory, output_directory):
  198.     # Проверка доступности GPU
  199.     if not torch.cuda.is_available():
  200.         logging.error("GPU is not available. Please ensure you have installed the correct drivers and CUDA toolkit.")
  201.         return
  202.  
  203.     logging.info(f"Using GPU: {torch.cuda.get_device_name(torch.cuda.current_device())}")
  204.  
  205.     # Создание директории для эмбеддингов, если она не существует
  206.     if not os.path.exists(EMBEDDINGS_DIR):
  207.         os.makedirs(EMBEDDINGS_DIR)
  208.  
  209.     start_time = time.time()
  210.    
  211.     files = [os.path.join(input_directory, file) for file in os.listdir(input_directory) if file.endswith('.jpg')]
  212.     grouped_files = group_by_prefix(files)
  213.    
  214.     # Вывод количества групп по префиксам
  215.     logging.info(f"Number of groups by prefix: {len(grouped_files)}")
  216.    
  217.     # Извлечение и кэширование эмбеддингов
  218. embeddings_cache = extract_and_cache_embeddings(grouped_files)
  219.  
  220. # Проводим многократную группировку с разными значениями порога
  221. for threshold in THRESHOLDS:
  222.     logging.info(f"Grouping with threshold: {threshold}")
  223.     grouped_files, embeddings_cache = merge_groups(grouped_files, embeddings_cache, threshold)
  224.     logging.info(f"Number of groups after grouping with threshold {threshold}: {len(grouped_files)}")
  225.  
  226. # Копируем файлы в выходную директорию
  227. copy_files_to_output(grouped_files, output_directory)
  228.  
  229. # Отделяем группы с одним префиксом и копируем в отдельную папку
  230. copy_single_prefix_groups_to_output(grouped_files, output_directory)
  231.  
  232. end_time = time.time()
  233. logging.info(f"Total execution time: {end_time - start_time} seconds")
  234.  
  235. for i, (group_name, group) in enumerate(grouped_files.items()):
  236.     logging.info(f"Group {i + 1}:")
  237.     for file in group:
  238.         logging.info(f"  {file}")
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement