Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import re
- from copy import copy
- from functools import wraps
- from typing import List, Dict, Tuple, Optional, Union, Any, Iterable, Set, Protocol
- from openpyxl import load_workbook
- from openpyxl.cell.cell import TYPE_FORMULA
- from openpyxl.drawing.image import Image
- from openpyxl.formula.translate import Translator
- from openpyxl.utils import get_column_letter
- from openpyxl.worksheet.cell_range import CellRange
- from base64 import b64decode
- from io import BytesIO
- from pydantic import BaseModel
- def make_tuple(func):
- @wraps(func)
- def wrapper(*args, **kwargs):
- return tuple(func(*args, **kwargs))
- return wrapper
- @make_tuple
- def split_fields(t, sep='.'):
- s = 0
- i = 0
- in_bracket = 0
- for i, c in enumerate(t):
- if c == '(':
- in_bracket += 1
- elif c == ')':
- in_bracket -= 1
- elif not in_bracket and c == sep:
- yield t[s:i]
- s = i + 1
- if s != i:
- yield t[s:]
- def xlsx_func(fp: str, data: BaseModel) -> BytesIO:
- sheet_template = SheetTemplate(fp)
- sheet_template.parse(data)
- ret = BytesIO()
- sheet_template.save(ret)
- ret.seek(0)
- return ret
- class TemplateFunc(Protocol):
- def __call__(self, template: 'SheetTemplate', coords=None, params=None, level=None, value=None, **kwargs) -> None:
- pass
- def func_row(template: 'SheetTemplate', coords, params, level, value: Union[BaseModel, Any], **kwargs):
- field = params[:-2]
- value_list = getattr(value, field, [])
- for row, col in coords:
- template.add_rows(field, level, row, len(value_list))
- for i in range(len(value_list)):
- template.ws.cell(row + i, col).value = i + 1
- def func_sum(template: 'SheetTemplate', coords, params, value: Union[BaseModel, Iterable], **kwargs):
- fields = split_fields(params)
- ret = None
- for level, field in enumerate(fields):
- if value is None:
- break
- if isinstance(value, list):
- ret = 0
- for item in value:
- for f in fields[level:]:
- item = getattr(item, f, None)
- if not item:
- continue
- ret += item
- break
- if field.endswith('[]'):
- field = field[:-2]
- value = getattr(value, field, None)
- for row, col in coords:
- template.ws.cell(row, col).value = ret or ''
- class SheetTemplate:
- delimiter = '$'
- idpattern = r'(?a:[_a-z][\(\)\[\]._a-z0-9]*)'
- braceidpattern = None
- flags = re.IGNORECASE
- merged_cells: List[CellRange]
- field_coord: Dict[str, List[Tuple[int, int]]]
- fields_added_rows: Set
- template_funcs: Dict[str, TemplateFunc] = {
- 'row': func_row,
- 'sum': func_sum,
- }
- def __init_subclass__(cls):
- super().__init_subclass__()
- if 'pattern' in cls.__dict__:
- pattern = cls.pattern
- else:
- delim = re.escape(cls.delimiter)
- id = cls.idpattern
- bid = cls.braceidpattern or cls.idpattern
- pattern = fr"""
- {delim}(?:
- (?P<escaped>{delim}) | # Escape sequence of two delimiters
- (?P<named>{id}) | # delimiter and a Python identifier
- {{(?P<braced>{bid})}} | # delimiter and a braced identifier
- (?P<invalid>) # Other ill-formed delimiter exprs
- )
- """
- cls.pattern = re.compile(pattern, cls.flags | re.VERBOSE)
- def __init__(self, template_fp: str, sheet_index=0):
- self.wb = load_workbook(template_fp, read_only=False)
- self.ws = self.wb.worksheets[sheet_index]
- # 记录合拼单元格
- self.merged_cells = list(self.ws.merged_cells.ranges)
- self.field_coord = dict()
- self.fields_added_rows = set()
- def parse(self, data: BaseModel):
- # unmerge cells
- self.unmerge()
- for row in self.ws.iter_rows():
- for cell in row:
- value = cell.value
- if not value or not isinstance(value, str):
- continue
- if (m := self.pattern.search(value)) is None:
- continue
- if (field_name := m['named']) is None:
- continue
- if field_name in self.field_coord:
- self.field_coord[field_name].append((cell.row, cell.column))
- else:
- self.field_coord[field_name] = [(cell.row, cell.column)]
- for raw_field, coords in self.field_coord.items():
- fields = split_fields(raw_field)
- self.parse_field(fields, 0, data, coords)
- # re-merge cells
- self.merge()
- def insert_rows(self, idx, amount=1):
- """
- :param amount:
- :param idx:
- :return:
- """
- self.ws.insert_rows(idx, amount)
- for j in range(amount):
- for i in range(1, 1+self.ws.max_column): # 获取基础行数据
- cell = self.ws.cell(idx-1, i)
- value = cell.value
- if cell.data_type == TYPE_FORMULA:
- value = Translator(value, origin=f'{get_column_letter(i)}{idx-1}').translate_formula(f'{get_column_letter(i)}{idx+j}')
- self.ws.cell(idx+j, i).value = value
- for i, cell_range in enumerate(self.merged_cells):
- if cell_range.min_row < idx:
- continue
- self.merged_cells[i].min_row += amount
- self.merged_cells[i].max_row += amount
- for key in self.field_coord:
- coords = self.field_coord[key]
- for i, (row, col) in enumerate(coords):
- if row >= idx:
- self.field_coord[key][i] = (row + amount, col)
- def unmerge(self):
- """
- :return:
- """
- for i in self.merged_cells:
- self.ws.unmerge_cells(str(i)) # 解除单元格
- def merge(self):
- """
- :return:
- """
- for i in self.merged_cells:
- self.ws.merge_cells(str(i))
- def save(self, fp: str):
- self.wb.save(fp)
- def add_rows(self, fields, level, row, amount=1): # 插入行并复制和格式
- """
- insert rows & copy style
- :param fields:
- :param level:
- :param row:
- :param amount:
- :return:
- """
- if tuple(fields[:level]) in self.fields_added_rows:
- return
- row_height = self.ws.row_dimensions[row].height
- heights = [self.ws.row_dimensions[x].height for x in range(row + 1, self.ws.max_row + 1)]
- self.insert_rows(row + 1, amount - 1)
- for x, height in enumerate(heights):
- self.ws.row_dimensions[row + amount + x].height = height
- for i in range(row+1, row + amount):
- self.ws.row_dimensions[i].height = row_height
- for j in range(1, amount):
- for i in range(1, self.ws.max_column + 1):
- for f in ('style', 'border', 'font', 'data_type', 'alignment', 'fill', 'comment'):
- v = copy(getattr(self.ws[f'{get_column_letter(i)}{row}'], f))
- setattr(self.ws[f'{get_column_letter(i)}{row + j}'], f, v)
- self.fields_added_rows.add(tuple(fields[:level]))
- def get_func(self, func_name):
- """
- get registered func by func_name
- :param func_name:
- :return:
- """
- if func_name not in self.template_funcs:
- raise ValueError(f'{func_name} is an invalid template function')
- return self.template_funcs[func_name]
- def parse_field(self, fields: Tuple, level: int, value: Union[BaseModel, Any], coords):
- field = fields[level]
- if m := re.match(r'(?P<func_name>[^(]+)\((?P<params>[^)]+)\)', field):
- coords = self.field_coord[field]
- func = self.get_func(m['func_name'])
- func(self, coords=coords, params=m["params"], level=level, value=value)
- return
- if field.endswith('[]'):
- field = field[:-2]
- list_value = getattr(value, field, [])
- for row, col in coords:
- self.add_rows(fields, level, row, len(list_value))
- for i, value in enumerate(list_value):
- self.parse_field(fields, level + 1, value, ((row + i, col),))
- else:
- value = getattr(value, field, None)
- if isinstance(value, BaseModel):
- self.parse_field(fields, level + 1, value, coords)
- # TODO refactor needed
- else:
- for row, col in coords:
- self.set_cell_value(row, col, value)
- def set_cell_value(self, row, col, value):
- if isinstance(value, str) and value.startswith("data:image/png;base64"):
- t = value.split(",")[-1]
- b = b64decode(t)
- img = Image(BytesIO(b))
- w, h = img.width, img.height
- cell_w = self.ws.column_dimensions[get_column_letter(col)].width * 7
- cell_h = (self.ws.row_dimensions[row].height or 120) * 1.33 # TODO
- r1 = w / h
- r2 = cell_w / cell_h
- ratio = cell_h / h if r2 > r1 else cell_w / w
- img.width, img.height = w * ratio, h * ratio
- self.ws.add_image(img, f'{get_column_letter(col)}{row}')
- self.ws.cell(row, col).value = ''
- elif value:
- self.ws.cell(row, col).value = str(value) or ''
- else:
- self.ws.cell(row, col).value = ''
- SheetTemplate.__init_subclass__()
- class SalesQuotationBase(BaseModel):
- cn_name: Optional[str]
- img: Optional[bytes]
- brand: Optional[str]
- cat_no: Optional[str]
- offer_price: Optional[float]
- class SalesQuotationDetail(BaseModel):
- quotation: SalesQuotationBase
- quantity: Optional[float]
- class SalesQuotation(BaseModel):
- code: str
- customer_name: str
- contact_name: str
- children: List[SalesQuotationDetail]
- def gen_quotation():
- children = [
- SalesQuotationDetail(
- quotation=SalesQuotationBase(cn_name=i, brand='cato', cat_no=f'A{i:0>5}', offer_price=1000),
- quantity=i
- )
- for i in range(9)
- ]
- return SalesQuotation(
- code='SDF200001010005',
- customer_name='客户名称',
- contact_name='联系人名称',
- children=children
- )
- if __name__ == '__main__':
- fp = r"H:\Users\Pandaaaa\Downloads\Telegram Desktop\UWA-报价单模板.xlsx"
- t = SheetTemplate(fp)
- quotation = gen_quotation()
- t.parse(quotation)
- t.save(r'D:\test.xlsx')
- pass
Add Comment
Please, Sign In to add comment