Advertisement
elliottchong

Strict GPT

Jul 23rd, 2023 (edited)
2,355
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 5.33 KB | None | 0 0
  1. import { Configuration, OpenAIApi } from "openai";
  2.  
  3. const configuration = new Configuration({
  4.   apiKey: process.env.OPENAI_API_KEY,
  5. });
  6. const openai = new OpenAIApi(configuration);
  7.  
  8. interface OutputFormat {
  9.   [key: string]: string | string[] | OutputFormat;
  10. }
  11.  
  12. export async function strict_output(
  13.   system_prompt: string,
  14.   user_prompt: string | string[],
  15.   output_format: OutputFormat,
  16.   default_category: string = "",
  17.   output_value_only: boolean = false,
  18.   model: string = "gpt-3.5-turbo",
  19.   temperature: number = 1,
  20.   num_tries: number = 3,
  21.   verbose: boolean = false
  22. ) {
  23.   // if the user input is in a list, we also process the output as a list of json
  24.   const list_input: boolean = Array.isArray(user_prompt);
  25.   // if the output format contains dynamic elements of < or >, then add to the prompt to handle dynamic elements
  26.   const dynamic_elements: boolean = /<.*?>/.test(JSON.stringify(output_format));
  27.   // if the output format contains list elements of [ or ], then we add to the prompt to handle lists
  28.   const list_output: boolean = /\[.*?\]/.test(JSON.stringify(output_format));
  29.  
  30.   // start off with no error message
  31.   let error_msg: string = "";
  32.  
  33.   for (let i = 0; i < num_tries; i++) {
  34.     let output_format_prompt: string = `\nYou are to output ${
  35.       list_output && "an array of objects in"
  36.     } the following in json format: ${JSON.stringify(
  37.       output_format
  38.     )}. \nDo not put quotation marks or escape character \\ in the output fields.`;
  39.  
  40.     if (list_output) {
  41.       output_format_prompt += `\nIf output field is a list, classify output into the best element of the list.`;
  42.     }
  43.  
  44.     // if output_format contains dynamic elements, process it accordingly
  45.     if (dynamic_elements) {
  46.       output_format_prompt += `\nAny text enclosed by < and > indicates you must generate content to replace it. Example input: Go to <location>, Example output: Go to the garden\nAny output key containing < and > indicates you must generate the key name to replace it. Example input: {'<location>': 'description of location'}, Example output: {school: a place for education}`;
  47.     }
  48.  
  49.     // if input is in a list format, ask it to generate json in a list
  50.     if (list_input) {
  51.       output_format_prompt += `\nGenerate an array of json, one json for each input element.`;
  52.     }
  53.  
  54.     // Use OpenAI to get a response
  55.     const response = await openai.createChatCompletion({
  56.       temperature: temperature,
  57.       model: model,
  58.       messages: [
  59.         {
  60.           role: "system",
  61.           content: system_prompt + output_format_prompt + error_msg,
  62.         },
  63.         { role: "user", content: user_prompt.toString() },
  64.       ],
  65.     });
  66.  
  67.     let res: string =
  68.       response.data.choices[0].message?.content?.replace(/'/g, '"') ?? "";
  69.  
  70.    // ensure that we don't replace away apostrophes in text
  71.    res = res.replace(/(\w)"(\w)/g, "$1'$2");
  72.  
  73.     if (verbose) {
  74.       console.log(
  75.         "System prompt:",
  76.         system_prompt + output_format_prompt + error_msg
  77.       );
  78.       console.log("\nUser prompt:", user_prompt);
  79.       console.log("\nGPT response:", res);
  80.     }
  81.  
  82.     // try-catch block to ensure output format is adhered to
  83.     try {
  84.       let output: any = JSON.parse(res);
  85.  
  86.       if (list_input) {
  87.         if (!Array.isArray(output)) {
  88.           throw new Error("Output format not in an array of json");
  89.         }
  90.       } else {
  91.         output = [output];
  92.       }
  93.  
  94.       // check for each element in the output_list, the format is correctly adhered to
  95.       for (let index = 0; index < output.length; index++) {
  96.         for (const key in output_format) {
  97.           // unable to ensure accuracy of dynamic output header, so skip it
  98.           if (/<.*?>/.test(key)) {
  99.             continue;
  100.           }
  101.  
  102.           // if output field missing, raise an error
  103.           if (!(key in output[index])) {
  104.             throw new Error(`${key} not in json output`);
  105.           }
  106.  
  107.           // check that one of the choices given for the list of words is an unknown
  108.           if (Array.isArray(output_format[key])) {
  109.             const choices = output_format[key] as string[];
  110.             // ensure output is not a list
  111.             if (Array.isArray(output[index][key])) {
  112.               output[index][key] = output[index][key][0];
  113.             }
  114.             // output the default category (if any) if GPT is unable to identify the category
  115.             if (!choices.includes(output[index][key]) && default_category) {
  116.               output[index][key] = default_category;
  117.             }
  118.             // if the output is a description format, get only the label
  119.             if (output[index][key].includes(":")) {
  120.               output[index][key] = output[index][key].split(":")[0];
  121.             }
  122.           }
  123.         }
  124.  
  125.         // if we just want the values for the outputs
  126.         if (output_value_only) {
  127.           output[index] = Object.values(output[index]);
  128.           // just output without the list if there is only one element
  129.           if (output[index].length === 1) {
  130.             output[index] = output[index][0];
  131.           }
  132.         }
  133.       }
  134.  
  135.       return list_input ? output : output[0];
  136.     } catch (e) {
  137.       error_msg = `\n\nResult: ${res}\n\nError message: ${e}`;
  138.       console.log("An exception occurred:", e);
  139.       console.log("Current invalid json format ", res);
  140.     }
  141.   }
  142.  
  143.   return [];
  144. }
  145.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement