Pandaaaa906

Untitled

Jan 12th, 2022 (edited)
317
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 10.72 KB | None | 0 0
  1. import re
  2. from copy import copy
  3. from functools import wraps
  4. from typing import List, Dict, Tuple, Optional, Union, Any, Iterable, Set, Protocol
  5. from openpyxl import load_workbook
  6. from openpyxl.cell.cell import TYPE_FORMULA
  7. from openpyxl.drawing.image import Image
  8. from openpyxl.formula.translate import Translator
  9. from openpyxl.utils import get_column_letter
  10. from openpyxl.worksheet.cell_range import CellRange
  11. from base64 import b64decode
  12. from io import BytesIO
  13.  
  14. from pydantic import BaseModel
  15.  
  16.  
  17. def make_tuple(func):
  18.     @wraps(func)
  19.     def wrapper(*args, **kwargs):
  20.         return tuple(func(*args, **kwargs))
  21.     return wrapper
  22.  
  23.  
  24. @make_tuple
  25. def split_fields(t, sep='.'):
  26.     s = 0
  27.     i = 0
  28.     in_bracket = 0
  29.     for i, c in enumerate(t):
  30.         if c == '(':
  31.             in_bracket += 1
  32.         elif c == ')':
  33.             in_bracket -= 1
  34.         elif not in_bracket and c == sep:
  35.             yield t[s:i]
  36.             s = i + 1
  37.     if s != i:
  38.         yield t[s:]
  39.  
  40.  
  41. def xlsx_func(fp: str, data: BaseModel) -> BytesIO:
  42.     sheet_template = SheetTemplate(fp)
  43.     sheet_template.parse(data)
  44.     ret = BytesIO()
  45.     sheet_template.save(ret)
  46.     ret.seek(0)
  47.     return ret
  48.  
  49.  
  50. class TemplateFunc(Protocol):
  51.     def __call__(self, template: 'SheetTemplate', coords=None, params=None, level=None, value=None, **kwargs) -> None:
  52.         pass
  53.  
  54.  
  55. def func_row(template: 'SheetTemplate', coords, params, level, value: Union[BaseModel, Any], **kwargs):
  56.     field = params[:-2]
  57.     value_list = getattr(value, field, [])
  58.     for row, col in coords:
  59.         template.add_rows(field, level, row, len(value_list))
  60.         for i in range(len(value_list)):
  61.             template.ws.cell(row + i, col).value = i + 1
  62.  
  63.  
  64. def func_sum(template: 'SheetTemplate', coords, params, value: Union[BaseModel, Iterable], **kwargs):
  65.     fields = split_fields(params)
  66.     ret = None
  67.     for level, field in enumerate(fields):
  68.         if value is None:
  69.             break
  70.         if isinstance(value, list):
  71.             ret = 0
  72.             for item in value:
  73.                 for f in fields[level:]:
  74.                     item = getattr(item, f, None)
  75.                 if not item:
  76.                     continue
  77.                 ret += item
  78.             break
  79.         if field.endswith('[]'):
  80.             field = field[:-2]
  81.         value = getattr(value, field, None)
  82.     for row, col in coords:
  83.         template.ws.cell(row, col).value = ret or ''
  84.  
  85.  
  86.  
  87. class SheetTemplate:
  88.     delimiter = '$'
  89.     idpattern = r'(?a:[_a-z][\(\)\[\]._a-z0-9]*)'
  90.     braceidpattern = None
  91.     flags = re.IGNORECASE
  92.  
  93.     merged_cells: List[CellRange]
  94.     field_coord: Dict[str, List[Tuple[int, int]]]
  95.     fields_added_rows: Set
  96.  
  97.     template_funcs: Dict[str, TemplateFunc] = {
  98.         'row': func_row,
  99.         'sum': func_sum,
  100.     }
  101.  
  102.     def __init_subclass__(cls):
  103.         super().__init_subclass__()
  104.         if 'pattern' in cls.__dict__:
  105.             pattern = cls.pattern
  106.         else:
  107.             delim = re.escape(cls.delimiter)
  108.             id = cls.idpattern
  109.             bid = cls.braceidpattern or cls.idpattern
  110.             pattern = fr"""
  111.            {delim}(?:
  112.              (?P<escaped>{delim})  |   # Escape sequence of two delimiters
  113.              (?P<named>{id})       |   # delimiter and a Python identifier
  114.              {{(?P<braced>{bid})}} |   # delimiter and a braced identifier
  115.              (?P<invalid>)             # Other ill-formed delimiter exprs
  116.            )
  117.            """
  118.         cls.pattern = re.compile(pattern, cls.flags | re.VERBOSE)
  119.  
  120.     def __init__(self, template_fp: str, sheet_index=0):
  121.         self.wb = load_workbook(template_fp, read_only=False)
  122.         self.ws = self.wb.worksheets[sheet_index]
  123.         # 记录合拼单元格
  124.         self.merged_cells = list(self.ws.merged_cells.ranges)
  125.         self.field_coord = dict()
  126.         self.fields_added_rows = set()
  127.  
  128.     def parse(self, data: BaseModel):
  129.         # unmerge cells
  130.         self.unmerge()
  131.  
  132.         for row in self.ws.iter_rows():
  133.             for cell in row:
  134.                 value = cell.value
  135.                 if not value or not isinstance(value, str):
  136.                     continue
  137.                 if (m := self.pattern.search(value)) is None:
  138.                     continue
  139.                 if (field_name := m['named']) is None:
  140.                     continue
  141.                 if field_name in self.field_coord:
  142.                     self.field_coord[field_name].append((cell.row, cell.column))
  143.                 else:
  144.                     self.field_coord[field_name] = [(cell.row, cell.column)]
  145.  
  146.         for raw_field, coords in self.field_coord.items():
  147.             fields = split_fields(raw_field)
  148.             self.parse_field(fields, 0, data, coords)
  149.  
  150.         # re-merge cells
  151.         self.merge()
  152.  
  153.     def insert_rows(self, idx, amount=1):
  154.         """
  155.        :param amount:
  156.        :param idx:
  157.        :return:
  158.        """
  159.         self.ws.insert_rows(idx, amount)
  160.  
  161.         for j in range(amount):
  162.             for i in range(1, 1+self.ws.max_column):  # 获取基础行数据
  163.                 cell = self.ws.cell(idx-1, i)
  164.                 value = cell.value
  165.                 if cell.data_type == TYPE_FORMULA:
  166.                     value = Translator(value, origin=f'{get_column_letter(i)}{idx-1}').translate_formula(f'{get_column_letter(i)}{idx+j}')
  167.                 self.ws.cell(idx+j, i).value = value
  168.  
  169.         for i, cell_range in enumerate(self.merged_cells):
  170.             if cell_range.min_row < idx:
  171.                 continue
  172.             self.merged_cells[i].min_row += amount
  173.             self.merged_cells[i].max_row += amount
  174.  
  175.         for key in self.field_coord:
  176.             coords = self.field_coord[key]
  177.             for i, (row, col) in enumerate(coords):
  178.                 if row >= idx:
  179.                     self.field_coord[key][i] = (row + amount, col)
  180.  
  181.     def unmerge(self):
  182.         """
  183.        :return:
  184.        """
  185.         for i in self.merged_cells:
  186.             self.ws.unmerge_cells(str(i))  # 解除单元格
  187.  
  188.     def merge(self):
  189.         """
  190.        :return:
  191.        """
  192.         for i in self.merged_cells:
  193.             self.ws.merge_cells(str(i))
  194.  
  195.     def save(self, fp: str):
  196.         self.wb.save(fp)
  197.  
  198.     def add_rows(self, fields, level, row, amount=1):  # 插入行并复制和格式
  199.         """
  200.        insert rows & copy style
  201.        :param fields:
  202.        :param level:
  203.        :param row:
  204.        :param amount:
  205.        :return:
  206.        """
  207.         if tuple(fields[:level]) in self.fields_added_rows:
  208.             return
  209.         row_height = self.ws.row_dimensions[row].height
  210.         heights = [self.ws.row_dimensions[x].height for x in range(row + 1, self.ws.max_row + 1)]
  211.         self.insert_rows(row + 1, amount - 1)
  212.         for x, height in enumerate(heights):
  213.             self.ws.row_dimensions[row + amount + x].height = height
  214.         for i in range(row+1, row + amount):
  215.             self.ws.row_dimensions[i].height = row_height
  216.         for j in range(1, amount):
  217.             for i in range(1, self.ws.max_column + 1):
  218.                 for f in ('style', 'border', 'font', 'data_type', 'alignment', 'fill', 'comment'):
  219.                     v = copy(getattr(self.ws[f'{get_column_letter(i)}{row}'], f))
  220.                     setattr(self.ws[f'{get_column_letter(i)}{row + j}'], f, v)
  221.         self.fields_added_rows.add(tuple(fields[:level]))
  222.  
  223.  
  224.     def get_func(self, func_name):
  225.         """
  226.        get registered func by func_name
  227.        :param func_name:
  228.        :return:
  229.        """
  230.         if func_name not in self.template_funcs:
  231.             raise ValueError(f'{func_name} is an invalid template function')
  232.         return self.template_funcs[func_name]
  233.  
  234.     def parse_field(self, fields: Tuple, level: int, value: Union[BaseModel, Any], coords):
  235.         field = fields[level]
  236.  
  237.         if m := re.match(r'(?P<func_name>[^(]+)\((?P<params>[^)]+)\)', field):
  238.             coords = self.field_coord[field]
  239.             func = self.get_func(m['func_name'])
  240.             func(self, coords=coords, params=m["params"], level=level, value=value)
  241.             return
  242.  
  243.         if field.endswith('[]'):
  244.             field = field[:-2]
  245.             list_value = getattr(value, field, [])
  246.             for row, col in coords:
  247.                 self.add_rows(fields, level, row, len(list_value))
  248.                 for i, value in enumerate(list_value):
  249.                     self.parse_field(fields, level + 1, value, ((row + i, col),))
  250.         else:
  251.             value = getattr(value, field, None)
  252.             if isinstance(value, BaseModel):
  253.                 self.parse_field(fields, level + 1, value, coords)
  254.             # TODO refactor needed
  255.             else:
  256.                 for row, col in coords:
  257.                     self.set_cell_value(row, col, value)
  258.  
  259.     def set_cell_value(self, row, col, value):
  260.         if isinstance(value, str) and value.startswith("data:image/png;base64"):
  261.             t = value.split(",")[-1]
  262.             b = b64decode(t)
  263.             img = Image(BytesIO(b))
  264.             w, h = img.width, img.height
  265.             cell_w = self.ws.column_dimensions[get_column_letter(col)].width * 7
  266.             cell_h = (self.ws.row_dimensions[row].height or 120) * 1.33  # TODO
  267.             r1 = w / h
  268.             r2 = cell_w / cell_h
  269.             ratio = cell_h / h if r2 > r1 else cell_w / w
  270.             img.width, img.height = w * ratio, h * ratio
  271.             self.ws.add_image(img, f'{get_column_letter(col)}{row}')
  272.             self.ws.cell(row, col).value = ''
  273.         elif value:
  274.             self.ws.cell(row, col).value = str(value) or ''
  275.         else:
  276.             self.ws.cell(row, col).value = ''
  277.  
  278.  
  279. SheetTemplate.__init_subclass__()
  280.  
  281.  
  282. class SalesQuotationBase(BaseModel):
  283.     cn_name: Optional[str]
  284.     img: Optional[bytes]
  285.     brand: Optional[str]
  286.     cat_no: Optional[str]
  287.     offer_price: Optional[float]
  288.  
  289.  
  290. class SalesQuotationDetail(BaseModel):
  291.     quotation: SalesQuotationBase
  292.     quantity: Optional[float]
  293.  
  294.  
  295. class SalesQuotation(BaseModel):
  296.     code: str
  297.     customer_name: str
  298.     contact_name: str
  299.     children: List[SalesQuotationDetail]
  300.  
  301.  
  302.  
  303. def gen_quotation():
  304.     children = [
  305.         SalesQuotationDetail(
  306.             quotation=SalesQuotationBase(cn_name=i, brand='cato', cat_no=f'A{i:0>5}', offer_price=1000),
  307.             quantity=i
  308.         )
  309.         for i in range(9)
  310.     ]
  311.     return SalesQuotation(
  312.         code='SDF200001010005',
  313.         customer_name='客户名称',
  314.         contact_name='联系人名称',
  315.         children=children
  316.     )
  317.  
  318.  
  319. if __name__ == '__main__':
  320.     fp = r"H:\Users\Pandaaaa\Downloads\Telegram Desktop\UWA-报价单模板.xlsx"
  321.     t = SheetTemplate(fp)
  322.  
  323.     quotation = gen_quotation()
  324.     t.parse(quotation)
  325.     t.save(r'D:\test.xlsx')
  326.     pass
  327.  
Add Comment
Please, Sign In to add comment