Source code for gpt4docstrings.generate_docstrings

import ast
import asyncio
import difflib
import os
import pathlib
import sys
from fnmatch import fnmatch
from typing import List
from typing import Union

import click
from colorama import Fore
from tabulate import tabulate
from tqdm.asyncio import tqdm_asyncio

from gpt4docstrings.ascii_title import title
from gpt4docstrings.config import GPT4DocstringsConfig
from gpt4docstrings.docstring import Docstring
from gpt4docstrings.docstrings_generators import ChatGPTDocstringGenerator
from gpt4docstrings.docstrings_translators import ChatGPTDocstringTranslator
from gpt4docstrings.utils.helpers import get_common_base
from gpt4docstrings.visit import GPT4DocstringsNode
from gpt4docstrings.visit import GPT4DocstringsVisitor


[docs]class GPT4Docstrings: def __init__( self, paths: Union[str, List[str]], excluded=None, model: str = "gpt-3.5-turbo", docstring_style: str = "google", translate: bool = True, api_key: str = None, verbose: int = 0, config: GPT4DocstringsConfig = None, ): self.paths = paths self.excluded = excluded or () self.common_base = pathlib.Path("/") if docstring_style not in ["google", "numpy", "reStructuredText", "epytext"]: raise ValueError( "Docstring Style must be one of the following: " '["google", "numpy", "reStructuredText", "epytext"]' ) self.docstring_generator = ChatGPTDocstringGenerator( api_key=api_key, model_name=model, docstring_style=docstring_style ) self.docstring_translator = ChatGPTDocstringTranslator( api_key=api_key, model_name=model, docstring_style=docstring_style ) self.verbose = verbose self.documented_nodes = [] self.config = config self.translate = translate self.patches = [] def __print_pretty_documentation_table(self): """Prints a pretty table of the documented functions and classes.""" headers = ["Filename", "Documented Functions / Classes"] table = [x for x in self.documented_nodes] print(Fore.GREEN + tabulate(table, headers, tablefmt="outline")) def __filter_files(self, files: List[str]): """Filters the input files based on the excluded patterns. Args: files (List[str]): The list of file paths to filter. Yields: str: The filtered file paths. """ for f in files: if not f.endswith(".py"): continue # By default, we will ignore __init__.py files basename = os.path.basename(f) if basename == "__init__.py": continue if any(fnmatch(f, exc + "*") for exc in self.excluded): continue yield f
[docs] def get_filenames_from_paths(self) -> List[str]: """Retrieves the filenames from the input paths. Returns: List[str]: The list of filenames. """ filenames = [] for path in self.paths: if path.startswith("./"): path = path[2:] if os.path.isfile(path): if not path.endswith(".py"): return sys.exit(1) if not any(fnmatch(path, exc + "*") for exc in self.excluded): filenames.append(path) continue for root, _, fs in os.walk(path): full_paths = [os.path.join(root, f) for f in fs] filenames.extend(self.__filter_files(full_paths)) if not filenames: return sys.exit(1) self.common_base = get_common_base(filenames) return filenames
@staticmethod def _filter_nodes_generation(nodes): """Filters the parsed nodes to only consider classes and functions""" return [ node for node in nodes if ( (node.node_type in ["ClassDef", "FunctionDef", "AsyncFunctionDef"]) and not node.covered ) ] @staticmethod def _filter_nodes_translation(nodes): """Filters the parsed nodes to only consider classes and functions""" return [ node for node in nodes if ( (node.node_type in ["ClassDef", "FunctionDef", "AsyncFunctionDef"]) and node.covered ) ] @staticmethod def _filter_inner_nested(nodes): """Filters out children of ignored nested funcs / classes.""" nested_cls = [n for n in nodes if n.is_nested_cls] inner_nested_nodes = [n for n in nodes if n.parent in nested_cls] filtered_nodes = [n for n in nodes if n not in inner_nested_nodes] filtered_nodes = [n for n in filtered_nodes if n not in nested_cls] return filtered_nodes @staticmethod def _read_file(filename: str, read_lines: bool = False) -> Union[str, List[str]]: with open(filename, encoding="utf-8") as file: return file.readlines() if read_lines else file.read() @staticmethod def _write_to_file(filename: str, content: str): with open(filename, "w", encoding="utf-8") as f: f.write(content) @staticmethod def _build_file_with_docstrings(source_file: str, docstrings: List[Docstring]): docstrings_positions = { docstring.lineno - 1: docstring.to_str() for docstring in docstrings } lines = [] for i, line in enumerate(source_file.split("\n")): lines.append(line) if i in docstrings_positions: lines.extend(docstrings_positions[i].splitlines()) target_file = "\n".join(lines) return target_file @staticmethod def _get_patch_lines(src: str, target: str, filename: str): src_lines = [line + "\n" for line in src.splitlines()] target_lines = [line + "\n" for line in target.splitlines()] fromfile = "a/" + filename tofile = "b/" + filename differ = list( difflib.unified_diff( src_lines, target_lines, fromfile=fromfile, tofile=tofile ) ) return differ def _generate_patch_file(self, src: str, target: str, filename: str): differ = self._get_patch_lines(src, target, filename) self.patches.append(differ) def _write_concatenated_patch_file(self): concatenated_patch = [] for patch in self.patches: concatenated_patch.extend(patch) concatenated_patch.append("\n") if concatenated_patch: with open( "gpt4docstring_docstring_generator_patch.diff", "w" ) as patch_file: patch_file.writelines(concatenated_patch)
[docs] async def generate_file_docstrings( self, filename: str, file_content: str, nodes: List[GPT4DocstringsNode] ) -> str: """ Generates docstrings for a single file. Args: filename (str): The path of the file to generate docstrings for. file_content (str): The content of the file to be processed. nodes (List[GPT4DocstringsNode]): The list of `GPT4DocstringsNode` containing nodes from classes and functions Returns: The new file content """ nodes = self._filter_inner_nested(self._filter_nodes_generation(nodes)) tasks = [] for node in nodes: tasks.append(self.docstring_generator.generate_docstring(node)) self.documented_nodes.append([filename, node.name]) docstrings = await tqdm_asyncio.gather(*tasks) new_file_content = self._build_file_with_docstrings(file_content, docstrings) return new_file_content
[docs] async def translate_file_docstrings( self, filename: str, file_content: str, nodes: List[GPT4DocstringsNode] ) -> str: """ Generates docstrings for a single file. Args: filename (str): The path of the file to generate docstrings for. file_content (str): The content of the file to be processed. nodes (List[GPT4DocstringsNode]): The list of `GPT4DocstringsNode` containing nodes from classes and functions Returns: The new file content """ nodes = self._filter_inner_nested(self._filter_nodes_translation(nodes)) tasks = [] for node in nodes: tasks.append(self.docstring_translator.translate_docstring(node)) self.documented_nodes.append([filename, node.name]) docstrings = await tqdm_asyncio.gather(*tasks) new_file_content = file_content for node, docstring in zip(nodes, docstrings, strict=True): new_file_content = new_file_content.replace( node.ast_node.body[0].value.value, docstring.to_str(add_triple_quotes=False), ) return new_file_content
[docs] def run(self): """Generates docstrings for the input files or directories.""" filenames = self.get_filenames_from_paths() click.echo(click.style(title, fg="green")) loop = asyncio.get_event_loop() for filename in filenames: click.echo(f"\n\n Documenting filename {filename} ... ") with open(filename, encoding="utf-8") as f: file_content = f.read() parsed_tree = ast.parse(file_content) visitor = GPT4DocstringsVisitor( filename=filename, config=GPT4DocstringsConfig() ) visitor.visit(parsed_tree) new_file_content = loop.run_until_complete( self.generate_file_docstrings(filename, file_content, visitor.nodes) ) if self.translate: new_file_content = loop.run_until_complete( self.translate_file_docstrings( filename, new_file_content, visitor.nodes ) ) if self.config.overwrite: self._write_to_file(filename, new_file_content) else: self._generate_patch_file(file_content, new_file_content, filename) if not self.config.overwrite: self._write_concatenated_patch_file() if self.verbose > 0: self.__print_pretty_documentation_table()