Skip to content

Add pre-commit To Prevent Usage of session.query #47275

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,14 @@ repos:
files: ^airflow-core/src/airflow/models/taskinstance\.py$|^airflow-core/src/airflow/models/taskinstancehistory\.py$
pass_filenames: false
require_serial: true
- id: prevent-usage-of-session.query
name: Prevent usage of session.query
entry: ./scripts/ci/pre_commit/usage_session_query.py
language: python
additional_dependencies: ['rich>=12.4.4']
files: ^airflow.*\.py$|^task_sdk.*\.py
exclude: ^task_sdk/tests/.*\.py$
pass_filenames: true
- id: check-deferrable-default
name: Check and fix default value of default_deferrable
language: python
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,11 +246,9 @@ def ti_run(

xcom_keys = list(session.scalars(query))
task_reschedule_count = (
session.query(
func.count(TaskReschedule.id) # or any other primary key column
)
.filter(TaskReschedule.ti_id == ti_id_str)
.scalar()
session.execute(
select(func.count(TaskReschedule.id)).where(TaskReschedule.ti_id == ti_id_str)
).scalar()
or 0
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ def set_xcom(
if not run_id:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"Run with ID: `{run_id}` was not found")

dag_run_id = session.query(DagRun.id).filter_by(dag_id=dag_id, run_id=run_id).scalar()
dag_run_id = session.scalar(DagRun.id).where(dag_id=dag_id, run_id=run_id)
if dag_run_id is None:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"DAG run not found on DAG {dag_id} with ID {run_id}")

Expand Down
7 changes: 5 additions & 2 deletions airflow-core/src/airflow/dag_processing/bundles/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from typing import TYPE_CHECKING

from sqlalchemy import delete
from sqlalchemy import select
from sqlalchemy import select, update

from airflow.configuration import conf
from airflow.exceptions import AirflowConfigException
Expand Down Expand Up @@ -124,8 +126,9 @@ def parse_config(self) -> None:
@provide_session
def sync_bundles_to_db(self, *, session: Session = NEW_SESSION) -> None:
self.log.debug("Syncing DAG bundles to the database")
stored = {b.name: b for b in session.query(DagBundleModel).all()}
for name in self._bundle_config.keys():
stored = {b.name: b for b in session.scalars(select(DagBundleModel)).all()}
active_bundle_names = set(self._bundle_config.keys())
for name in active_bundle_names:
if bundle := stored.pop(name, None):
bundle.active = True
else:
Expand Down
17 changes: 7 additions & 10 deletions airflow-core/src/airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2028,34 +2028,31 @@ def _get_num_times_stuck_in_queued(self, ti: TaskInstance, session: Session = NE

We can then use this information to determine whether to reschedule a task or fail it.
"""
return (
session.query(Log)
.where(
return session.execute(
select(func.count(Log.id)).where(
Log.task_id == ti.task_id,
Log.dag_id == ti.dag_id,
Log.run_id == ti.run_id,
Log.map_index == ti.map_index,
Log.try_number == ti.try_number,
Log.event == TASK_STUCK_IN_QUEUED_RESCHEDULE_EVENT,
)
.count()
)
).scalar()

previous_ti_running_metrics: dict[tuple[str, str, str], int] = {}

@provide_session
def _emit_running_ti_metrics(self, session: Session = NEW_SESSION) -> None:
running = (
session.query(
running = session.execute(
select(
TaskInstance.dag_id,
TaskInstance.task_id,
TaskInstance.queue,
func.count(TaskInstance.task_id).label("running_count"),
)
.filter(TaskInstance.state == State.RUNNING)
.where(TaskInstance.state == State.RUNNING)
.group_by(TaskInstance.dag_id, TaskInstance.task_id, TaskInstance.queue)
.all()
)
).all()

ti_running_metrics = {(row.dag_id, row.task_id, row.queue): row.running_count for row in running}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def upgrade():
if not context.is_offline_mode():
session = get_session()
try:
for trigger in session.query(Trigger).options(lazyload(Trigger.task_instance)):
for trigger in session.scalars(select(Trigger).options(lazyload(Trigger.task_instance))):
trigger.kwargs = trigger.kwargs
session.commit()
finally:
Expand All @@ -81,7 +81,7 @@ def downgrade():
else:
session = get_session()
try:
for trigger in session.query(Trigger).options(lazyload(Trigger.task_instance)):
for trigger in session.scalars(select(Trigger).options(lazyload(Trigger.task_instance))):
trigger.encrypted_kwargs = json.dumps(BaseSerialization.serialize(trigger.kwargs))
session.commit()
finally:
Expand Down
11 changes: 7 additions & 4 deletions airflow-core/src/airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,10 +788,13 @@ def fetch_task_instances(
def _check_last_n_dagruns_failed(self, dag_id, max_consecutive_failed_dag_runs, session):
"""Check if last N dags failed."""
dag_runs = (
session.query(DagRun)
.filter(DagRun.dag_id == dag_id)
.order_by(DagRun.logical_date.desc())
.limit(max_consecutive_failed_dag_runs)
session.execute(
select(DagRun)
.where(DagRun.dag_id == dag_id)
.order_by(DagRun.logical_date.desc())
.limit(max_consecutive_failed_dag_runs)
)
.scalars()
.all()
)
""" Marking dag as paused, if needed"""
Expand Down
8 changes: 3 additions & 5 deletions airflow-core/src/airflow/models/serialized_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,8 +479,8 @@ def get_latest_serialized_dags(
"""
# Subquery to get the latest serdag per dag_id
latest_serdag_subquery = (
session.query(cls.dag_id, func.max(cls.created_at).label("created_at"))
.filter(cls.dag_id.in_(dag_ids))
select(cls.dag_id, func.max(cls.created_at).label("created_at"))
.where(cls.dag_id.in_(dag_ids))
.group_by(cls.dag_id)
.subquery()
)
Expand All @@ -504,9 +504,7 @@ def read_all_dags(cls, session: Session = NEW_SESSION) -> dict[str, SerializedDA
:returns: a dict of DAGs read from database
"""
latest_serialized_dag_subquery = (
session.query(cls.dag_id, func.max(cls.created_at).label("max_created"))
.group_by(cls.dag_id)
.subquery()
select(cls.dag_id, func.max(cls.created_at).label("max_created")).group_by(cls.dag_id).subquery()
)
serialized_dags = session.scalars(
select(cls).join(
Expand Down
57 changes: 30 additions & 27 deletions airflow-core/src/airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,16 +280,14 @@ def clear_task_instances(
for instance in tis:
run_ids_by_dag_id[instance.dag_id].add(instance.run_id)

drs = (
session.query(DagRun)
.filter(
drs = session.scalars(
select(DagRun).where(
or_(
and_(DagRun.dag_id == dag_id, DagRun.run_id.in_(run_ids))
for dag_id, run_ids in run_ids_by_dag_id.items()
)
)
.all()
)
).all()
dag_run_state = DagRunState(dag_run_state) # Validate the state value.
for dr in drs:
if dr.state in State.finished_dr_states:
Expand Down Expand Up @@ -804,22 +802,22 @@ def get_task_instance(
session: Session = NEW_SESSION,
) -> TaskInstance | None:
query = (
session.query(TaskInstance)
.options(lazyload(TaskInstance.dag_run)) # lazy load dag run to avoid locking it
.filter_by(
dag_id=dag_id,
run_id=run_id,
task_id=task_id,
map_index=map_index,
select(TaskInstance)
.options(lazyload(TaskInstance.dag_run))
.where(
TaskInstance.dag_id == dag_id,
TaskInstance.run_id == run_id,
TaskInstance.task_id == task_id,
TaskInstance.map_index == map_index,
)
)

if lock_for_update:
for attempt in run_with_db_retries(logger=cls.logger()):
with attempt:
return query.with_for_update().one_or_none()
return session.execute(query.with_for_update()).one_or_none()
else:
return query.one_or_none()
return session.execute(query).one_or_none()

return None

Expand Down Expand Up @@ -969,13 +967,15 @@ def are_dependents_done(self, session: Session = NEW_SESSION) -> bool:
if not task.downstream_task_ids:
return True

ti = session.query(func.count(TaskInstance.task_id)).filter(
TaskInstance.dag_id == self.dag_id,
TaskInstance.task_id.in_(task.downstream_task_ids),
TaskInstance.run_id == self.run_id,
TaskInstance.state.in_((TaskInstanceState.SKIPPED, TaskInstanceState.SUCCESS)),
ti = session.execute(
select(func.count(TaskInstance.task_id)).where(
TaskInstance.dag_id == self.dag_id,
TaskInstance.task_id.in_(task.downstream_task_ids),
TaskInstance.run_id == self.run_id,
TaskInstance.state.in_((TaskInstanceState.SKIPPED, TaskInstanceState.SUCCESS)),
)
)
count = ti[0][0]
count = ti.scalar()
return count == len(task.downstream_task_ids)

@provide_session
Expand Down Expand Up @@ -1157,7 +1157,7 @@ def ready_for_retry(self) -> bool:
def _get_dagrun(dag_id, run_id, session) -> DagRun:
from airflow.models.dagrun import DagRun # Avoid circular import

dr = session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == run_id).one()
dr = session.execute(select(DagRun).where(DagRun.dag_id == dag_id, DagRun.run_id == run_id)).one()
return dr

@provide_session
Expand Down Expand Up @@ -2209,16 +2209,19 @@ def xcom_pull(
def get_num_running_task_instances(self, session: Session, same_dagrun: bool = False) -> int:
"""Return Number of running TIs from the DB."""
# .count() is inefficient
num_running_task_instances_query = session.query(func.count()).filter(
TaskInstance.dag_id == self.dag_id,
TaskInstance.task_id == self.task_id,
TaskInstance.state == TaskInstanceState.RUNNING,

num_running_task_instances_query = select(
func.count().where(
TaskInstance.dag_id == self.dag_id,
TaskInstance.task_id == self.task_id,
TaskInstance.state == TaskInstanceState.RUNNING,
)
)
if same_dagrun:
num_running_task_instances_query = num_running_task_instances_query.filter(
num_running_task_instances_query = num_running_task_instances_query.where(
TaskInstance.run_id == self.run_id
)
return num_running_task_instances_query.scalar()
return session.execute(num_running_task_instances_query).scalar()

@staticmethod
def filter_for_tis(tis: Iterable[TaskInstance | TaskInstanceKey]) -> BooleanClauseList | None:
Expand Down
37 changes: 20 additions & 17 deletions airflow-core/src/airflow/models/xcom.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,11 @@ def clear(
if not run_id:
raise ValueError(f"run_id must be passed. Passed run_id={run_id}")

query = session.query(cls).filter_by(dag_id=dag_id, task_id=task_id, run_id=run_id)
query = select(cls).where(dag_id=dag_id, task_id=task_id, run_id=run_id)
if map_index is not None:
query = query.filter_by(map_index=map_index)
query = query.where(map_index=map_index)

for xcom in query:
for xcom in session.scalars(query):
# print(f"Clearing XCOM {xcom} with value {xcom.value}")
session.delete(xcom)

Expand Down Expand Up @@ -186,7 +186,9 @@ def set(
if not run_id:
raise ValueError(f"run_id must be passed. Passed run_id={run_id}")

dag_run_id = session.query(DagRun.id).filter_by(dag_id=dag_id, run_id=run_id).scalar()
dag_run_id = session.execute(
select(DagRun.id).where(DagRun.dag_id == dag_id, DagRun.run_id == run_id)
).scalar()
if dag_run_id is None:
raise ValueError(f"DAG run not found on DAG {dag_id!r} with ID {run_id!r}")

Expand Down Expand Up @@ -287,42 +289,43 @@ def get_many(
if not run_id:
raise ValueError(f"run_id must be passed. Passed run_id={run_id}")

query = session.query(cls).join(XComModel.dag_run)
query = select(cls).join(XComModel.dag_run)

if key:
query = query.filter(XComModel.key == key)
query = query.where(XComModel.key == key)

if is_container(task_ids):
query = query.filter(cls.task_id.in_(task_ids))
query = query.where(cls.task_id.in_(task_ids))
elif task_ids is not None:
query = query.filter(cls.task_id == task_ids)
query = query.where(cls.task_id == task_ids)

if is_container(dag_ids):
query = query.filter(cls.dag_id.in_(dag_ids))
query = query.where(cls.dag_id.in_(dag_ids))

elif dag_ids is not None:
query = query.filter(cls.dag_id == dag_ids)
query = query.where(cls.dag_id == dag_ids)

if isinstance(map_indexes, range) and map_indexes.step == 1:
query = query.filter(cls.map_index >= map_indexes.start, cls.map_index < map_indexes.stop)
query = query.where(cls.map_index >= map_indexes.start, cls.map_index < map_indexes.stop)
elif is_container(map_indexes):
query = query.filter(cls.map_index.in_(map_indexes))
query = query.where(cls.map_index.in_(map_indexes))
elif map_indexes is not None:
query = query.filter(cls.map_index == map_indexes)
query = query.where(cls.map_index == map_indexes)

if include_prior_dates:
dr = (
session.query(
select(
func.coalesce(DagRun.logical_date, DagRun.run_after).label("logical_date_or_run_after")
)
.filter(DagRun.run_id == run_id)
.where(DagRun.run_id == run_id)
.subquery()
)

query = query.filter(
query = query.where(
func.coalesce(DagRun.logical_date, DagRun.run_after) <= dr.c.logical_date_or_run_after
)
else:
query = query.filter(cls.run_id == run_id)
query = query.where(cls.run_id == run_id)

query = query.order_by(DagRun.logical_date.desc(), cls.timestamp.desc())
if limit:
Expand Down
25 changes: 25 additions & 0 deletions airflow-core/src/airflow/utils/log/file_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

import pendulum
from pydantic import BaseModel, ConfigDict, ValidationError
from sqlalchemy import select

from airflow.configuration import conf
from airflow.executors.executor_loader import ExecutorLoader
Expand Down Expand Up @@ -179,6 +180,30 @@ def _interleave_logs(*logs: str | LogMessages) -> Iterable[StructuredLogMessage]
last = msg


def _ensure_ti(ti: TaskInstanceKey | TaskInstance, session) -> TaskInstance:
"""
Given TI | TIKey, return a TI object.

Will raise exception if no TI is found in the database.
"""
from airflow.models.taskinstance import TaskInstance

if isinstance(ti, TaskInstance):
return ti
val = session.execute(
select(TaskInstance).where(
TaskInstance.task_id == ti.task_id,
TaskInstance.dag_id == ti.dag_id,
TaskInstance.run_id == ti.run_id,
TaskInstance.map_index == ti.map_index,
)
).one_or_none
if not val:
raise AirflowException(f"Could not find TaskInstance for {ti}")
val.try_number = ti.try_number
return val


class FileTaskHandler(logging.Handler):
"""
FileTaskHandler is a python log handler that handles and reads task instance logs.
Expand Down
2 changes: 2 additions & 0 deletions contributing-docs/08_static_code_checks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,8 @@ require Breeze Docker image to be built locally.
+-----------------------------------------------------------+--------------------------------------------------------+---------+
| pretty-format-json | Format JSON files | |
+-----------------------------------------------------------+--------------------------------------------------------+---------+
| prevent-usage-of-session.query | Prevent usage of session.query | |
+-----------------------------------------------------------+--------------------------------------------------------+---------+
| pylint | pylint | |
+-----------------------------------------------------------+--------------------------------------------------------+---------+
| python-no-log-warn | Check if there are no deprecate log warn | |
Expand Down
Loading
Loading