Skip to content
Prev Previous commit
Next Next commit
fix: address third round of review feedback
- Coerce None values from provides.commands/scripts via `or []` and
  validate they are lists, preventing TypeError on null YAML values
- Discover both .sh and .ps1 scripts in ExtensionResolver and
  PresetResolver.list_available() instead of only .sh
- Remove unused ExtensionRegistry import (ruff F401)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
  • Loading branch information
iamaeroplane and claude committed Mar 26, 2026
commit 6901e64b6d861224e9729b015b09c158544ab7cd
72 changes: 39 additions & 33 deletions src/specify_cli/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,15 +140,19 @@ def _validate(self):

# Validate provides section
provides = self.data["provides"]
has_commands = "commands" in provides and provides["commands"]
has_scripts = "scripts" in provides and provides["scripts"]
if not has_commands and not has_scripts:
commands = provides.get("commands") or []
scripts = provides.get("scripts") or []
if not isinstance(commands, list):
raise ValidationError("provides.commands must be a list")
if not isinstance(scripts, list):
raise ValidationError("provides.scripts must be a list")
Comment thread
mbachorik marked this conversation as resolved.
if not commands and not scripts:
raise ValidationError(
"Extension must provide at least one command or script"
)

# Validate commands
for cmd in provides.get("commands", []):
for cmd in commands:
if "name" not in cmd or "file" not in cmd:
raise ValidationError("Command missing 'name' or 'file'")

Expand All @@ -160,7 +164,7 @@ def _validate(self):
)
Comment thread
mbachorik marked this conversation as resolved.

# Validate scripts
for script in provides.get("scripts", []):
for script in scripts:
if "name" not in script or "file" not in script:
raise ValidationError("Script missing 'name' or 'file'")
Comment thread
mbachorik marked this conversation as resolved.

Expand Down Expand Up @@ -214,12 +218,12 @@ def requires_speckit_version(self) -> str:
@property
def commands(self) -> List[Dict[str, Any]]:
"""Get list of provided commands."""
return self.data["provides"].get("commands", [])
return self.data["provides"].get("commands") or []

@property
def scripts(self) -> List[Dict[str, Any]]:
"""Get list of provided scripts."""
return self.data["provides"].get("scripts", [])
return self.data["provides"].get("scripts") or []

@property
def hooks(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -920,19 +924,18 @@ def resolve(
Returns:
Path to the resolved template file, or None if not found
"""
subdirs, ext = self._type_config(template_type)
subdirs, exts = self._type_config(template_type)

for _priority, ext_id, _metadata in self.get_all_by_priority():
ext_dir = self.extensions_dir / ext_id
if not ext_dir.is_dir():
continue
for subdir in subdirs:
if subdir:
candidate = ext_dir / subdir / f"{template_name}{ext}"
else:
candidate = ext_dir / f"{template_name}{ext}"
if candidate.exists():
return candidate
base = ext_dir / subdir if subdir else ext_dir
for file_ext in exts:
candidate = base / f"{template_name}{file_ext}"
if candidate.exists():
return candidate

return None

Expand All @@ -950,24 +953,23 @@ def resolve_with_source(
Returns:
Dictionary with 'path' and 'source' keys, or None if not found
"""
subdirs, ext = self._type_config(template_type)
subdirs, exts = self._type_config(template_type)

for _priority, ext_id, ext_meta in self.get_all_by_priority():
ext_dir = self.extensions_dir / ext_id
if not ext_dir.is_dir():
continue
for subdir in subdirs:
if subdir:
candidate = ext_dir / subdir / f"{template_name}{ext}"
else:
candidate = ext_dir / f"{template_name}{ext}"
if candidate.exists():
if ext_meta:
version = ext_meta.get("version", "?")
source = f"extension:{ext_id} v{version}"
else:
source = f"extension:{ext_id} (unregistered)"
return {"path": str(candidate), "source": source}
base = ext_dir / subdir if subdir else ext_dir
for file_ext in exts:
candidate = base / f"{template_name}{file_ext}"
if candidate.exists():
if ext_meta:
version = ext_meta.get("version", "?")
source = f"extension:{ext_id} v{version}"
else:
source = f"extension:{ext_id} (unregistered)"
return {"path": str(candidate), "source": source}

return None

Expand All @@ -985,7 +987,7 @@ def list_templates(
Returns:
List of dicts with 'name', 'path', and 'source' keys.
"""
subdirs, ext = self._type_config(template_type)
subdirs, exts = self._type_config(template_type)
results: List[Dict[str, str]] = []
seen: set[str] = set()

Expand All @@ -1005,7 +1007,7 @@ def list_templates(
if not scan_dir.is_dir():
continue
for f in sorted(scan_dir.iterdir()):
if f.is_file() and f.suffix == ext:
if f.is_file() and f.suffix in exts:
name = f.stem
if name not in seen:
seen.add(name)
Expand All @@ -1019,14 +1021,18 @@ def list_templates(

@staticmethod
def _type_config(template_type: str) -> tuple:
"""Return (subdirs, file_extension) for a template type."""
"""Return (subdirs, file_extensions) for a template type.

Returns:
Tuple of (subdirs list, list of file extensions to match).
"""
if template_type == "template":
return ["templates", ""], ".md"
return ["templates", ""], [".md"]
elif template_type == "command":
return ["commands"], ".md"
return ["commands"], [".md"]
elif template_type == "script":
Comment thread
mbachorik marked this conversation as resolved.
return ["scripts"], ".sh"
return [""], ".md"
return ["scripts"], [".sh", ".ps1"]
return [""], [".md"]


def version_satisfies(current: str, required: str) -> bool:
Expand Down
6 changes: 3 additions & 3 deletions src/specify_cli/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1694,8 +1694,8 @@ def list_available(
seen: set[str] = set()
results: List[Dict[str, str]] = []

# Determine file extension and subdirectory mapping
ext = ".sh" if template_type == "script" else ".md"
# Determine file extensions and subdirectory mapping
exts = [".sh", ".ps1"] if template_type == "script" else [".md"]
if template_type == "template":
subdirs = ["templates", ""]
elif template_type == "command":
Expand Down Expand Up @@ -1723,7 +1723,7 @@ def _collect(directory: Path, source: str):
if not directory.is_dir():
return
for f in sorted(directory.iterdir()):
if f.is_file() and f.suffix == ext:
if f.is_file() and f.suffix in exts:
name = f.stem
if name in seen:
continue
Expand Down
Loading