Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import json
- from datetime import datetime
- from io import BytesIO
- import numpy as np
- import pandas as pd
- import requests
- from django.db.models import Q
- from django.utils.translation import gettext as _
- from sqlalchemy import text
- from sqlalchemy.orm import Session
- from config.settings import EMAIL_FROM, FRONTEND_SCENARIO_RESULT_URL, FORECAST_LEVELS_SEPARATOR
- from forecasts.clickhouse.db import get_session
- from forecasts.clickhouse.models import ForecastResultData, ForecastResultDataCorrected
- from forecasts.exceptions.classes import BlankDatesInPanelData, InvalidPartitionInPanelData, \
- NoActualShipmentsForMLPayloadCreation, MLServerUnavailable, MissingMlSettings, KeyCountExceededInMLBatch
- from forecasts.models import ForecastScenario, ForecastResult, MasterDataElement, ActualShipment, MasterDataLevel, \
- MasterDataHierarchy, MasterDataInnovation, MasterDataElementStatus
- from forecasts.services.ml_interaction.timeline_extender import TimelineExtender
- from forecasts.services.scenario_result.aggregated_scenario_data import AggregatedScenarioData
- from general.tasks import task_send_email
- class MLServerDataSender:
- default_dimensions = ['key', 'date', 'partition']
- partition_forecast = 'fc'
- partition_train = 'train'
- min_train_rows = 3
- def __init__(
- self,
- aggregated_scenario_data: AggregatedScenarioData,
- timeline_extender: TimelineExtender,
- session: Session = get_session()
- ):
- self._ml_settings = aggregated_scenario_data.scenario.project.forecast_general_settings.ml_settings
- self._session = session
- self._scenario_data = aggregated_scenario_data
- self._timeline_extender = timeline_extender
- project = self._scenario_data.scenario.project
- self._cannibalization_level = project.general_settings.cannibalization_level if bool(
- project.general_settings.cannibalization_level
- and project.general_settings.cannibalization_level.id
- in set(MasterDataLevel.objects.filter(hierarchy__name=MasterDataHierarchy.PRODUCT_HIERARCHY).values_list('id', flat=True))
- ) else None
- def execute(self):
- max_key_count = self._scenario_data.scenario.project.forecast_general_settings.max_calculation_key_count
- headers = {
- 'api_key': self._ml_settings.api_key,
- 'feature_types': json.dumps({'can_level': self._cannibalization_level.name if self._cannibalization_level else None}),
- }
- scenario_id = self._scenario_data.scenario.id
- scenario = ForecastScenario.objects.filter(id=scenario_id).first()
- scenario.last_process_error = None
- scenario.last_process_start = datetime.now()
- try:
- data = self._collect_data()
- except Exception as ex:
- scenario.last_process_error = str(ex)
- scenario.is_in_process = False
- scenario.save()
- raise ex
- scenario.save()
- all_records = []
- result_model = self._create_forecast_result(scenario_id)
- try:
- if self._cannibalization_level:
- grouped_by_level = data.groupby([self._cannibalization_level.name])
- for _, group in grouped_by_level:
- group_key_count = group['key'].nunique()
- if group_key_count > max_key_count:
- raise KeyCountExceededInMLBatch(group_key_count, max_key_count, language_code=self._scenario_data.scenario.project.language.code)
- records = self._call_ml(group, headers, result_model)
- all_records.extend(records)
- else:
- grouped_by_key = data.groupby(['key'])
- for i in range(0, len(grouped_by_key), max_key_count):
- sliced = pd.concat(
- [grouped_by_key.get_group(n) for n in list(dict(list(grouped_by_key)).keys())[i:i + max_key_count]])
- records = self._call_ml(sliced, headers, result_model)
- all_records.extend(records)
- except Exception as ex:
- scenario.last_process_error = str(ex)
- scenario.is_in_process = False
- scenario.save()
- result_model.delete()
- raise ex
- self._session.query(ForecastResultData) \
- .filter(ForecastResultData.forecast_result_id == result_model.id) \
- .delete()
- self._session.query(ForecastResultDataCorrected) \
- .filter(ForecastResultDataCorrected.forecast_result_id == result_model.id) \
- .delete()
- cleaned_records = self._remove_actuals_from_forecast_results(all_records)
- if cleaned_records:
- self._session.execute(
- ForecastResultData.__table__.insert(),
- cleaned_records
- )
- now = datetime.now()
- scenario.last_process_end = now
- scenario.processed = now
- scenario.is_in_process = False
- scenario.is_proposed = False
- scenario.save()
- self._send_email_on_forecast_processing_finish(scenario)
- def _call_ml(self, df: pd.DataFrame, headers: dict, result: ForecastResult) -> dict:
- bytes_io = BytesIO()
- df.to_parquet(bytes_io, index=False, compression='gzip')
- payload = bytes_io.getvalue()
- try:
- response = self._make_request(payload, headers)
- except Exception:
- language_code = self._scenario_data.scenario.project.language.code if self._scenario_data.scenario.project.language else None
- raise MLServerUnavailable(language_code=language_code)
- result_df = pd.read_parquet(BytesIO(response.content))
- result_df['date'] = pd.to_datetime(result_df['date'], format='%Y-%m-%d')
- result_df['forecast_result_id'] = result.id
- result_df.rename(columns={'fc': 'value', 'uplift': 'promo', 'bl': 'baseline', 'cn': 'cannibalization'},
- inplace=True)
- records = result_df.to_dict('records')
- return records
- def _create_forecast_result(self, scenario_id: int) -> ForecastResult:
- return ForecastResult.objects.create(
- scenario_id=scenario_id,
- cannibalization_level=self._cannibalization_level
- )
- def _collect_data(self) -> pd.DataFrame:
- if not all((self._ml_settings.host, self._ml_settings.endpoint, self._ml_settings.api_key)):
- raise MissingMlSettings(language_code=self._scenario_data.scenario.project.language.code)
- actual_shipments = self._get_actual_shipments_df(self._scenario_data.scenario.project_id)
- if actual_shipments.empty:
- raise NoActualShipmentsForMLPayloadCreation(language_code=self._scenario_data.scenario.project.language.code)
- if self._cannibalization_level:
- actual_shipments = self._apply_cannibalization(actual_shipments)
- resulting_df = self._update_data_based_on_forecast_horizon(actual_shipments)
- resulting_df.rename(columns={'actual': 'target'}, inplace=True)
- resulting_df['partition'] = resulting_df.apply(
- lambda row: self.partition_forecast if pd.isnull(row['target']) else self.partition_train, axis=1
- )
- resulting_df['target'].fillna(0, inplace=True)
- resulting_df['date'] = pd.to_datetime(resulting_df['date'], format='%Y-%m-%d')
- resulting_df.sort_values(['key', 'date'], inplace=True)
- granularity_errors = self._check_granularity(resulting_df)
- if granularity_errors:
- raise BlankDatesInPanelData(', '.join(str(error) for error in granularity_errors[:3]),
- len(granularity_errors), language_code=self._scenario_data.scenario.project.language.code)
- partition_errors = self._check_partitions(resulting_df)
- if partition_errors:
- raise InvalidPartitionInPanelData(', '.join(error for error in partition_errors[:3]),
- len(partition_errors), language_code=self._scenario_data.scenario.project.language.code)
- for dimension in self.default_dimensions:
- resulting_df[dimension] = pd.Categorical(resulting_df[dimension])
- without_innovations = self._remove_innovations_from_df(resulting_df)
- print(f"""
- Метод _collect_data:
- actual_shipments:
- {actual_shipments}
- resulting_df:
- {resulting_df}
- without_innovations:
- {without_innovations}
- ---------- End _collect_data ---------
- """)
- return without_innovations
- def _update_data_based_on_forecast_horizon(self, actual_shipments: pd.DataFrame):
- if self._scenario_data.df.empty:
- return self._timeline_extender.execute(
- actual_shipments,
- self._cannibalization_level.name if self._cannibalization_level else None
- )
- else:
- return self._remove_excess_rows(actual_shipments)
- def _make_request(self, payload: bytes, headers: dict):
- response = requests.post(
- url=f'{self._ml_settings.host}/{self._ml_settings.endpoint}',
- data=payload,
- headers=headers,
- timeout=self._scenario_data.scenario.project.forecast_general_settings.ml_settings.timeout
- )
- response.raise_for_status()
- return response
- def _check_granularity(self, df: pd.DataFrame) -> list:
- date_col_index = df.columns.get_loc("date")
- key_col_index = df.columns.get_loc("key")
- iterator = df.iterrows()
- first_index, first_row = next(iterator)
- previous_date = first_row[date_col_index]
- previous_key = first_row[key_col_index]
- granularity_validator = self._scenario_data.scenario.project.forecast_general_settings.\
- granularity.validators.adjacent_dates
- errors = []
- for index, row in iterator:
- is_same_key = row[key_col_index] == previous_key
- is_valid = granularity_validator(previous_date, row[date_col_index])
- if is_same_key and not is_valid:
- errors.append(f'{row[key_col_index]} ({previous_date.strftime("%Y-%m-%d")} - {row[date_col_index].strftime("%Y-%m-%d")})')
- previous_date = row[date_col_index]
- previous_key = row[key_col_index]
- return errors
- def _check_partitions(self, df: pd.DataFrame) -> list:
- keys_with_errors = []
- grouped = df.groupby(['key'])
- for _, group in grouped:
- if self.partition_forecast not in set(group['partition']) or self.min_train_rows > len(group[group['partition'] == self.partition_train]):
- keys_with_errors.append(group.iloc[0]['key'])
- return keys_with_errors
- def _remove_excess_rows(self, actuals_df: pd.DataFrame) -> pd.DataFrame:
- forecast_horizon = self._scenario_data.scenario.project.forecast_general_settings.forecast_horizon
- self._scenario_data.df.sort_values(['key', 'date'], inplace=True)
- grouped = self._scenario_data.df.groupby('key')
- def _remove_from_tail_for_group(group):
- if len(group) > forecast_horizon:
- group.drop(group.tail(len(group) - forecast_horizon).index,inplace=True)
- return group
- return group
- with_cut_timeline = grouped.apply(_remove_from_tail_for_group).reset_index(drop=True)
- concat = pd.concat([with_cut_timeline, actuals_df], ignore_index=True)
- concat.drop_duplicates(subset=['key', 'date'], keep='last', inplace=True)
- concat.drop(labels=list(self._scenario_data.fields.values()), axis=1, inplace=True)
- return concat
- def _remove_actuals_from_forecast_results(self, results: list[dict]):
- actual_shipments = self._get_actual_shipments_df(self._scenario_data.scenario.project_id)
- actual_shipments = {(actual['key'], actual['date'].strftime('%Y-%m-%d')) for _, actual in actual_shipments.iterrows()}
- cleaned_records = []
- for result in results:
- if (result['key'], result['date'].strftime('%Y-%m-%d')) not in actual_shipments:
- cleaned_records.append(result)
- return cleaned_records
- def _send_email_on_forecast_processing_finish(self, scenario):
- scenario_in_db = ForecastScenario.objects.filter(id=scenario.id).first()
- if scenario_in_db:
- frontend_url = f': {FRONTEND_SCENARIO_RESULT_URL.format(scenario.id)}' if FRONTEND_SCENARIO_RESULT_URL else ''
- task_send_email.delay({
- 'subject': _('Сценарий "{}" рассчитан').format(scenario.project.title),
- 'recipients': list(scenario.project.users.values_list('email', flat=True)),
- 'sender': EMAIL_FROM,
- 'body': _('Расчет сценария прогноза "{}" был завершен {}').format(
- scenario.project.title, frontend_url)
- })
- def _apply_cannibalization(self, actual_shipments: pd.DataFrame):
- order_by_key = self._scenario_data.scenario.project.product_hierarchy.order_by_key
- actual_shipments['product_key'] = actual_shipments.apply(lambda row: row['key'].split(FORECAST_LEVELS_SEPARATOR)[order_by_key], axis=1)
- system_keys = set(actual_shipments['product_key'].tolist())
- fc_levels = MasterDataElement.get_forecast_level_elements_by_system_keys_with_cannibalization_level(
- system_keys,
- self._scenario_data.scenario.project_id,
- self._cannibalization_level.id
- )
- df = pd.DataFrame(list(fc_levels), columns=['product_key', self._cannibalization_level.name])
- merged_df = actual_shipments.merge(df, on=['product_key'], how='left')
- merged_df.drop(labels=['product_key'], axis=1, inplace=True)
- return merged_df
- def _get_actual_shipments_df(self, project_id: int) -> pd.DataFrame:
- actual_shipments_ids = list(ActualShipment.objects.filter(
- project_id=project_id,
- deleted__isnull=True).values_list('id', flat=True))
- actual_shipments_data = self._session.execute(text(
- """
- SELECT
- key,
- date,
- value,
- created
- FROM actual_shipment_data
- WHERE actual_shipment_id IN :actual_shipments_ids
- """).bindparams(actual_shipments_ids=actual_shipments_ids)).all()
- actual_shipments = pd.DataFrame.from_records(
- actual_shipments_data,
- columns=['key', 'date', 'value', 'created']
- )
- actual_shipments['actual'] = actual_shipments['value'].replace(r'^\s*$', np.nan, regex=True).astype(float)
- actual_shipments.sort_values(['date', 'key', 'created'], inplace=True)
- actual_shipments.drop_duplicates(subset=['key', 'date'], keep='last', inplace=True)
- actual_shipments.drop(labels=['value', 'created'], axis=1, inplace=True)
- print(f"""
- Метод _get_actual_shipments_df:
- {actual_shipments}
- ---------- End _get_actual_shipments_df ---------
- """)
- return actual_shipments
- def _processing_transition_elements(self, df: pd.DataFrame) -> pd.DataFrame:
- order_by_key = self._scenario_data.scenario.project.product_hierarchy.order_by_key
- df['product_key'] = df.apply(lambda row: row['key'].split(FORECAST_LEVELS_SEPARATOR)[order_by_key], axis=1)
- system_keys = set(df['product_key'].tolist())
- elements_with_transition_status = MasterDataElement.objects.filter(
- Q(info__manual_status_id=MasterDataElementStatus.Options.RELAUNCH.value) |
- Q(info__calculated_status_id=MasterDataElementStatus.Options.RELAUNCH.value)
- ).filter(system_key__in=system_keys).values_list('system_key', flat=True)
- filtered_df = df[df['product_key'].map(lambda v: v not in set(elements_with_transition_status))]
- def _remove_innovations_from_df(self, df: pd.DataFrame) -> pd.DataFrame:
- order_by_key = self._scenario_data.scenario.project.product_hierarchy.order_by_key
- df['product_key'] = df.apply(lambda row: row['key'].split(FORECAST_LEVELS_SEPARATOR)[order_by_key], axis=1)
- system_keys = set(df['product_key'].tolist())
- elements_with_innovation_status = MasterDataElement.objects.filter(
- Q(info__manual_status_id=MasterDataElementStatus.Options.INNOVATION.value) | Q(info__calculated_status_id=MasterDataElementStatus.Options.INNOVATION.value)
- ).filter(system_key__in=system_keys).exclude(
- Q(innovation_as_dummy__is_archived=True) | Q(innovation_as_real__is_archived=True)
- ).values_list('system_key', flat=True)
- filtered_df = df[df['product_key'].map(lambda v: v not in set(elements_with_innovation_status))]
- print(f"""
- Метод _remove_innovations_from_df:
- system_keys:
- {system_keys}
- elements_with_innovation_status:
- {elements_with_innovation_status}
- filtered_df:
- {filtered_df}
- ---------- End _remove_innovations_from_df ---------
- """)
- return filtered_df
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement