rishabh-mondal
Initial HF Space
423d6f2
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import io
import os
from typing import Tuple, List
import gradio as gr
import pandas as pd
import numpy as np
from sklearn.neighbors import BallTree
import folium
from folium.plugins import MarkerCluster, HeatMap
import shapely.wkt
import matplotlib.pyplot as plt
# ------------------------------
# Utilities
# ------------------------------
EARTH_RADIUS_KM = 6371.0088
def _to_radians(latlon: np.ndarray) -> np.ndarray:
"""latlon in degrees -> radians (n,2)"""
return np.radians(latlon.astype(float))
def _balltree_haversine_min_km(a_latlon_deg: np.ndarray, b_latlon_deg: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""
Fast nearest-neighbor distance between points A and B using haversine metric.
Returns (min_distance_km, index_in_B).
"""
if len(a_latlon_deg) == 0 or len(b_latlon_deg) == 0:
return np.array([]), np.array([], dtype=int)
# convert to (lon,lat) radians for BallTree(haversine)
a_rad = _to_radians(a_latlon_deg[:, [0,1]])[:, ::-1]
b_rad = _to_radians(b_latlon_deg[:, [0,1]])[:, ::-1]
tree = BallTree(b_rad, metric="haversine")
dist_rad, idx = tree.query(a_rad, k=1)
dist_km = dist_rad.flatten() * EARTH_RADIUS_KM
return dist_km, idx.flatten()
def _lines_to_vertices_df(lines_like: pd.DataFrame) -> pd.DataFrame:
"""
Convert lines in WKT (column 'geometry') to a vertex cloud (lon,lat).
If already has lon/lat columns, return those as-is.
"""
if {"lon", "lat"}.issubset(lines_like.columns):
return lines_like[["lon", "lat"]].dropna().reset_index(drop=True)
if "geometry" not in lines_like.columns:
return pd.DataFrame(columns=["lon", "lat"])
out = []
for _, row in lines_like.iterrows():
geom = row["geometry"]
if isinstance(geom, str):
try:
geom = shapely.wkt.loads(geom)
except Exception:
geom = None
if geom is None:
continue
gtype = getattr(geom, "geom_type", "")
if gtype == "LineString":
out.extend([(x, y) for x, y in geom.coords])
elif gtype == "MultiLineString":
for line in geom.geoms:
out.extend([(x, y) for x, y in line.coords])
elif gtype == "Point":
out.append((geom.x, geom.y))
return pd.DataFrame(out, columns=["lon", "lat"]).dropna()
def _ensure_cols(df: pd.DataFrame, needed: List[str], name_for_error: str):
missing = [c for c in needed if c not in df.columns]
if missing:
raise ValueError(f"{name_for_error}: missing columns {missing}. Expected at least {needed}.")
def _read_csv(file) -> pd.DataFrame:
"""
Accepts: None, string path, gradio File object, or old-style dict {name/path/data}.
Tries path first; falls back to bytes if needed.
"""
if file is None:
return pd.DataFrame()
# String path
if isinstance(file, str):
return pd.read_csv(file)
# Older gradio may pass dict
if isinstance(file, dict):
for key in ("path", "name"):
p = file.get(key)
if isinstance(p, str) and os.path.exists(p):
return pd.read_csv(p)
data = file.get("data")
if data is not None:
return pd.read_csv(io.BytesIO(data))
return pd.DataFrame()
# File-like with .name
path = getattr(file, "name", None)
if isinstance(path, str) and os.path.exists(path):
return pd.read_csv(path)
# Last resort
try:
return pd.read_csv(file)
except Exception:
return pd.DataFrame()
def _center_from_points(latlon: np.ndarray) -> Tuple[float, float]:
if len(latlon) == 0:
return 28.6, 77.2 # fallback (Delhi-ish)
return float(np.mean(latlon[:, 0])), float(np.mean(latlon[:, 1]))
# ------------------------------
# Core: compute compliance
# ------------------------------
def compute_compliance(
kilns_csv,
hospitals_csv=None,
waterways_csv=None,
kiln_km_thresh: float = 1.0,
hosp_km_thresh: float = 0.8,
water_km_thresh: float = 0.5,
add_heatmap: bool = False,
cluster_points: bool = True
):
# Load data
kilns = _read_csv(kilns_csv)
_ensure_cols(kilns, ["lat", "lon"], "Kilns CSV")
hospitals = _read_csv(hospitals_csv) if hospitals_csv else pd.DataFrame()
waterways = _read_csv(waterways_csv) if waterways_csv else pd.DataFrame()
# Arrays
kiln_latlon = kilns[["lat", "lon"]].to_numpy(dtype=float)
# Nearest kiln (exclude self): query k=2, take index 1
if len(kilns) >= 2:
rad = _to_radians(kiln_latlon)[:, ::-1]
tree = BallTree(rad, metric="haversine")
dist_rad, _ = tree.query(rad, k=2)
nearest_km = dist_rad[:, 1] * EARTH_RADIUS_KM
else:
nearest_km = np.full(len(kilns), np.nan)
# Nearest hospital
if not hospitals.empty and {"Latitude", "Longitude"}.issubset(hospitals.columns):
hosp_latlon = hospitals[["Latitude", "Longitude"]].to_numpy(dtype=float)
hosp_km, _ = _balltree_haversine_min_km(kiln_latlon, hosp_latlon)
elif not hospitals.empty and {"lat", "lon"}.issubset(hospitals.columns):
hosp_latlon = hospitals[["lat", "lon"]].to_numpy(dtype=float)
hosp_km, _ = _balltree_haversine_min_km(kiln_latlon, hosp_latlon)
else:
hosp_km = np.full(len(kilns), np.nan)
# Nearest water (lines/points -> vertices)
if not waterways.empty:
water_pts = _lines_to_vertices_df(waterways)
if len(water_pts) > 0:
water_latlon = water_pts[["lat", "lon"]].to_numpy(dtype=float)
water_km, _ = _balltree_haversine_min_km(kiln_latlon, water_latlon)
else:
water_km = np.full(len(kilns), np.nan)
else:
water_km = np.full(len(kilns), np.nan)
# Flags
flags = np.ones(len(kilns), dtype=bool)
if kiln_km_thresh is not None and kiln_km_thresh > 0:
flags &= (nearest_km >= kiln_km_thresh) | np.isnan(nearest_km)
if hosp_km_thresh is not None and hosp_km_thresh > 0:
flags &= (hosp_km >= hosp_km_thresh) | np.isnan(hosp_km)
if water_km_thresh is not None and water_km_thresh > 0:
flags &= (water_km >= water_km_thresh) | np.isnan(water_km)
# Output DF
out = kilns.copy()
out["nearest_kiln_km"] = np.round(nearest_km, 4)
out["nearest_hospital_km"] = np.round(hosp_km, 4)
out["nearest_water_km"] = np.round(water_km, 4)
out["compliant"] = flags
# Summary
total = len(out)
non_compliant = int((~out["compliant"]).sum())
compliant = int(out["compliant"].sum())
# Folium map
ctr_lat, ctr_lon = _center_from_points(kiln_latlon)
m = folium.Map(
location=[ctr_lat, ctr_lon],
zoom_start=6,
control_scale=True,
tiles="CartoDB positron"
)
g_compliant = folium.FeatureGroup(name="Compliant kilns", show=True)
g_noncomp = folium.FeatureGroup(name="Non-compliant kilns", show=True)
def _add_markers(df: pd.DataFrame, group: folium.FeatureGroup, color: str):
if len(df) == 0:
return
if cluster_points:
cluster = MarkerCluster()
group.add_child(cluster)
for _, r in df.iterrows():
folium.CircleMarker(
location=[r["lat"], r["lon"]],
radius=4,
color=color,
fill=True,
fill_opacity=0.7,
tooltip=(
f"Kiln\n"
f"Nearest kiln: {r.get('nearest_kiln_km', np.nan)} km\n"
f"Nearest hospital: {r.get('nearest_hospital_km', np.nan)} km\n"
f"Nearest water: {r.get('nearest_water_km', np.nan)} km"
),
).add_to(cluster)
else:
for _, r in df.iterrows():
folium.CircleMarker(
location=[r["lat"], r["lon"]],
radius=4,
color=color,
fill=True,
fill_opacity=0.7
).add_to(group)
_add_markers(out[out["compliant"]], g_compliant, color="#16a34a") # green
_add_markers(out[~out["compliant"]], g_noncomp, color="#dc2626") # red
m.add_child(g_compliant)
m.add_child(g_noncomp)
if add_heatmap and len(out) > 0:
HeatMap(out[["lat", "lon"]].values.tolist(), name="Kiln density").add_to(m)
folium.LayerControl(collapsed=False).add_to(m)
map_html = m._repr_html_()
# Summary text
summary = (
f"Total kilns: {total} | "
f"Compliant: {compliant} | "
f"Non-compliant: {non_compliant}\n"
f"Rules: ≥{kiln_km_thresh} km from nearest kiln, "
f"≥{hosp_km_thresh} km from hospital, "
f"≥{water_km_thresh} km from water"
)
# Also return the combined DF (as bytes) so we can make a static plot without saving to disk
buf = io.BytesIO()
out.to_csv(buf, index=False)
buf.seek(0)
return map_html, summary, buf.read()
# ------------------------------
# Static visualization (Matplotlib)
# ------------------------------
def make_scatter_figure(csv_bytes: bytes, title: str = "Kilns: Compliant vs Non-compliant"):
df = pd.read_csv(io.BytesIO(csv_bytes))
fig, ax = plt.subplots(figsize=(6.5, 5.5)) # single plot
comp = df[df["compliant"] == True]
nonc = df[df["compliant"] == False]
# Keep default matplotlib colors (no explicit color)
if len(comp) > 0:
ax.scatter(comp["lon"], comp["lat"], marker="o", label=f"Compliant (n={len(comp)})")
if len(nonc) > 0:
ax.scatter(nonc["lon"], nonc["lat"], marker="x", label=f"Non-compliant (n={len(nonc)})")
ax.set_xlabel("Longitude")
ax.set_ylabel("Latitude")
ax.set_title(title)
ax.grid(True)
ax.legend()
return fig
# ------------------------------
# Gradio UI
# ------------------------------
with gr.Blocks(title="Brick Kiln Compliance Monitor (Gradio)") as demo:
gr.Markdown(
"## Automatic Compliance Monitoring for Brick Kilns\n"
"Upload CSVs, set thresholds, and visualize compliant vs non-compliant kilns on an interactive map.\n"
"- **Kilns CSV** must include columns: `lat, lon` (WGS84).\n"
"- Hospitals CSV can have `Latitude, Longitude` or `lat, lon`.\n"
"- Waterways CSV may be points (`lat, lon`) or WKT LineString/MultiLineString in `geometry`."
)
with gr.Row():
with gr.Column(scale=1):
use_demo = gr.Checkbox(value=True, label="Use bundled demo data (skip uploads)")
kilns_csv = gr.File(label="Kilns CSV (required if demo OFF)", file_types=[".csv"])
hospitals_csv = gr.File(label="Hospitals CSV (optional)", file_types=[".csv"])
waterways_csv = gr.File(label="Waterways CSV or WKT (optional)", file_types=[".csv"])
gr.Markdown("### Thresholds (km)")
kiln_thresh = gr.Number(value=1.0, label="Min distance to nearest kiln (km)")
hosp_thresh = gr.Number(value=0.8, label="Min distance to hospital (km)")
water_thresh = gr.Number(value=0.5, label="Min distance to water body (km)")
add_heatmap = gr.Checkbox(value=False, label="Add heatmap layer")
cluster_points = gr.Checkbox(value=True, label="Cluster markers for speed")
run_btn = gr.Button("Compute & Map", variant="primary")
with gr.Column(scale=2):
fmap = gr.HTML(label="Interactive Map")
summary = gr.Textbox(label="Summary", lines=3)
scatter = gr.Plot(label="Static Visualization: Compliant vs Non-compliant")
def _run(use_demo_flag, k, h, w, kt, ht, wt, heat, cluster):
if use_demo_flag:
k = "data/kilns_clean.csv"
h = "data/hospitals.csv" if os.path.exists("data/hospitals.csv") else None
w = "data/waterways_wkt.csv" if os.path.exists("data/waterways_wkt.csv") else None
map_html, summary_text, csv_bytes = compute_compliance(
k, h, w, float(kt), float(ht), float(wt), bool(heat), bool(cluster)
)
fig = make_scatter_figure(csv_bytes)
return map_html, summary_text, fig
run_btn.click(
_run,
inputs=[use_demo, kilns_csv, hospitals_csv, waterways_csv,
kiln_thresh, hosp_thresh, water_thresh, add_heatmap, cluster_points],
outputs=[fmap, summary, scatter],
)
if __name__ == "__main__":
# Change port if needed: demo.launch(server_port=7861)
demo.launch()