Skip to content

Add parameters and volatility to Function #104

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 43 additions & 7 deletions docs/source/functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

```python
from sqlalchemy.orm import declarative_base
from sqlalchemy_declarative_extensions import declarative_database, Function, Functions
from sqlalchemy_declarative_extensions import declarative_database, Functions

# Import dialect-specific Function for full feature support
from sqlalchemy_declarative_extensions.dialects.postgresql import Function
# from sqlalchemy_declarative_extensions.dialects.mysql import Function

_Base = declarative_base()

Expand All @@ -22,22 +26,49 @@ class Base(_Base):
""",
language="plpgsql",
returns="trigger",
),
Function(
"gimme_rows",
'''
SELECT id, name
FROM dem_rowz
WHERE group_id = _group_id;
''',
language="sql",
parameters=["_group_id int"],
returns="TABLE(id int, name text)",
volatility='stable', # PostgreSQL specific characteristic
)

# Example MySQL function
# Function(
# "gimme_concat",
# "RETURN CONCAT(label, ': ', CAST(val AS CHAR));",
# parameters=["val INT", "label VARCHAR(50)"],
# returns="VARCHAR(100)",
# deterministic=True, # MySQL specific
# data_access='NO SQL', # MySQL specific
# security='INVOKER', # MySQL specific
# ),
)
```

```{note}
Functions options are wildly different across dialects. As such, you should likely always use
the diaelect-specific `Function` object.
the dialect-specific `Function` object (e.g., `sqlalchemy_declarative_extensions.dialects.postgresql.Function`
or `sqlalchemy_declarative_extensions.dialects.mysql.Function`) to access all available features.
The base `Function` provides only the most common subset of options.
```

```{note}
Function behavior (for eaxmple...arguments) is not fully implemented at current time,
although it **should** be functional for the options it does support. Any ability to instantiate
an object which produces a syntax error should be considered a bug. Additionally, feature requests
for supporting more function options are welcome!
Function comparison logic now supports parsing and comparing function parameters (including name and type)
and various dialect-specific characteristics:

* **PostgreSQL:** `LANGUAGE`, `VOLATILITY`, `SECURITY`, `RETURNS TABLE(...)` syntax.
* **MySQL:** `DETERMINISTIC`, `DATA ACCESS`, `SECURITY`.

In particular, the current function support is heavily oriented around support for defining triggers.
The comparison logic handles normalization (e.g., mapping `integer` to `int4` in PostgreSQL) to ensure
accurate idempotency checks during Alembic autogeneration.
```

```{eval-rst}
Expand All @@ -52,3 +83,8 @@ any dialect-specific options.
.. autoapimodule:: sqlalchemy_declarative_extensions.dialects.postgresql.function
:members: Function, Procedure
```

```{eval-rst}
.. autoapimodule:: sqlalchemy_declarative_extensions.dialects.mysql.function
:members: Function
```
34 changes: 32 additions & 2 deletions src/sqlalchemy_declarative_extensions/dialects/mysql/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,17 @@ def from_unknown_function(cls, f: base.Function) -> Self:
language=f.language,
schema=f.schema,
returns=f.returns,
parameters=f.parameters,
)

def to_sql_create(self) -> list[str]:
components = ["CREATE FUNCTION"]

components.append(self.qualified_name + "()")
parameter_str = ""
if self.parameters:
parameter_str = ", ".join(self.parameters)

components.append(f"{self.qualified_name}({parameter_str})")
components.append(f"RETURNS {self.returns}")

if self.deterministic:
Expand Down Expand Up @@ -85,9 +90,34 @@ def modifies_sql(self):

def normalize(self) -> Function:
definition = textwrap.dedent(self.definition).strip()

# Remove optional trailing semicolon for comparison robustness
if definition.endswith(";"):
definition = definition[:-1]

returns = self.returns.lower()
normalized_returns = type_map.get(returns, returns)

normalized_parameters = None
if self.parameters:
normalized_parameters = []
for param in self.parameters:
# Naive split, assumes 'name type' format
parts = param.split(maxsplit=1)
if len(parts) == 2:
name, type_str = parts
norm_type = type_map.get(type_str.lower(), type_str.lower())
normalized_parameters.append(f"{name} {norm_type}")
else:
normalized_parameters.append(
param
) # Keep as is if format unexpected

return replace(
self, definition=definition, returns=type_map.get(returns, returns)
self,
definition=definition,
returns=normalized_returns,
parameters=normalized_parameters,
)


Expand Down
5 changes: 5 additions & 0 deletions src/sqlalchemy_declarative_extensions/dialects/mysql/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,15 @@ def get_functions_mysql(connection: Connection) -> Sequence[BaseFunction]:

functions = []
for f in connection.execute(functions_query, {"schema": database}).fetchall():
parameters = None
if f.parameters: # Parameter string might be None if no parameters
parameters = [p.strip() for p in f.parameters.split(",")]

functions.append(
Function(
name=f.name,
definition=f.definition,
parameters=parameters,
security=(
FunctionSecurity.definer
if f.security == "DEFINER"
Expand Down
25 changes: 25 additions & 0 deletions src/sqlalchemy_declarative_extensions/dialects/mysql/schema.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from sqlalchemy import bindparam, column, table
from sqlalchemy.sql import func, text

from sqlalchemy_declarative_extensions.sqlalchemy import select

Expand Down Expand Up @@ -81,6 +82,23 @@
.where(routine_table.c.routine_type == "PROCEDURE")
)

# Need to query PARAMETERS separately to reconstruct the parameter list
parameters_subquery = (
select(
column("SPECIFIC_NAME").label("routine_name"),
func.group_concat(
text(
"concat(PARAMETER_NAME, ' ', DTD_IDENTIFIER) ORDER BY ORDINAL_POSITION SEPARATOR ', '"
),
).label("parameters"),
)
.select_from(table("PARAMETERS", schema="INFORMATION_SCHEMA"))
.where(column("SPECIFIC_SCHEMA") == bindparam("schema"))
.where(column("ROUTINE_TYPE") == "FUNCTION")
.group_by(column("SPECIFIC_NAME"))
.alias("parameters_sq")
)

functions_query = (
select(
routine_table.c.routine_name.label("name"),
Expand All @@ -89,6 +107,13 @@
routine_table.c.dtd_identifier.label("return_type"),
routine_table.c.is_deterministic.label("deterministic"),
routine_table.c.sql_data_access.label("data_access"),
parameters_subquery.c.parameters.label("parameters"),
)
.select_from( # Join routines with the parameter subquery
routine_table.outerjoin(
parameters_subquery,
routine_table.c.routine_name == parameters_subquery.c.routine_name,
)
)
.where(routine_table.c.routine_schema == bindparam("schema"))
.where(routine_table.c.routine_type == "FUNCTION")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from sqlalchemy_declarative_extensions.dialects.postgresql.function import (
Function,
FunctionParam,
FunctionReturn,
FunctionSecurity,
FunctionVolatility,
)
from sqlalchemy_declarative_extensions.dialects.postgresql.grant import (
DefaultGrant,
Expand Down Expand Up @@ -42,7 +45,10 @@
"DefaultGrantTypes",
"Function",
"FunctionGrants",
"FunctionParam",
"FunctionReturn",
"FunctionSecurity",
"FunctionVolatility",
"Grant",
"Grant",
"GrantStatement",
Expand Down
Loading